diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index e9d789b..0000000 Binary files a/.DS_Store and /dev/null differ diff --git a/.github/.gitignore b/.github/.gitignore deleted file mode 100644 index 2d19fc7..0000000 --- a/.github/.gitignore +++ /dev/null @@ -1 +0,0 @@ -*.html diff --git a/R/.DS_Store b/R/.DS_Store deleted file mode 100644 index 5008ddf..0000000 Binary files a/R/.DS_Store and /dev/null differ diff --git a/R/itr_helpers.R b/R/itr_helpers.R index 598163c..262d5df 100644 --- a/R/itr_helpers.R +++ b/R/itr_helpers.R @@ -397,29 +397,63 @@ getAupecOutput = function( } # transformation function for taucv matrix -gettaucv <- function( - fit, - ... -){ - estimates <- fit$estimates +# gettaucv <- function( +# estimates, +# ... +# ){ +# # estimates <- fit$estimates +# fit_ml <- estimates$fit_ml +# n_folds <- estimates$params$n_folds +# tau_cv <- list() +# +# # for one model +# for (k in seq(n_folds)) { +# tau_cv[[k]] <- fit_ml[["causal_forest"]][[k]][["tau_cv"]] +# } +# +# # convert to a single matrix +# tau_cv <- do.call(cbind, tau_cv) +# +# return(tau_cv) +# +# } +gettaucv <- function(estimates, ...) { + # Assuming fit_ml is a named list of models fit_ml <- estimates$fit_ml n_folds <- estimates$params$n_folds - tau_cv <- list() - # for one model - for (k in seq(n_folds)) { - tau_cv[[k]] <- fit_ml[["causal_forest"]][[k]][["tau_cv"]] + # Prepare a list to store tau_cv for each model and each fold + all_models_tau_cv <- list() + + # Loop over all models in fit_ml + for (model_name in names(fit_ml)) { + model_tau_cv <- list() + + # For one model, loop over all k folds + for (k in seq(n_folds)) { + # Check if the k-th fold exists for the current model + if (is.list(fit_ml[[model_name]]) && k <= length(fit_ml[[model_name]])) { + # Check if tau_cv exists for the k-th fold of the current model + if ("tau_cv" %in% names(fit_ml[[model_name]][[k]])) { + model_tau_cv[[k]] <- fit_ml[[model_name]][[k]][["tau_cv"]] + } else { + warning(paste("tau_cv not found in fold", k, "of model", model_name)) + } + } else { + warning(paste("Fold", k, "not found in model", model_name)) + } + } + + # Convert to a single matrix and store it with the model's name + all_models_tau_cv[[model_name]] <- do.call(cbind, model_tau_cv) } - # convert to a single matrix - tau_cv <- do.call(cbind, tau_cv) - - return(tau_cv) - + return(all_models_tau_cv) } + # rename the columns of the data frame with the interaction terms rename_interaction_terms <- function(interaction_df){ colnames(interaction_df) <- gsub(":", "_", colnames(interaction_df)) diff --git a/R/itr_summary.R b/R/itr_summary.R index d8fd0d1..a0b1695 100644 --- a/R/itr_summary.R +++ b/R/itr_summary.R @@ -1,6 +1,6 @@ -#' Summarize estimate_itr output +#' Summarize estimate_itr output #' @param object An object of \code{estimate_itr} class (typically an output of \code{estimate_itr()} function). -#' @param ... Other parameters. +#' @param ... Other parameters. #' @importFrom stats pnorm #' @export summary.itr <- function(object, ...) { @@ -16,7 +16,7 @@ summary.itr <- function(object, ...) { # fit <- object$qoi # ----------------------------------------- -# estimate ITR from ML algorithms +# estimate ITR from ML algorithms # ----------------------------------------- if(length(estimate_algs) != 0){ @@ -46,7 +46,7 @@ if(length(estimate_algs) != 0){ map(., ~ as_tibble(.)) %>% bind_rows() %>% mutate( - statistic = pape / sd, + statistic = pape / sd, p.value = 2 * pnorm(abs(pape / sd), lower.tail = FALSE) ) %>% rename( @@ -150,7 +150,7 @@ if(length(estimate_algs) != 0){ ) } - aupec_algs_vec <- fit$AUPEC %>% + aupec_algs_vec <- fit$AUPEC %>% map(., ~ .x$aupec_cv) %>% bind_rows() %>% mutate( @@ -208,7 +208,7 @@ if(length(estimate_user) != 0){ map(., ~ as_tibble(.)) %>% bind_rows() %>% mutate( - statistic = pape / sd, + statistic = pape / sd, p.value = 2 * pnorm(abs(pape / sd), lower.tail = FALSE) ) %>% rename( @@ -316,93 +316,104 @@ print.summary.itr <- function(x, ...) { } -#' Summarize test_itr output +#' Summarize test_itr output #' @param object An object of \code{test_itr} class (typically an output of \code{test_itr()} function). -#' @param ... Other parameters. -#' @importFrom stats pnorm +#' @param ... Other parameters. +#' @importFrom dplyr mutate rename select bind_rows %>% +#' @importFrom purrr map_dfr +#' @importFrom tibble as_tibble tibble #' @export summary.test_itr <- function(object, ...) { - out <- list() - consist_tibble <- tibble() - het_tibble <- tibble() - - ## ----------------------------------------- - ## hypothesis tests - ## ----------------------------------------- - if (names(object[1]) == "consist") { - - # parameters for test_itr object - consist <- object$consist - het <- object$het - consist_names <- names(consist) - het_names <- names(het) - - # reformat - out[["Consistency"]] <- consist %>% - map(., ~ as_tibble(.)) %>% - bind_rows() %>% - mutate(algorithm = consist_names) %>% - rename(statistic = stat, - p.value = pval) %>% - select(algorithm, statistic, p.value) - + requireNamespace("dplyr", quietly = TRUE) + requireNamespace("purrr", quietly = TRUE) + requireNamespace("tibble", quietly = TRUE) - out[["Heterogeneity"]] <- het %>% - map(., ~ as_tibble(.)) %>% - bind_rows() %>% - mutate(algorithm = het_names) %>% - rename(statistic = stat, - p.value = pval) %>% - select(algorithm, statistic, p.value) - } - - - if (names(object[1]) == "consistcv") { - - # parameters for test_itr object - consist <- object$consistcv - het <- object$hetcv - consist_names <- names(consist) - het_names <- names(het) - - # reformat - out[["Consistency_cv"]] <- consist %>% - map(., ~ as_tibble(.)) %>% - bind_rows() %>% - mutate(algorithm = consist_names) %>% - rename(statistic = stat, - p.value = pval) %>% - select(algorithm, statistic, p.value) + process_results <- function(data, names) { + if (length(data) == 0) { + stop("Data for processing is empty. Check the structure of the 'test_itr' object.") + } + if (length(names) != length(data)) { + stop("Mismatch between data elements and their names. Each element of data should have a corresponding name.") + } + purrr::map_dfr(data, ~ tibble::as_tibble(.x), .id = "algorithm") %>% + dplyr::mutate(algorithm = names) %>% + dplyr::rename(statistic = stat, p.value = pval) %>% + dplyr::select(algorithm, statistic, p.value) + } - out[["Heterogeneity_cv"]] <- het %>% - map(., ~ as_tibble(.)) %>% - bind_rows() %>% - mutate(algorithm = het_names) %>% - rename(statistic = stat, - p.value = pval) %>% - select(algorithm, statistic, p.value) + out <- list() + + if ("consist" %in% names(object)) { + consist <- object$consist + het <- object$het + out[["Consistency"]] <- process_results(consist, names(consist)) + out[["Heterogeneity"]] <- process_results(het, names(het)) + } else if ("consistcv" %in% names(object)) { + consistcv <- object$consistcv + hetcv <- object$hetcv + if (is.null(consistcv) || is.null(hetcv)) { + stop("The 'consistcv' or 'hetcv' elements are NULL.") + } + out[["Consistency_cv"]] <- process_results(consistcv, names(consistcv)) + out[["Heterogeneity_cv"]] <- process_results(hetcv, names(hetcv)) + } else { + stop("Invalid 'test_itr' object: neither 'consist' nor 'consistcv' found in names.") } class(out) <- c("summary.test_itr", class(out)) - return(out) } + + #' Print #' @importFrom cli cat_rule #' @param x An object of \code{summary.test_itr} class. This is typically an output of \code{summary.test_itr()} function. -#' @param ... Other parameters. +#' @param ... Other parameters. #' @export print.summary.test_itr <- function(x, ...) { + # Ensure the cli package is available + if (!requireNamespace("cli", quietly = TRUE)) { + stop("The 'cli' package is required for printing this summary. Please install it using install.packages('cli').") + } + + # Consistency + cli::cat_rule(left = "The Consistency Test Results for GATEs") + if ("Consistency" %in% names(x)) { + print(as.data.frame(x[["Consistency"]]), digits = 2) + } else { + cli::cat_line("No consistency results available (sample-splitting).") + } + cli::cat_line("") + + # Heterogeneity + cli::cat_rule(left = "The Heterogeneity Test Results for GATEs") + if ("Heterogeneity" %in% names(x)) { + print(as.data.frame(x[["Heterogeneity"]]), digits = 2) + } else { + cli::cat_line("No heterogeneity results available (sample-splitting).") + } + cli::cat_line("") - # Rank Consistency Test - cli::cat_rule(left = "Rank Consistency Test Results") - print(as.data.frame(x[["Consistency"]], digits = 2)) + # Consistency Cross-Validation + cli::cat_rule(left = "The Consistency Test Results for GATEs (Cross-validation)") + if ("Consistency_cv" %in% names(x)) { + print(as.data.frame(x[["Consistency_cv"]]), digits = 2) + } else { + cli::cat_line("No consistency results available (cross-validation).") + } cli::cat_line("") - # Group Heterogeneity Test - cli::cat_rule(left = "Group Heterogeneity Test Results") - print(as.data.frame(x[["Heterogeneity"]], digits = 2)) + # Heterogeneity Cross-Validation + cli::cat_rule(left = "The Heterogeneity Test Results for GATEs (Cross-validation)") + if ("Heterogeneity_cv" %in% names(x)) { + print(as.data.frame(x[["Heterogeneity_cv"]]), digits = 2) + } else { + cli::cat_line("No heterogeneity results available (cross-validation).") + } cli::cat_line("") + + invisible(x) } + diff --git a/R/main.r b/R/main.r index c40af4c..196383d 100644 --- a/R/main.r +++ b/R/main.r @@ -812,19 +812,20 @@ evaluate_itr <- function( } #' Conduct hypothesis tests -#' @param fit Fitted model. Usually an output from \code{estimate_itr} +#' @param fit Fitted model. Usually an output from \code{evaluate_itr} #' @param nsim Number of Monte Carlo simulations used to simulate the null distributions. Default is 1000. #' @param ... Further arguments passed to the function. #' @return An object of \code{test_itr} class #' @export test_itr <- function( - fit, + est, nsim = 1000, ... ) { # test parameters - estimates <- fit$estimates + out_algs <- est$out_algs + estimates <- out_algs$estimates cv <- estimates$params$cv fit_ml <- estimates$fit_ml Tcv <- estimates$Tcv @@ -832,8 +833,8 @@ test_itr <- function( indcv <- estimates$indcv n_folds <- estimates$params$n_folds ngates <- estimates$params$ngates - algorithms <- fit$df$algorithms - outcome <- fit$df$outcome + algorithms <- out_algs$df$algorithms + outcome <- out_algs$df$outcome # caret and rlearner parameters caret_algorithms <- estimates$params$caret_algorithms @@ -849,7 +850,7 @@ test_itr <- function( ## ================================= if(cv == FALSE){ - cat('Conduct hypothesis tests for GATEs unde sample splitting ...\n') + cat('Conduct hypothesis tests for GATEs unde sample-splitting ...\n') # create empty lists to for consistcv and hetcv consist <- list() @@ -893,14 +894,14 @@ test_itr <- function( consistcv[[i]] <- consistcv.test( T = Tcv, - tau = gettaucv(fit)[[i]], + tau = gettaucv(estimates)[[i]], Y = Ycv, ind = indcv, ngates = ngates) hetcv[[i]] <- hetcv.test( T = Tcv, - tau = gettaucv(fit)[[i]], + tau = gettaucv(estimates)[[i]], Y = Ycv, ind = indcv, ngates = ngates) diff --git a/docs/articles/test_itr.html b/docs/articles/test_itr.html new file mode 100644 index 0000000..16d65a5 --- /dev/null +++ b/docs/articles/test_itr.html @@ -0,0 +1,249 @@ + + + + + + + + +Nonparametric statistical tests for treatment heterogeneity and rank consistency across multiple ML algorithms • evalITR + + + + + + + + + + Skip to contents + + +
+ + + + +
+
+ + + +

