From edcb7c1d423c498c013fb2226e25020876d88be2 Mon Sep 17 00:00:00 2001 From: Gao Wang Date: Tue, 7 Apr 2026 08:22:29 -0400 Subject: [PATCH] Fix PRS-CS sigma sampling bug and add signal recovery test The sigma (residual variance) sampling had an incorrect parameterization: Before: sigma = 1.0 / Gamma((n+p)/2, scale=1) / err = InvGamma(a,1)/err After: sigma = err / Gamma((n+p)/2, scale=1) = err * InvGamma(a,1) The original PRS-CS (Ge et al., getian107/PRScs) samples sigma as 1/Gamma(a, scale=1/err) = InvGamma(a, rate=err), which has mean err/(a-1). The buggy version gave mean 1/(err*(a-1)), differing by err^2. This caused incorrect posterior variance estimation. Also add a signal recovery test for prs_cs with realistic binomial genotype data and verify sigma is in a reasonable range. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/prscs_mcmc.h | 6 +++++- tests/testthat/test_regularized_regression.R | 22 ++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/src/prscs_mcmc.h b/src/prscs_mcmc.h index 6e6ab00a..371e858b 100644 --- a/src/prscs_mcmc.h +++ b/src/prscs_mcmc.h @@ -259,7 +259,11 @@ std::map prs_cs_mcmc(double a, double b, double* phi, double err = std::max(n / 2.0 * (1.0 - 2.0 * arma::dot(beta, beta_mrg) + quad), n / 2.0 * arma::sum(arma::pow(beta, 2) / psi)); - sigma = 1.0 / gamma_dist_sigma(rng) / err; + // Original PRS-CS (Ge et al.): sigma = 1/Gamma((n+p)/2, scale=1/err) + // = InvGamma((n+p)/2, rate=err), i.e. mean = err/((n+p)/2-1). + // gamma_dist_sigma samples X ~ Gamma((n+p)/2, scale=1), so + // sigma = err/X gives the correct InvGamma(a, rate=err). + sigma = err / gamma_dist_sigma(rng); arma::vec delta = arma::vec(p); for (int jj = 0; jj < p; ++jj) { diff --git a/tests/testthat/test_regularized_regression.R b/tests/testthat/test_regularized_regression.R index ed45211c..bba34c5d 100644 --- a/tests/testthat/test_regularized_regression.R +++ b/tests/testthat/test_regularized_regression.R @@ -102,6 +102,28 @@ test_that("prs_cs works without maf (maf = NULL)", { expect_equal(length(result$beta_est), p) }) +# ---- prs_cs signal recovery ---- +test_that("prs_cs recovers signal direction on simulated genotype data", { + set.seed(42) + n <- 500 + p <- 20 + X <- matrix(rbinom(n * p, 2, 0.3), nrow = n) + beta_true <- rep(0, p) + beta_true[c(3, 10, 15)] <- c(0.4, -0.3, 0.2) + y <- X %*% beta_true + rnorm(n) + bhat <- as.vector(cor(y, X)) + R <- cor(X) + result <- prs_cs(bhat = bhat, LD = list(blk1 = R), n = n, + n_iter = 1000, n_burnin = 500, thin = 5, seed = 42) + expect_true("beta_est" %in% names(result)) + expect_equal(length(result$beta_est), p) + expect_true(all(is.finite(result$beta_est))) + # Sigma should be reasonable (near 1 for standardized data) + expect_true(result$sigma_est > 0.1 && result$sigma_est < 10) + # Correlation with truth should be positive (signal recovery) + expect_gt(cor(result$beta_est, beta_true), 0.5) +}) + # ---- prs_cs_weights (wrapper) ---- test_that("prs_cs_weights calls prs_cs and returns beta_est", { set.seed(42)