Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file removed .DS_Store
Binary file not shown.
1 change: 0 additions & 1 deletion .github/.gitignore

This file was deleted.

Binary file removed R/.DS_Store
Binary file not shown.
62 changes: 48 additions & 14 deletions R/itr_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -397,29 +397,63 @@ getAupecOutput = function(
}

# transformation function for taucv matrix
gettaucv <- function(
fit,
...
){
estimates <- fit$estimates
# gettaucv <- function(
# estimates,
# ...
# ){
# # estimates <- fit$estimates
# fit_ml <- estimates$fit_ml
# n_folds <- estimates$params$n_folds
# tau_cv <- list()
#
# # for one model
# for (k in seq(n_folds)) {
# tau_cv[[k]] <- fit_ml[["causal_forest"]][[k]][["tau_cv"]]
# }
#
# # convert to a single matrix
# tau_cv <- do.call(cbind, tau_cv)
#
# return(tau_cv)
#
# }
gettaucv <- function(estimates, ...) {
# Assuming fit_ml is a named list of models
fit_ml <- estimates$fit_ml
n_folds <- estimates$params$n_folds
tau_cv <- list()

# for one model
for (k in seq(n_folds)) {
tau_cv[[k]] <- fit_ml[["causal_forest"]][[k]][["tau_cv"]]
# Prepare a list to store tau_cv for each model and each fold
all_models_tau_cv <- list()

# Loop over all models in fit_ml
for (model_name in names(fit_ml)) {
model_tau_cv <- list()

# For one model, loop over all k folds
for (k in seq(n_folds)) {
# Check if the k-th fold exists for the current model
if (is.list(fit_ml[[model_name]]) && k <= length(fit_ml[[model_name]])) {
# Check if tau_cv exists for the k-th fold of the current model
if ("tau_cv" %in% names(fit_ml[[model_name]][[k]])) {
model_tau_cv[[k]] <- fit_ml[[model_name]][[k]][["tau_cv"]]
} else {
warning(paste("tau_cv not found in fold", k, "of model", model_name))
}
} else {
warning(paste("Fold", k, "not found in model", model_name))
}
}

# Convert to a single matrix and store it with the model's name
all_models_tau_cv[[model_name]] <- do.call(cbind, model_tau_cv)
}

# convert to a single matrix
tau_cv <- do.call(cbind, tau_cv)

return(tau_cv)

return(all_models_tau_cv)
}




# rename the columns of the data frame with the interaction terms
rename_interaction_terms <- function(interaction_df){
colnames(interaction_df) <- gsub(":", "_", colnames(interaction_df))
Expand Down
159 changes: 85 additions & 74 deletions R/itr_summary.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#' Summarize estimate_itr output
#' Summarize estimate_itr output
#' @param object An object of \code{estimate_itr} class (typically an output of \code{estimate_itr()} function).
#' @param ... Other parameters.
#' @param ... Other parameters.
#' @importFrom stats pnorm
#' @export
summary.itr <- function(object, ...) {
Expand All @@ -16,7 +16,7 @@ summary.itr <- function(object, ...) {
# fit <- object$qoi

# -----------------------------------------
# estimate ITR from ML algorithms
# estimate ITR from ML algorithms
# -----------------------------------------

if(length(estimate_algs) != 0){
Expand Down Expand Up @@ -46,7 +46,7 @@ if(length(estimate_algs) != 0){
map(., ~ as_tibble(.)) %>%
bind_rows() %>%
mutate(
statistic = pape / sd,
statistic = pape / sd,
p.value = 2 * pnorm(abs(pape / sd), lower.tail = FALSE)
) %>%
rename(
Expand Down Expand Up @@ -150,7 +150,7 @@ if(length(estimate_algs) != 0){
)
}

aupec_algs_vec <- fit$AUPEC %>%
aupec_algs_vec <- fit$AUPEC %>%
map(., ~ .x$aupec_cv) %>%
bind_rows() %>%
mutate(
Expand Down Expand Up @@ -208,7 +208,7 @@ if(length(estimate_user) != 0){
map(., ~ as_tibble(.)) %>%
bind_rows() %>%
mutate(
statistic = pape / sd,
statistic = pape / sd,
p.value = 2 * pnorm(abs(pape / sd), lower.tail = FALSE)
) %>%
rename(
Expand Down Expand Up @@ -316,93 +316,104 @@ print.summary.itr <- function(x, ...) {
}


#' Summarize test_itr output
#' Summarize test_itr output
#' @param object An object of \code{test_itr} class (typically an output of \code{test_itr()} function).
#' @param ... Other parameters.
#' @importFrom stats pnorm
#' @param ... Other parameters.
#' @importFrom dplyr mutate rename select bind_rows %>%
#' @importFrom purrr map_dfr
#' @importFrom tibble as_tibble tibble
#' @export
summary.test_itr <- function(object, ...) {
out <- list()
consist_tibble <- tibble()
het_tibble <- tibble()

## -----------------------------------------
## hypothesis tests
## -----------------------------------------
if (names(object[1]) == "consist") {

# parameters for test_itr object
consist <- object$consist
het <- object$het
consist_names <- names(consist)
het_names <- names(het)

# reformat
out[["Consistency"]] <- consist %>%
map(., ~ as_tibble(.)) %>%
bind_rows() %>%
mutate(algorithm = consist_names) %>%
rename(statistic = stat,
p.value = pval) %>%
select(algorithm, statistic, p.value)

requireNamespace("dplyr", quietly = TRUE)
requireNamespace("purrr", quietly = TRUE)
requireNamespace("tibble", quietly = TRUE)

out[["Heterogeneity"]] <- het %>%
map(., ~ as_tibble(.)) %>%
bind_rows() %>%
mutate(algorithm = het_names) %>%
rename(statistic = stat,
p.value = pval) %>%
select(algorithm, statistic, p.value)
}


if (names(object[1]) == "consistcv") {

# parameters for test_itr object
consist <- object$consistcv
het <- object$hetcv
consist_names <- names(consist)
het_names <- names(het)

# reformat
out[["Consistency_cv"]] <- consist %>%
map(., ~ as_tibble(.)) %>%
bind_rows() %>%
mutate(algorithm = consist_names) %>%
rename(statistic = stat,
p.value = pval) %>%
select(algorithm, statistic, p.value)
process_results <- function(data, names) {
if (length(data) == 0) {
stop("Data for processing is empty. Check the structure of the 'test_itr' object.")
}
if (length(names) != length(data)) {
stop("Mismatch between data elements and their names. Each element of data should have a corresponding name.")
}
purrr::map_dfr(data, ~ tibble::as_tibble(.x), .id = "algorithm") %>%
dplyr::mutate(algorithm = names) %>%
dplyr::rename(statistic = stat, p.value = pval) %>%
dplyr::select(algorithm, statistic, p.value)
}

out[["Heterogeneity_cv"]] <- het %>%
map(., ~ as_tibble(.)) %>%
bind_rows() %>%
mutate(algorithm = het_names) %>%
rename(statistic = stat,
p.value = pval) %>%
select(algorithm, statistic, p.value)
out <- list()

if ("consist" %in% names(object)) {
consist <- object$consist
het <- object$het
out[["Consistency"]] <- process_results(consist, names(consist))
out[["Heterogeneity"]] <- process_results(het, names(het))
} else if ("consistcv" %in% names(object)) {
consistcv <- object$consistcv
hetcv <- object$hetcv
if (is.null(consistcv) || is.null(hetcv)) {
stop("The 'consistcv' or 'hetcv' elements are NULL.")
}
out[["Consistency_cv"]] <- process_results(consistcv, names(consistcv))
out[["Heterogeneity_cv"]] <- process_results(hetcv, names(hetcv))
} else {
stop("Invalid 'test_itr' object: neither 'consist' nor 'consistcv' found in names.")
}

class(out) <- c("summary.test_itr", class(out))

return(out)
}



#' Print
#' @importFrom cli cat_rule
#' @param x An object of \code{summary.test_itr} class. This is typically an output of \code{summary.test_itr()} function.
#' @param ... Other parameters.
#' @param ... Other parameters.
#' @export
print.summary.test_itr <- function(x, ...) {
# Ensure the cli package is available
if (!requireNamespace("cli", quietly = TRUE)) {
stop("The 'cli' package is required for printing this summary. Please install it using install.packages('cli').")
}

# Consistency
cli::cat_rule(left = "The Consistency Test Results for GATEs")
if ("Consistency" %in% names(x)) {
print(as.data.frame(x[["Consistency"]]), digits = 2)
} else {
cli::cat_line("No consistency results available (sample-splitting).")
}
cli::cat_line("")

# Heterogeneity
cli::cat_rule(left = "The Heterogeneity Test Results for GATEs")
if ("Heterogeneity" %in% names(x)) {
print(as.data.frame(x[["Heterogeneity"]]), digits = 2)
} else {
cli::cat_line("No heterogeneity results available (sample-splitting).")
}
cli::cat_line("")

# Rank Consistency Test
cli::cat_rule(left = "Rank Consistency Test Results")
print(as.data.frame(x[["Consistency"]], digits = 2))
# Consistency Cross-Validation
cli::cat_rule(left = "The Consistency Test Results for GATEs (Cross-validation)")
if ("Consistency_cv" %in% names(x)) {
print(as.data.frame(x[["Consistency_cv"]]), digits = 2)
} else {
cli::cat_line("No consistency results available (cross-validation).")
}
cli::cat_line("")

# Group Heterogeneity Test
cli::cat_rule(left = "Group Heterogeneity Test Results")
print(as.data.frame(x[["Heterogeneity"]], digits = 2))
# Heterogeneity Cross-Validation
cli::cat_rule(left = "The Heterogeneity Test Results for GATEs (Cross-validation)")
if ("Heterogeneity_cv" %in% names(x)) {
print(as.data.frame(x[["Heterogeneity_cv"]]), digits = 2)
} else {
cli::cat_line("No heterogeneity results available (cross-validation).")
}
cli::cat_line("")

invisible(x)
}


17 changes: 9 additions & 8 deletions R/main.r
Original file line number Diff line number Diff line change
Expand Up @@ -812,28 +812,29 @@ evaluate_itr <- function(
}

#' Conduct hypothesis tests
#' @param fit Fitted model. Usually an output from \code{estimate_itr}
#' @param fit Fitted model. Usually an output from \code{evaluate_itr}
#' @param nsim Number of Monte Carlo simulations used to simulate the null distributions. Default is 1000.
#' @param ... Further arguments passed to the function.
#' @return An object of \code{test_itr} class
#' @export
test_itr <- function(
fit,
est,
nsim = 1000,
...
) {

# test parameters
estimates <- fit$estimates
out_algs <- est$out_algs
estimates <- out_algs$estimates
cv <- estimates$params$cv
fit_ml <- estimates$fit_ml
Tcv <- estimates$Tcv
Ycv <- estimates$Ycv
indcv <- estimates$indcv
n_folds <- estimates$params$n_folds
ngates <- estimates$params$ngates
algorithms <- fit$df$algorithms
outcome <- fit$df$outcome
algorithms <- out_algs$df$algorithms
outcome <- out_algs$df$outcome

# caret and rlearner parameters
caret_algorithms <- estimates$params$caret_algorithms
Expand All @@ -849,7 +850,7 @@ test_itr <- function(
## =================================

if(cv == FALSE){
cat('Conduct hypothesis tests for GATEs unde sample splitting ...\n')
cat('Conduct hypothesis tests for GATEs unde sample-splitting ...\n')

# create empty lists to for consistcv and hetcv
consist <- list()
Expand Down Expand Up @@ -893,14 +894,14 @@ test_itr <- function(

consistcv[[i]] <- consistcv.test(
T = Tcv,
tau = gettaucv(fit)[[i]],
tau = gettaucv(estimates)[[i]],
Y = Ycv,
ind = indcv,
ngates = ngates)

hetcv[[i]] <- hetcv.test(
T = Tcv,
tau = gettaucv(fit)[[i]],
tau = gettaucv(estimates)[[i]],
Y = Ycv,
ind = indcv,
ngates = ngates)
Expand Down
Loading