In practice, machine learning (ML) algorithms may fail to ascertain +heterogeneous treatment effects due to small small sample sizes, high +dimensionality, and arbitrary parameter-tuning. The +test_itr function allows users to empirically validate the +estimates of GATEs under various ML algorithms with statistical testing. +In particular, there are two types of nonparametric statistical tests: +(1) test for across group treatment effects heterogeneity and (2) test +of rank consistency of GATEs. The following provides a description of +each test. The tests are based on the idea that, if an ML algorithm +produces a reasonable scoring rule (achieved by the +estimate_itr function), it is reasonable to expect that (1) +the GATEs across groups are heterogeneous; and (2) the rank ordering of +the GATEs based on their magnitude should be mononotic.

+

Following the previous examples, we first estimate GATEs using causal +forest (causal forest), Bayesian Additive Regression Trees +(bartc), LASSO (lasso), random forest +(rf) under cross-validation using the +estimate_itr function. We specify the number of groups to +divide the sample into through the ngates argument. By +setting ngates = 5 in the example below, we estimate the +heterogeneous impact of small class sizes on students’ writing scores +across 5 groups of students.

+
+# library(evalITR)
+devtools::load_all(".")
+
+# specify the trainControl method
+fitControl <- caret::trainControl(
+                           method = "repeatedcv",
+                           number = 2,
+                           repeats = 2)
+# estimate ITR
+set.seed(2023)
+fit_cv <- estimate_itr(
+               treatment = "treatment",
+               form = user_formula,
+               data = star_data,
+               trControl = fitControl,
+               algorithms = c(
+                  "causal_forest", # from caret
+                  "bartc", # from caret
+                  "lasso", # from caret 
+                  "rf"), # from caret 
+               budget = 0.2, # 20% budget constraint
+               n_folds = 5, # 5-fold cross-validation
+               ngates = 5) # 5 groups
+#> Evaluate ITR with cross-validation ...
+#> fitting treatment model via method 'bart'
+#> fitting response model via method 'bart'
+#> fitting treatment model via method 'bart'
+#> fitting response model via method 'bart'
+#> fitting treatment model via method 'bart'
+#> fitting response model via method 'bart'
+#> fitting treatment model via method 'bart'
+#> fitting response model via method 'bart'
+#> fitting treatment model via method 'bart'
+#> fitting response model via method 'bart'
+
+# evaluate ITR
+est_cv <- evaluate_itr(fit_cv)
+
+# extract GATEs estimates
+summary(est_cv)$GATE
+#> # A tibble: 20 × 8
+#>    estimate std.deviation algorithm     group statistic p.value  upper  lower
+#>       <dbl>         <dbl> <chr>         <int>     <dbl>   <dbl>  <dbl>  <dbl>
+#>  1   -52.3          125.  causal_forest     1   -0.417   0.677  193.   -298. 
+#>  2   -61.8          109.  causal_forest     2   -0.568   0.570  151.   -275. 
+#>  3    22.8           82.6 causal_forest     3    0.275   0.783  185.   -139. 
+#>  4   135.           120.  causal_forest     4    1.13    0.259  369.    -99.3
+#>  5   -24.9          124.  causal_forest     5   -0.202   0.840  217.   -267. 
+#>  6   -24.1           88.2 bartc             1   -0.273   0.785  149.   -197. 
+#>  7   -25.5           59.5 bartc             2   -0.429   0.668   91.1  -142. 
+#>  8   -46.9          101.  bartc             3   -0.465   0.642  151.   -245. 
+#>  9   131.           102.  bartc             4    1.28    0.199  330.    -68.7
+#> 10   -15.1          116.  bartc             5   -0.130   0.896  212.   -242. 
+#> 11   -56.5           92.0 lasso             1   -0.614   0.539  124.   -237. 
+#> 12    98.6           59.5 lasso             2    1.66    0.0975 215.    -18.0
+#> 13   -25.6          114.  lasso             3   -0.224   0.823  199.   -250. 
+#> 14    40.4          117.  lasso             4    0.344   0.731  271.   -190. 
+#> 15   -38.0           96.8 lasso             5   -0.393   0.695  152.   -228. 
+#> 16  -117.            58.7 rf                1   -1.99    0.0469  -1.61 -232. 
+#> 17     2.78         124.  rf                2    0.0224  0.982  246.   -240. 
+#> 18    22.7          111.  rf                3    0.204   0.838  241.   -196. 
+#> 19   145.            94.6 rf                4    1.53    0.125  331.    -40.3
+#> 20   -35.0          118.  rf                5   -0.296   0.767  197.   -267.
+

