diff --git a/DESCRIPTION b/DESCRIPTION index 1a011cd..8ef92e3 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -27,11 +27,13 @@ Depends: mia Imports: dplyr, + glmnet, methods, S4Vectors, SingleCellExperiment, stats, SummarizedExperiment, + survival, tidyr, TreeSummarizedExperiment Suggests: diff --git a/NAMESPACE b/NAMESPACE index 3380ac6..7775cb9 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -5,21 +5,25 @@ export(addBimodality) export(addShortTermChange) export(addStability) export(addStepwiseDivergence) +export(addSurvival) export(getBaselineDivergence) export(getBimodality) export(getShortTermChange) export(getStability) export(getStepwiseDivergence) +export(getSurvival) exportMethods(addBaselineDivergence) exportMethods(addBimodality) exportMethods(addShortTermChange) exportMethods(addStability) exportMethods(addStepwiseDivergence) +exportMethods(addSurvival) exportMethods(getBaselineDivergence) exportMethods(getBimodality) exportMethods(getShortTermChange) exportMethods(getStability) exportMethods(getStepwiseDivergence) +exportMethods(getSurvival) import(S4Vectors) import(SingleCellExperiment) import(SummarizedExperiment) @@ -44,9 +48,14 @@ importFrom(dplyr,summarize) importFrom(dplyr,sym) importFrom(dplyr,ungroup) importFrom(dplyr,vars) +importFrom(glmnet,Cindex) +importFrom(glmnet,cv.glmnet) +importFrom(glmnet,glmnet) importFrom(stats,coef) importFrom(stats,lm) importFrom(stats,median) importFrom(stats,setNames) +importFrom(survival,Surv) +importFrom(survival,coxph) importFrom(tidyr,pivot_wider) importFrom(tidyr,unnest) diff --git a/NEWS b/NEWS index ca47b03..519de29 100755 --- a/NEWS +++ b/NEWS @@ -1,3 +1,6 @@ +Changes in version 0.99.11 ++ Added wrapper for survival analysis + Changes in version 0.99.10 Date: 2025-09-06 + Added survival data on Crohn's disease (crohn_survival) diff --git a/R/AllGenerics.R b/R/AllGenerics.R index 40a420e..ca65f48 100644 --- a/R/AllGenerics.R +++ b/R/AllGenerics.R @@ -50,3 +50,13 @@ setGeneric("getStability", signature = "x", function(x, ...) #' @export setGeneric("addStability", signature = "x", function(x, ...) standardGeneric("addStability")) + +#' @rdname getSurvival +#' @export +setGeneric("getSurvival", signature = "x", function(x, ...) + standardGeneric("getSurvival")) + +#' @rdname getSurvival +#' @export +setGeneric("addSurvival", signature = "x", function(x, ...) + standardGeneric("addSurvival")) \ No newline at end of file diff --git a/R/getSurvival.R b/R/getSurvival.R new file mode 100644 index 0000000..422a1c7 --- /dev/null +++ b/R/getSurvival.R @@ -0,0 +1,357 @@ +#' @name +#' getSurvival +#' +#' @export +#' +#' @title +#' Survival analysis +#' +#' @description +#' Fit a (penalized) Cox proportional hazards model on microbiome data contained +#' in a SummarizedExperiment object. Data transformations (e.g. pairwise +#' log-ratios) should be handled upstream +#' (e.g. with \code{mia::transformAssay()}). +#' +#' @param time.col \code{Character scalar}. Column name in \code{colData(x)} +#' representing time to event or follow-up time. Must be numeric. +#' +#' @param event.col \code{Character scalar}. Column name in \code{colData(x)} +#' representing event occurrence. Accepts numeric (\code{0}/\code{1}) or +#' logical (\code{TRUE}/\code{FALSE}) values. +#' +#' @param col.var \code{Character vector}. Optional. Specifies covariate +#' columns in \code{colData(x)} to adjust for in the survival model. +#' (Default: \code{NULL}) +#' +#' @param ... additional arguments. +#' \itemize{ +#' \item \code{penalized}: \code{Logical}. If \code{TRUE}, fit penalized Cox +#' regression using \code{glmnet}. If \code{FALSE}, fit standard Cox model +#' using \code{survival::coxph}. (Default: \code{TRUE}) +#' +#' \item \code{lambda}: \code{Character or numeric}. Penalization parameter +#' passed to \code{\link[glmnet]{cv.glmnet}}. Use \code{"lambda.1se"}, +#' \code{"lambda.min"}, or a numeric value. (Default: \code{"lambda.1se"}) +#' +#' \item \code{alpha}: \code{Numeric scalar}. Elastic net mixing parameter +#' that controls the balance between Lasso and Ridge regression: +#' \code{alpha = 1} corresponds to Lasso, +#' \code{alpha = 0} corresponds to Ridge. +#' Values between 0 and 1 specify a combination of the two. +#' (Default: \code{0.9}) +#' +#' \item \code{nfolds}: \code{Integer scalar}. Number of cross-validation +#' folds for \code{cv.glmnet}. (Default: \code{10}) +#' +#' \item \code{nvar}: \code{Integer scalar}. Optional. Maximum number of +#' variables (log-ratios) to include in the model. (Default: \code{NULL}) +#' +#' \item \code{coef.threshold}: \code{Numeric scalar}. Minimum absolute value +#' for a coefficient to be included in the final model. (Default: \code{0}) +#' } +#' +#' @inheritParams addBaselineDivergence +#' +#' @return A list with model summaries: +#' \itemize{ +#' \item \code{coef}: estimated model coefficients +#' \item \code{risk_scores}: predicted risk scores +#' \item \code{c_index}: apparent concordance index +#' \item \code{c_index_cv_mean}: mean cross-validated C-index (if penalized) +#' \item \code{c_index_cv_sd}: SD of cross-validated C-index (if penalized) +#' \item \code{fit}: fitted model object (coxph or cv.glmnet) +#' } +#' +#' @seealso \code{\link[coda4microbiome]{coda_coxnet}}, +#' \code{\link[glmnet]{cv.glmnet}} +#' +#' @references +#' Meritxell Pujolassos , Antoni Susín , M.Luz Calle (2024). +#' \emph{Microbiome compositional data analysis for survival studies } +#' NAR Genomics and Bioinformatics, 6(2), lqae038. +#' \doi{10.1093/nargab/lqae038} +#' +#' @examples +#' data(crohn_survival) +#' tse <- crohn_survival +#' tse <- transformAssay(tse, method = "relabundance") +#' fit <- getSurvival( +#' tse, assay.type = "relabundance", +#' time.col = "event_time", event.col = "event" +#' ) +#' +NULL + +#' @rdname getSurvival +#' @export +setMethod("addSurvival", signature = c(x = "SummarizedExperiment"), + function(x, time.col, event.col, name = "survival", ...){ + .check_input(name, "character scalar") + x <- .check_and_get_altExp(x, ...) + # Run analysis + args <- c( + list(x = x, time.col = time.col, event.col = event.col, + name = name), + list(...)[!names(list(...)) %in% c("altexp")]) + res <- do.call(getSurvival, args) + # Add results to metadata + x <- .add_values_to_metadata(x, name, res, ...) + return(x) + } +) + +#' @rdname getSurvival +#' @export +setMethod("getSurvival", signature(x = "SummarizedExperiment"), + function(x, time.col, event.col, assay.type = "counts", col.var = NULL, + ...){ + # Input checks + x <- .check_data_for_survival( + x, time.col, event.col, col.var, assay.type, ...) + # Extract data + args <- .get_data_for_survival( + x, time.col, event.col, col.var, assay.type) + args <- c(args, list(...)) + # Fit survival model + res <- do.call(.calc_survival, args) + return(res) + } +) + +############################# Internal helpers ################################# + +# Check input validity for survival analysis +# +# Ensures that the required columns for survival analysis are present in +# colData, are of the correct type, and that the requested assay exists. +.check_data_for_survival <- function( + x, time.col, event.col, col.var, assay.type, ...){ + # Ensure we are working with the correct alternative experiment + x <- .check_and_get_altExp(x, ...) + # Check that 'time.col' exists in colData and is a character scalar + .check_input(time.col, list("character scalar"), colnames(colData(x))) + # Check that 'event.col' exists in colData and is a character scalar + .check_input(event.col, list("character scalar"), colnames(colData(x))) + # Check that the requested assay is present in the object + .check_assay_present(assay.type, x) + # Verify that the time column contains numeric values + if( !is.numeric(x[[time.col]]) ){ + stop("'time.col' must be numeric.", call. = FALSE) + } + # Verify that the status column is either logical or numeric (0/1) + if( !is.logical(x[[event.col]]) && !is.numeric(x[[event.col]]) ){ + stop("'event.col' must be numeric (0/1) or logical.", call. = FALSE) + } + # If a grouping variable is provided, check it exists in colData + if( !is.null(col.var) ){ + .check_input(col.var, list("character vector"), colnames(colData(x))) + } + return(x) +} + +# Extract and prepare data for survival analysis +# +# Retrieves the assay matrix, survival time, survival status, +# and optional covariates from the input object, +# formatted for downstream survival models. +.get_data_for_survival <- function( + x, time.col, event.col, col.var, assay.type){ + # Extract assay data (samples as rows, features as columns) + mat <- assay(x, assay.type) |> t() + # Extract survival time column and ensure numeric + time <- x[[time.col]] |> as.numeric() + # Extract survival status column and ensure numeric (0/1) + status <- x[[event.col]] |> as.numeric() + # Extract optional covariates from colData + if( !is.null(col.var) ){ + col.var <- colData(x)[, col.var, drop = FALSE] |> as.data.frame() + } + # Return prepared components in a list + res <- list( + mat = mat, + time = time, + status = status, + col.var = col.var + ) + return(res) +} + + +# Dispatcher to fit a survival model +# +# Chooses between a penalized Cox model (via glmnet) and a standard Cox +# proportional hazards model, based on the `penalized` argument. +.calc_survival <- function(mat, time, status, col.var, penalized = TRUE, ...){ + if( !.is_a_bool(penalized) ){ + stop("'penalized' must be TRUE or FALSE.", call. = FALSE) + } + FUN <- if( penalized ) .fit_penalized_cox else .fit_standard_cox + res <- FUN(mat = mat, time = time, status = status, col.var = col.var, ...) + return(res) +} + +# Fit a penalized Cox proportional hazards model using glmnet +# +# Performs cross-validated elastic net penalized Cox regression. Optional +# covariates +# can be included via an offset. +#' @importFrom survival Surv coxph +#' @importFrom glmnet cv.glmnet glmnet +.fit_penalized_cox <- function( + mat, time, status, col.var, alpha = 0.9, nfolds = 10, + lambda = "lambda.1se", ...){ + # Create survival response object as glmnet requires a Surv object for Cox + # regression + y <- Surv(time, status) + + # Fit Cox model for covariates (if provided) and use as offset. This is done + # to estimate how microbiome profile improves the prediction compared to + # conventional clinical variables. + offset <- NULL + if( !is.null(col.var) ){ + # Fit standard Cox on covariates only + df_covar <- data.frame(time = time, status = status, col.var) + model_covar <- coxph(y ~ ., data = df_covar) + # Compute the linear predictor (log-hazard ratio) from the + # covariate-only Cox model. + offset <- predict(model_covar, type = "lp") + } + + # Fit penalized Cox model with cross-validation. Cross-validation identifies + # the optimal penalty (lambda) while controlling overfitting. This model + # uses microbial profile as predictor. If covariates where specified, the + # model is built on top of the covariate-only model, i.e., to assess how + # much microbes improve the prediction. + fit <- cv.glmnet( + x = mat, y = y, family = "cox", type.measure = "C", + alpha = alpha, nfolds = nfolds, keep = TRUE, + offset = offset + ) + + # Select lambda value based on user input and optional nvar constraint to + # allow selecting lambda that balances model sparsity and predictive + # performance + lambda_value <- .lambda_selector(fit, lambda, ...) + # Extract coefficients at chosen lambda + coefs <- coef(fit, s = lambda_value) + coefs <- setNames(as.vector(coefs), rownames(coefs)) + coefs <- .filter_and_normalize_coefs(coefs, ...) + + # Compute risk scores for all samples + risk_scores <- predict(fit, mat, s = lambda_value, newoffset = offset) |> + as.numeric() + # Compute concordance index + c_index <- .compute_c_index(risk_scores, y, penalized = TRUE) + + # Identify row in CV results corresponding to chosen lambda + id_row <- which(fit[["glmnet.fit"]][["lambda"]] >= lambda_value) |> max() + + # Return all results as a list + res <- list( + coef = coefs, + risk_scores = risk_scores, + c_index = c_index, + c_index_cv_mean = fit[["cvm"]][[id_row]], + c_index_cv_sd = fit[["cvsd"]][[id_row]], + fit = fit + ) + return(res) +} + +# Fit a standard Cox proportional hazards model +# +# Performs a standard (unpenalized) Cox regression on features and optional +# covariates. This is used when no penalization is desired. +.fit_standard_cox <- function(mat, time, status, col.var, ...){ + # Create survival response object as it is required input format for coxph + y <- Surv(time, status) + # Combine survival times, status, features, and optional covariates into one + # data frame + if( is.null(col.var) ){ + df <- data.frame(time = time, status = status, mat) + } else { + df <- data.frame(time = time, status = status, mat, col.var) + } + + # Fit standard Cox proportional hazards model + fit <- coxph(y ~ ., data = df) + + # Extract raw coefficients + coefs <- coef(fit) + # Filter and normalize coefficients + coefs <- .filter_and_normalize_coefs(coefs, ...) + # Compute risk scores for all samples + risk_scores <- predict(fit, type = "lp") + # Compute apparent concordance index + c_index <- .compute_c_index(risk_scores, y, fit, penalized = FALSE) + + # Return results as a list + res <- list( + coef = coefs, + risk_scores = as.numeric(risk_scores), + c_index = c_index, + fit = fit + ) + return(res) +} + +# Filter and normalize coefficients +# +# Removes coefficients with magnitude below a threshold and rescales the +# remaining coefficients. +.filter_and_normalize_coefs <- function(coefs, coef.threshold = 0, ...) { + # Identify coefficients exceeding the threshold in absolute value + # to remove very small coefficients so that we can focus on meaningful + # predictors + res <- coefs[ abs(coefs) > coef.threshold ] + # Rescale the remaining coefficients + if( length(res) > 0L ){ + res <- 2*res / sum(abs(res)) + } + return(res) +} + +# Compute concordance index (C-index) for survival predictions +# +# Evaluates the discriminative ability of a survival model: +# how well the predicted risk scores rank patients according to observed +# survival. +#' @importFrom glmnet Cindex +.compute_c_index <- function(risk_scores, y, fit = NULL, penalized = TRUE){ + res <- NA + if( !(length(risk_scores) == 0 || all(risk_scores == 0)) ){ + if( penalized ){ + # Penalized Cox: use glmnet's Cindex function + res <- Cindex(pred = risk_scores, y) + } else { + # Standard Cox: extract C-index from the model summary + res <- summary(fit)[["concordance"]][[1L]] + } + } + return(res) +} + +# Select lambda value from a cross-validated glmnet fit +# +# Allows choosing lambda based on CV results or restricting the model to a +# maximum number of variables. +.lambda_selector <- function(fit, lambda, nvar = NULL, ...) { + # Handle nvar cutoff + if( !is.null(nvar) ){ + # Identify lambda values with number of non-zero coefficients <= nvar. + valid <- which(fit[["glmnet.fit"]][["df"]] <= nvar) + # If such lambdas exist, choose the lambda for model that has + # highest number of taxa. + if( length(valid) > 0L ){ + lambda <- fit[["glmnet.fit"]][max(valid), "lambda"] + } + } + # Handle lambda selection by name + if( is.character(lambda) ){ + # Convert character string to actual numeric lambda value from fit + # Why: "lambda.min" or "lambda.1se" are stored inside fit; this + # retrieves the correct numeric value + lambda <- fit[[lambda]] + } + return(lambda) +} diff --git a/man/getSurvival.Rd b/man/getSurvival.Rd new file mode 100644 index 0000000..9eac72c --- /dev/null +++ b/man/getSurvival.Rd @@ -0,0 +1,104 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/AllGenerics.R, R/getSurvival.R +\name{getSurvival} +\alias{getSurvival} +\alias{addSurvival} +\alias{addSurvival,SummarizedExperiment-method} +\alias{getSurvival,SummarizedExperiment-method} +\title{Survival analysis} +\usage{ +getSurvival(x, ...) + +addSurvival(x, ...) + +\S4method{addSurvival}{SummarizedExperiment}(x, time.col, event.col, name = "survival", ...) + +\S4method{getSurvival}{SummarizedExperiment}(x, time.col, event.col, assay.type = "counts", col.var = NULL, ...) +} +\arguments{ +\item{x}{A +\code{\link[SummarizedExperiment:SummarizedExperiment-class]{SummarizedExperiment}} +object.} + +\item{...}{additional arguments. +\itemize{ +\item \code{penalized}: \code{Logical}. If \code{TRUE}, fit penalized Cox +regression using \code{glmnet}. If \code{FALSE}, fit standard Cox model +using \code{survival::coxph}. (Default: \code{TRUE}) + +\item \code{lambda}: \code{Character or numeric}. Penalization parameter +passed to \code{\link[glmnet]{cv.glmnet}}. Use \code{"lambda.1se"}, +\code{"lambda.min"}, or a numeric value. (Default: \code{"lambda.1se"}) + +\item \code{alpha}: \code{Numeric scalar}. Elastic net mixing parameter +that controls the balance between Lasso and Ridge regression: +\code{alpha = 1} corresponds to Lasso, +\code{alpha = 0} corresponds to Ridge. +Values between 0 and 1 specify a combination of the two. +(Default: \code{0.9}) + +\item \code{nfolds}: \code{Integer scalar}. Number of cross-validation +folds for \code{cv.glmnet}. (Default: \code{10}) + +\item \code{nvar}: \code{Integer scalar}. Optional. Maximum number of +variables (log-ratios) to include in the model. (Default: \code{NULL}) + +\item \code{coef.threshold}: \code{Numeric scalar}. Minimum absolute value +for a coefficient to be included in the final model. (Default: \code{0}) +}} + +\item{time.col}{\code{Character scalar}. Column name in \code{colData(x)} +representing time to event or follow-up time. Must be numeric.} + +\item{event.col}{\code{Character scalar}. Column name in \code{colData(x)} +representing event occurrence. Accepts numeric (\code{0}/\code{1}) or +logical (\code{TRUE}/\code{FALSE}) values.} + +\item{name}{\code{Character vector}. Specifies a column name for storing +divergence results. +(Default: \code{c("divergence", "time_diff", "ref_samples")})} + +\item{assay.type}{\code{Character scalar}. Specifies which assay values are +used in the dissimilarity estimation. (Default: \code{"counts"})} + +\item{col.var}{\code{Character vector}. Optional. Specifies covariate +columns in \code{colData(x)} to adjust for in the survival model. +(Default: \code{NULL})} +} +\value{ +A list with model summaries: +\itemize{ +\item \code{coef}: estimated model coefficients +\item \code{risk_scores}: predicted risk scores +\item \code{c_index}: apparent concordance index +\item \code{c_index_cv_mean}: mean cross-validated C-index (if penalized) +\item \code{c_index_cv_sd}: SD of cross-validated C-index (if penalized) +\item \code{fit}: fitted model object (coxph or cv.glmnet) +} +} +\description{ +Fit a (penalized) Cox proportional hazards model on microbiome data contained +in a SummarizedExperiment object. Data transformations (e.g. pairwise +log-ratios) should be handled upstream +(e.g. with \code{mia::transformAssay()}). +} +\examples{ +data(crohn_survival) +tse <- crohn_survival +tse <- transformAssay(tse, method = "relabundance") +fit <- getSurvival( + tse, assay.type = "relabundance", + time.col = "event_time", event.col = "event" +) + +} +\references{ +Meritxell Pujolassos , Antoni Susín , M.Luz Calle (2024). +\emph{Microbiome compositional data analysis for survival studies } +NAR Genomics and Bioinformatics, 6(2), lqae038. +\doi{10.1093/nargab/lqae038} +} +\seealso{ +\code{\link[coda4microbiome]{coda_coxnet}}, +\code{\link[glmnet]{cv.glmnet}} +} diff --git a/tests/testthat/test-getSurvival.R b/tests/testthat/test-getSurvival.R new file mode 100644 index 0000000..b4bfccb --- /dev/null +++ b/tests/testthat/test-getSurvival.R @@ -0,0 +1,345 @@ +# Check that input checks work +test_that("getSurvival validates input", { + # Load and prepare data + data(crohn_survival) + tse <- crohn_survival + tse <- transformAssay(tse, method = "relabundance") + + # Test that function works with valid inputs + expect_no_error({ + fit <- getSurvival( + tse, + assay.type = "relabundance", + time.col = "event_time", + event.col = "event" + ) + }) + + expect_error({ + getSurvival( + tse, + assay.type = "relabundance" + ) + }) + expect_error({ + getSurvival( + tse + ) + }) + + # Test error with non-existent time column + expect_error({ + getSurvival( + tse, + assay.type = "relabundance", + time.col = "nonexistent_time", + event.col = "event" + ) + }) + + expect_error({ + getSurvival( + tse, + assay.type = "relabundance", + event.col = "event" + ) + }) + + # Test error with non-existent event column + expect_error({ + getSurvival( + tse, + assay.type = "relabundance", + time.col = "event_time", + event.col = "nonexistent_event" + ) + }) + + expect_error({ + getSurvival( + tse, + assay.type = "relabundance", + time.col = "event_time" + ) + }) + + # Test error with non-existent assay + expect_error({ + getSurvival( + tse, + assay.type = "nonexistent_assay", + time.col = "event_time", + event.col = "event" + ) + }) +}) + +# Check that method works correctly +test_that("getSurvival works correctly", { + # Load the example data + data(crohn_survival) + tse <- crohn_survival + + # Check that the data loaded correctly + expect_s4_class(tse, "TreeSummarizedExperiment") + + # Verify required columns exist + expect_true("event_time" %in% colnames(colData(tse))) + expect_true("event" %in% colnames(colData(tse))) + + # Transform to relative abundance as in example + tse <- transformAssay(tse, method = "relabundance") + + # Test basic getSurvival functionality + fit <- getSurvival( + tse, + assay.type = "relabundance", + time.col = "event_time", + event.col = "event" + ) + + # Check that result is a list with expected components + expect_type(fit, "list") + expect_true(all(c("coef", "risk_scores", "c_index", "fit") %in% names(fit))) + + # Check coefficients + expect_type(fit$coef, "double") + expect_true(is.numeric(fit$coef)) + + # Check risk scores match number of samples + expect_length(fit$risk_scores, ncol(tse)) + expect_true(is.numeric(fit$risk_scores)) + expect_true(all(is.finite(fit$risk_scores))) + + # Check C-index is valid + expect_true(is.numeric(fit$c_index)) + expect_true(fit$c_index >= 0 && fit$c_index <= 1) + + # Check that fit object exists and is correct type + expect_true(!is.null(fit$fit)) + expect_s3_class(fit$fit, "cv.glmnet") + + # Check cross-validation results are present (penalized = TRUE by default) + expect_true("c_index_cv_mean" %in% names(fit)) + expect_true("c_index_cv_sd" %in% names(fit)) + expect_true(is.numeric(fit$c_index_cv_mean)) + expect_true(is.numeric(fit$c_index_cv_sd)) +}) + +# Check that method handles different params passed in ... +test_that("getSurvival handles different parameters", { + # Load and prepare data + data(crohn_survival) + tse <- crohn_survival + tse <- transformAssay(tse, method = "relabundance") + + # Test with different alpha values + fit_ridge <- getSurvival( + tse, + assay.type = "relabundance", + time.col = "event_time", + event.col = "event", + alpha = 0.1 # Ridge regression + ) + + fit_lasso <- getSurvival( + tse, + assay.type = "relabundance", + time.col = "event_time", + event.col = "event", + alpha = 1.0 # Lasso regression + ) + + # Both should return valid results + expect_type(fit_ridge, "list") + expect_type(fit_lasso, "list") + + # Results should be different due to different regularization + expect_false(identical(fit_ridge$coef, fit_lasso$coef)) + expect_false(identical(fit_ridge$risk_scores, fit_lasso$risk_scores)) + + # Test with standard Cox regression (no penalization) + fit_standard <- getSurvival( + tse, + assay.type = "relabundance", + time.col = "event_time", + event.col = "event", + penalized = FALSE + ) + + expect_type(fit_standard, "list") + expect_s3_class(fit_standard$fit, "coxph") + + # Standard model should not have CV results + expect_false("c_index_cv_mean" %in% names(fit_standard)) + expect_false("c_index_cv_sd" %in% names(fit_standard)) +}) + +# Check that method handles covariates +test_that("getSurvival handles covariates", { + # Load and prepare data + data(crohn_survival) + tse <- crohn_survival + tse <- transformAssay(tse, method = "relabundance") + + set.seed(42) # For reproducibility + + # Determine sample size + n <- ncol(tse) + + # Generate synthetic covariates + colData(tse)$age <- round(rnorm(n, mean = 40, sd = 12)) # ages, e.g. 20–70 + colData(tse)$sex <- factor(sample(c("M", "F"), n, replace = TRUE)) + + # Test with available covariates (use first available column) + fit_with_covs <- getSurvival( + tse, + assay.type = "relabundance", + time.col = "event_time", + event.col = "event", + col.var = c("age", "sex") + ) + + # Test without covariates for comparison + fit_without_covs <- getSurvival( + tse, + assay.type = "relabundance", + time.col = "event_time", + event.col = "event" + ) + + # Both should return valid results + expect_type(fit_with_covs, "list") + expect_type(fit_without_covs, "list") + + # Results should be different when covariates are included + expect_false(identical(fit_with_covs$risk_scores, fit_without_covs$risk_scores)) + + # Both should have same structure + essential_components <- c("coef", "risk_scores", "c_index", "fit") + expect_true(all(essential_components %in% names(fit_with_covs))) + expect_true(all(essential_components %in% names(fit_without_covs))) +}) + +# Check that method handles coefficient filtering +test_that("getSurvival handles coefficient filtering", { + # Load and prepare data + data(crohn_survival) + tse <- crohn_survival + tse <- transformAssay(tse, method = "relabundance") + + # Test with different coefficient thresholds + fit_no_filter <- getSurvival( + tse, + assay.type = "relabundance", + time.col = "event_time", + event.col = "event", + coef.threshold = 0 # No filtering + ) + + fit_filtered <- getSurvival( + tse, + assay.type = "relabundance", + time.col = "event_time", + event.col = "event", + coef.threshold = 0.1 # Filter small coefficients + ) + + # Filtered result should have fewer or equal coefficients + expect_true(length(fit_filtered$coef) <= length(fit_no_filter$coef)) + + # All remaining coefficients should exceed the threshold + if(length(fit_filtered$coef) > 0) { + expect_true(all(abs(fit_filtered$coef) > 0.1)) + } + + # Test with high threshold that might filter out all coefficients + fit_high_filter <- getSurvival( + tse, + assay.type = "relabundance", + time.col = "event_time", + event.col = "event", + coef.threshold = 10 # Very high threshold + ) + + # Should still return valid structure even if no coefficients pass threshold + expect_type(fit_high_filter, "list") + expect_true("coef" %in% names(fit_high_filter)) +}) + +# Check that method handles different transformations +test_that("getSurvival handles different assay types", { + # Load data + data(crohn_survival) + tse <- crohn_survival + + # Test with original counts + fit_counts <- getSurvival( + tse, + assay.type = "counts", + time.col = "event_time", + event.col = "event" + ) + + # Transform and test with relative abundances + tse <- transformAssay(tse, method = "relabundance") + fit_relabundance <- getSurvival( + tse, + assay.type = "relabundance", + time.col = "event_time", + event.col = "event" + ) + + # Both should return valid results + expect_type(fit_counts, "list") + expect_type(fit_relabundance, "list") + + # Results should be different due to different data scaling + expect_false(identical(fit_counts$coef, fit_relabundance$coef)) + expect_false(identical(fit_counts$risk_scores, fit_relabundance$risk_scores)) + + # Add CLR transformation and test + tse <- transformAssay(tse, method = "clr", pseudocount = TRUE) + fit_clr <- getSurvival( + tse, + assay.type = "clr", + time.col = "event_time", + event.col = "event" + ) + + expect_type(fit_clr, "list") + expect_false(identical(fit_clr$coef, fit_relabundance$coef)) +}) + +# Check that methods results are reproducible +test_that("getSurvival results are reproducible", { + # Load and prepare data + data(crohn_survival) + tse <- crohn_survival + tse <- transformAssay(tse, method = "relabundance") + + # Set seed and run analysis + set.seed(123) + fit1 <- getSurvival( + tse, + assay.type = "relabundance", + time.col = "event_time", + event.col = "event", + nfolds = 5 # Use fewer folds for faster testing + ) + + # Set same seed and run analysis again + set.seed(123) + fit2 <- getSurvival( + tse, + assay.type = "relabundance", + time.col = "event_time", + event.col = "event", + nfolds = 5 + ) + + # Results should be identical with same seed + expect_identical(fit1$coef, fit2$coef) + expect_identical(fit1$risk_scores, fit2$risk_scores) + expect_equal(fit1$c_index, fit2$c_index) +}) \ No newline at end of file