diff --git a/R/regularized_regression.R b/R/regularized_regression.R index 2d21335a..2557115e 100644 --- a/R/regularized_regression.R +++ b/R/regularized_regression.R @@ -370,6 +370,56 @@ susie_weights <- function(X = NULL, y = NULL, susie_fit = NULL, ...) { } } +#' @importFrom susieR coef.susie +#' @export +susie_ash_weights <- function(X = NULL, y = NULL, susie_ash_fit = NULL, ...) { + if (is.null(susie_ash_fit)) { + # get susie_ash_fit object + susie_ash_fit <- susie_wrapper(X, y, unmappable_effects = "ash", standardize = FALSE, intercept = FALSE, ...) + } + if (!is.null(X)) { + if (length(susie_ash_fit$pip) != ncol(X)) { + stop(paste0( + "Dimension mismatch on number of variant in susie_ash_fit ", length(susie_ash_fit$pip), + " and TWAS weights ", ncol(X), ". " + )) + } + } + if ("alpha" %in% names(susie_ash_fit) && "mu" %in% names(susie_ash_fit) && "X_column_scale_factors" %in% names(susie_ash_fit)) { + # This is designed to cope with output from pecotmr::susie_post_processor() + # We set intercept to 0 and later trim it off anyways + susie_ash_fit$intercept <- 0 + return(coef.susie(susie_ash_fit)[-1]) + } else { + return(rep(0, length(susie_ash_fit$pip))) + } +} + +#' @importFrom susieR coef.susie +#' @export +susie_inf_weights <- function(X = NULL, y = NULL, susie_inf_fit = NULL, ...) { + if (is.null(susie_inf_fit)) { + # get susie_inf_fit object + susie_inf_fit <- susie_wrapper(X, y, unmappable_effects = "inf", standardize = FALSE, intercept = FALSE, ...) + } + if (!is.null(X)) { + if (length(susie_inf_fit$pip) != ncol(X)) { + stop(paste0( + "Dimension mismatch on number of variant in susie_inf_fit ", length(susie_inf_fit$pip), + " and TWAS weights ", ncol(X), ". " + )) + } + } + if ("alpha" %in% names(susie_inf_fit) && "mu" %in% names(susie_inf_fit) && "X_column_scale_factors" %in% names(susie_inf_fit)) { + # This is designed to cope with output from pecotmr::susie_post_processor() + # We set intercept to 0 and later trim it off anyways + susie_inf_fit$intercept <- 0 + return(coef.susie(susie_inf_fit)[-1]) + } else { + return(rep(0, length(susie_inf_fit$pip))) + } +} + #' @importFrom mr.mash.alpha coef.mr.mash #' @export mrmash_weights <- function(mrmash_fit = NULL, X = NULL, Y = NULL, ...) { @@ -829,43 +879,3 @@ bayes_c_rss_weights <- function(sumstats, LD, ...) { bayes_r_rss_weights <- function(sumstats, LD, ...) { return(bayes_alphabet_rss_weights(sumstats, LD, method = "bayesR", ...)) } - -#' @export -susie_ash_weights <- function(susie_ash_fit, X = NULL, y = NULL, ...) { - # If the fit object is missing or NULL, try to recover it from the parent frame. - if (missing(susie_ash_fit) || is.null(susie_ash_fit)) { - susie_ash_fit <- get0("susie_ash_fit", envir = parent.frame()) - if (is.null(susie_ash_fit)) { - stop("A susie_ash_fit object is required.") - } - } - if (!is.null(X)) { - if (length(susie_ash_fit$marginal_PIP) != ncol(X)) { - stop(paste0("Dimension mismatch on number of variants in susie_ash_fit ", - length(susie_ash_fit$marginal_PIP), " and TWAS weights ", ncol(X), ". ")) - } - } - # Calculate coefficients as per the provided formula. - weights <- rowSums(susie_ash_fit$mu * susie_ash_fit$PIP) + susie_ash_fit$theta - return(weights) -} - -#' @export -susie_inf_weights <- function(susie_inf_fit, X = NULL, y = NULL, ...) { - # If the fit object is missing or NULL, try to recover it from the parent frame. - if (missing(susie_inf_fit) || is.null(susie_inf_fit)) { - susie_inf_fit <- get0("susie_inf_fit", envir = parent.frame()) - if (is.null(susie_inf_fit)) { - stop("A susie_inf_fit object is required.") - } - } - if (!is.null(X)) { - if (length(susie_inf_fit$marginal_PIP) != ncol(X)) { - stop(paste0("Dimension mismatch on number of variants in susie_inf_fit ", - length(susie_inf_fit$marginal_PIP), " and TWAS weights ", ncol(X), ". ")) - } - } - # Calculate coefficients as per the provided formula. - weights <- rowSums(susie_inf_fit$mu * susie_inf_fit$PIP) + susie_inf_fit$alpha - return(weights) -} diff --git a/R/twas_weights.R b/R/twas_weights.R index b1bfd190..73ddcc05 100644 --- a/R/twas_weights.R +++ b/R/twas_weights.R @@ -448,7 +448,9 @@ twas_weights_pipeline <- function(X, bayes_r_weights = list(), bayes_l_weights = list(), mrash_weights = list(init_prior_sd = TRUE, max.iter = 100), - susie_weights = list(refine = FALSE, init_L = 5, max_L = 20) + susie_weights = list(refine = FALSE, init_L = 5, max_L = 20), + susie_ash_weights = list(refine = FALSE, init_L = 5, max_L = 20), + susie_inf_weights = list(refine = FALSE, init_L = 5, max_L = 20) ), max_cv_variants = -1, cv_threads = 1, diff --git a/man/twas_weights_pipeline.Rd b/man/twas_weights_pipeline.Rd index f3e6e3ed..302af2f7 100644 --- a/man/twas_weights_pipeline.Rd +++ b/man/twas_weights_pipeline.Rd @@ -12,7 +12,9 @@ twas_weights_pipeline( sample_partition = NULL, weight_methods = list(enet_weights = list(), lasso_weights = list(), bayes_r_weights = list(), bayes_l_weights = list(), mrash_weights = list(init_prior_sd = TRUE, max.iter - = 100), susie_weights = list(refine = FALSE, init_L = 5, max_L = 20)), + = 100), susie_weights = list(refine = FALSE, init_L = 5, max_L = 20), + susie_ash_weights = list(refine = FALSE, init_L = 5, max_L = 20), susie_inf_weights = + list(refine = FALSE, init_L = 5, max_L = 20)), max_cv_variants = -1, cv_threads = 1, cv_weight_methods = NULL