The table reports the quintile GATEs (\(K = +5\)) estimates for each ML algorithm. We find that the Random +Forest is able to produce statistically negative GATE for the lowest +quantitle group (group 1) under cross-validation. This provides evidence +that the Random Forest is able to identify a 20% subgroup whose writing +scores are negatively impacted by small class sizes.

+

We now conduct the statistical tests of treatment effect +heterogeneity and rank consistency to validate these GATEs estimates. We +use the module object output test_est_cv from the +evaluate_itr function as the input object for the +test_itr function to conduct 2 tests simultaneously. We can +summarize the test statistics and the p-values using the +summary function. Lastly, we use the nsims +argument to specify the number of simulations to conduct for each test. +The default is 1000 simulations.

+
+# conduct nonparametric tests
+test_est_cv <- test_itr(est_cv,
+                        nsim = 5000)
+#> Conduct hypothesis tests for GATEs unde cross-validation ...
+
+# summarize test statistics and p-values
+summary(test_est_cv)
+#> ── The Consistency Test Results for GATEs ──────────────────────────────────────
+#> No consistency results available (sample-splitting).
+#> 
+#> ── The Heterogeneity Test Results for GATEs ────────────────────────────────────
+#> No heterogeneity results available (sample-splitting).
+#> 
+#> ── The Consistency Test Results for GATEs (Cross-validation) ───────────────────
+#>       algorithm statistic p.value
+#> 1 causal_forest      0.83    0.74
+#> 2         bartc      1.03    0.63
+#> 3         lasso      0.24    0.82
+#> 4            rf      1.32    0.66
+#> 
+#> ── The Heterogeneity Test Results for GATEs (Cross-validation) ─────────────────
+#>       algorithm statistic p.value
+#> 1 causal_forest       1.9    0.86
+#> 2         bartc       2.3    0.81
+#> 3         lasso       3.3    0.65
+#> 4            rf       6.6    0.25
+

The table reports the resulting values of test statistics and the +p-values for each test under each algorithm. We find that none of the ML +algorithms is able to reject the treatment effect homogeneity hypothesis +under cross-validation, which indicates that these algorithms failed to +identify statistically significant GATEs estimate for any subgroup. In +addition, none of the ML algorithms is able to reject the rank +consistency hypothesis under cross-validation. Thus, there is no strong +statistical evidence that these algorithms are producing unreliable +GATEs.

+
+
+ + + + +
+ + + + + + + diff --git a/man/.DS_Store b/man/.DS_Store deleted file mode 100644 index 467497b..0000000 Binary files a/man/.DS_Store and /dev/null differ diff --git a/man/figures/README-caret_model-1 2.png b/man/figures/README-caret_model-1 2.png new file mode 100644 index 0000000..e2a173a Binary files /dev/null and b/man/figures/README-caret_model-1 2.png differ diff --git a/man/figures/README-caret_model-2 2.png b/man/figures/README-caret_model-2 2.png new file mode 100644 index 0000000..bb0e12a Binary files /dev/null and b/man/figures/README-caret_model-2 2.png differ diff --git a/man/figures/README-compare_itr_aupec-1 2.png b/man/figures/README-compare_itr_aupec-1 2.png new file mode 100644 index 0000000..db3b005 Binary files /dev/null and b/man/figures/README-compare_itr_aupec-1 2.png differ diff --git a/man/figures/README-compare_itr_gate-1 2.png b/man/figures/README-compare_itr_gate-1 2.png new file mode 100644 index 0000000..7d0a0a7 Binary files /dev/null and b/man/figures/README-compare_itr_gate-1 2.png differ diff --git a/man/figures/README-compare_itr_model_summary-1 2.png b/man/figures/README-compare_itr_model_summary-1 2.png new file mode 100644 index 0000000..2ef4008 Binary files /dev/null and b/man/figures/README-compare_itr_model_summary-1 2.png differ diff --git a/man/figures/README-cv_estimate-1 2.png b/man/figures/README-cv_estimate-1 2.png new file mode 100644 index 0000000..a1e0d32 Binary files /dev/null and b/man/figures/README-cv_estimate-1 2.png differ diff --git a/man/figures/README-est_extract-1 2.png b/man/figures/README-est_extract-1 2.png new file mode 100644 index 0000000..a418cb1 Binary files /dev/null and b/man/figures/README-est_extract-1 2.png differ diff --git a/man/figures/README-sl_plot-1 2.png b/man/figures/README-sl_plot-1 2.png new file mode 100644 index 0000000..158837e Binary files /dev/null and b/man/figures/README-sl_plot-1 2.png differ diff --git a/man/figures/README-user_itr_aupec-1 2.png b/man/figures/README-user_itr_aupec-1 2.png new file mode 100644 index 0000000..63304d2 Binary files /dev/null and b/man/figures/README-user_itr_aupec-1 2.png differ diff --git a/man/figures/README-user_itr_gate-1 2.png b/man/figures/README-user_itr_gate-1 2.png new file mode 100644 index 0000000..094375a Binary files /dev/null and b/man/figures/README-user_itr_gate-1 2.png differ diff --git a/man/figures/gate 2.png b/man/figures/gate 2.png new file mode 100644 index 0000000..3fc77af Binary files /dev/null and b/man/figures/gate 2.png differ diff --git a/man/figures/plot_5folds 2.png b/man/figures/plot_5folds 2.png new file mode 100644 index 0000000..bd59a4f Binary files /dev/null and b/man/figures/plot_5folds 2.png differ diff --git a/man/figures/rf 2.png b/man/figures/rf 2.png new file mode 100644 index 0000000..4eb1ba2 Binary files /dev/null and b/man/figures/rf 2.png differ diff --git a/tests/testthat/star 2.rda b/tests/testthat/star 2.rda new file mode 100644 index 0000000..5baf7ba Binary files /dev/null and b/tests/testthat/star 2.rda differ diff --git a/tests/testthat/test-high_level 2.R b/tests/testthat/test-high_level 2.R new file mode 100644 index 0000000..b58529b --- /dev/null +++ b/tests/testthat/test-high_level 2.R @@ -0,0 +1,42 @@ +library(evalITR) +library(dplyr) +test_that("Sample Splitting Works", { + load("star.rda") + # specifying the outcome + outcomes <- "g3tlangss" + + # specifying the treatment + treatment <- "treatment" + + # specifying the data (remove other outcomes) + star_data <- star %>% dplyr::select(-c(g3treadss,g3tmathss)) + + # specifying the formula + user_formula <- as.formula( + "g3tlangss ~ treatment + gender + race + birthmonth + + birthyear + SCHLURBN + GRDRANGE + GKENRMNT + GKFRLNCH + + GKBUSED + GKWHITE ") + + + # estimate ITR + fit <- estimate_itr( + treatment = treatment, + form = user_formula, + data = star_data, + algorithms = c("lasso"), + budget = 0.2, + split_ratio = 0.7) + expect_no_error(estimate_itr( + treatment = treatment, + form = user_formula, + data = star_data, + algorithms = c("lasso"), + budget = 0.2, + split_ratio = 0.7)) + + + # evaluate ITR + est <- evaluate_itr(fit) + expect_no_error(evaluate_itr(fit)) +}) + diff --git a/tests/testthat/test-low_level 2.R b/tests/testthat/test-low_level 2.R new file mode 100644 index 0000000..dd0993b --- /dev/null +++ b/tests/testthat/test-low_level 2.R @@ -0,0 +1,59 @@ +library(evalITR) + +test_that("Non Cross-Validated Functions Work", { + T = c(1,0,1,0,1,0,1,0) + That = c(0,1,1,0,0,1,1,0) + That2 = c(1,0,0,1,1,0,0,1) + tau = c(0,0.1,0.2,0.3,0.4,0.5,0.6,0.7) + Y = c(4,5,0,2,4,1,-4,3) + papelist <- PAPE(T,That,Y) + pavlist <- PAV(T,That,Y) + papdlist <- PAPD(T,That,That2,Y,0.5) + aupeclist <- AUPEC(T,tau,Y) + gatelist <- GATE(T,tau,Y,ngates=2) + expect_type(papelist,"list") + expect_type(pavlist,"list") + expect_type(papdlist,"list") + expect_type(aupeclist,"list") + expect_type(gatelist,"list") + expect_type(papelist$pape,"double") + expect_type(pavlist$pav,"double") + expect_type(papdlist$papd,"double") + expect_type(aupeclist$aupec,"double") + expect_type(gatelist$gate,"double") + expect_type(papelist$sd,"double") + expect_type(pavlist$sd,"double") + expect_type(papdlist$sd,"double") + expect_type(aupeclist$sd,"double") + expect_type(gatelist$sd,"double") +}) + +test_that("Cross-Validated Functions Work", { + T = c(1,0,1,0,1,0,1,0) + That = matrix(c(0,1,1,0,0,1,1,0,1,0,0,1,1,0,0,1), nrow = 8, ncol = 2) + That2 = matrix(c(0,0,1,1,0,0,1,1,1,1,0,0,1,1,0,0), nrow = 8, ncol = 2) + tau = matrix(c(0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,-0.5,-0.3,-0.1,0.1,0.3,0.5,0.7,0.9),nrow = 8, ncol = 2) + Y = c(4,5,0,2,4,1,-4,3) + ind = c(rep(1,4),rep(2,4)) + papelist <- PAPEcv(T,That,Y,ind,budget = 0.5) + pavlist <- PAVcv(T,That,Y,ind) + papdlist <- PAPDcv(T,That,That2,Y,ind,budget = 0.5) + aupeclist <- AUPECcv(T,tau,Y,ind) + gatelist <- GATEcv(T,tau,Y,ind,ngates=2) + expect_type(papelist,"list") + expect_type(pavlist,"list") + expect_type(papdlist,"list") + expect_type(aupeclist,"list") + expect_type(gatelist,"list") + expect_type(papelist$pape,"double") + expect_type(pavlist$pav,"double") + expect_type(papdlist$papd,"double") + expect_type(aupeclist$aupec,"double") + expect_type(gatelist$gate,"double") + expect_type(papelist$sd,"double") + expect_type(pavlist$sd,"double") + expect_type(papdlist$sd,"double") + expect_type(aupeclist$sd,"double") + expect_type(gatelist$sd,"double") +}) + diff --git a/vignettes/test_itr.Rmd b/vignettes/test_itr.Rmd new file mode 100644 index 0000000..c8c045c --- /dev/null +++ b/vignettes/test_itr.Rmd @@ -0,0 +1,84 @@ +--- +title: "Nonparametric statistical tests for treatment heterogeneity and rank consistency across multiple ML algorithms" +output: rmarkdown::html_vignette +vignette: > + %\VignetteIndexEntry{Nonparametric statistical tests with multiple ML algorithms} + %\VignetteEngine{knitr::rmarkdown} + %\VignetteEncoding{UTF-8} +--- + +```{r, include = FALSE} +knitr::opts_chunk$set( + collapse = TRUE, + comment = "#>", + fig.path = "../man/figures/README-" + ) + +library(dplyr) + +load("../data/star.rda") + +# specifying the outcome +outcomes <- "g3tlangss" + +# specifying the treatment +treatment <- "treatment" + +# specifying the data (remove other outcomes) +star_data <- star %>% dplyr::select(-c(g3treadss,g3tmathss)) + +# specifying the formula +user_formula <- as.formula( + "g3tlangss ~ treatment + gender + race + birthmonth + + birthyear + SCHLURBN + GRDRANGE + GKENRMNT + GKFRLNCH + + GKBUSED + GKWHITE ") +``` + +In practice, machine learning (ML) algorithms may fail to ascertain heterogeneous treatment effects due to small small sample sizes, high dimensionality, and arbitrary parameter-tuning. The `test_itr` function allows users to empirically validate the estimates of GATEs under various ML algorithms with statistical testing. In particular, there are two types of nonparametric statistical tests: (1) test for across group treatment effects heterogeneity and (2) test of rank consistency of GATEs. The following provides a description of each test. The tests are based on the idea that, if an ML algorithm produces a reasonable scoring rule (achieved by the `estimate_itr` function), it is reasonable to expect that (1) the GATEs across groups are heterogeneous; and (2) the rank ordering of the GATEs based on their magnitude should be mononotic. + +Following the previous examples, we first estimate GATEs using causal forest (`causal forest`), Bayesian Additive Regression Trees (`bartc`), LASSO (`lasso`), random forest (`rf`) under cross-validation using the `estimate_itr` function. We specify the number of groups to divide the sample into through the `ngates` argument. By setting `ngates = 5` in the example below, we estimate the heterogeneous impact of small class sizes on students’ writing scores across 5 groups of students. + +```{r multiple, warning=FALSE, message=FALSE} +# library(evalITR) +devtools::load_all(".") + +# specify the trainControl method +fitControl <- caret::trainControl( + method = "repeatedcv", + number = 2, + repeats = 2) +# estimate ITR +set.seed(2023) +fit_cv <- estimate_itr( + treatment = "treatment", + form = user_formula, + data = star_data, + trControl = fitControl, + algorithms = c( + "causal_forest", # from caret + "bartc", # from caret + "lasso", # from caret + "rf"), # from caret + budget = 0.2, # 20% budget constraint + n_folds = 5, # 5-fold cross-validation + ngates = 5) # 5 groups + +# evaluate ITR +est_cv <- evaluate_itr(fit_cv) + +# extract GATEs estimates +summary(est_cv)$GATE +``` +The table reports the quintile GATEs ($K = 5$) estimates for each ML algorithm. We find that the Random Forest is able to produce statistically negative GATE for the lowest quantitle group (group 1) under cross-validation. This provides evidence that the Random Forest is able to identify a 20% subgroup whose writing scores are negatively impacted by small class sizes. + +We now conduct the statistical tests of treatment effect heterogeneity and rank consistency to validate these GATEs estimates. We use the module object output `test_est_cv` from the `evaluate_itr` function as the input object for the `test_itr` function to conduct 2 tests simultaneously. We can summarize the test statistics and the p-values using the `summary` function. Lastly, we use the `nsims` argument to specify the number of simulations to conduct for each test. The default is 1000 simulations. + +```{r warning=FALSE, message=FALSE} +# conduct nonparametric tests +test_est_cv <- test_itr(est_cv, + nsim = 5000) + +# summarize test statistics and p-values +summary(test_est_cv) +``` +The table reports the resulting values of test statistics and the p-values for each test under each algorithm. We find that none of the ML algorithms is able to reject the treatment effect homogeneity hypothesis under cross-validation, which indicates that these algorithms failed to identify statistically significant GATEs estimate for any subgroup. In addition, none of the ML algorithms is able to reject the rank consistency hypothesis under cross-validation. Thus, there is no strong statistical evidence that these algorithms are producing unreliable GATEs.