From 9a9d5f623b4d0c6fcd2b9e56697eb5975516cb40 Mon Sep 17 00:00:00 2001 From: Lief Date: Sat, 23 Oct 2021 14:29:36 -0700 Subject: [PATCH] Stops augsynth from segfaulting with na values or unbalanced panels. Fix #56, fix #53 --- R/augsynth.R | 122 ++++++++++++++++++++--------------- tests/testthat/test_format.R | 23 +++++-- 2 files changed, 89 insertions(+), 56 deletions(-) diff --git a/R/augsynth.R b/R/augsynth.R index 00c7f96..b4cfc28 100644 --- a/R/augsynth.R +++ b/R/augsynth.R @@ -4,7 +4,7 @@ #' Fit Augmented SCM -#' +#' #' @param form outcome ~ treatment | auxillary covariates #' @param unit Name of unit column #' @param time Name of time column @@ -14,11 +14,11 @@ #' ridge=Ridge regression (allows for standard errors), #' none=No outcome model, #' en=Elastic Net, RF=Random Forest, GSYN=gSynth, -#' mcp=MCPanel, +#' mcp=MCPanel, #' cits=Comparitive Interuppted Time Series #' causalimpact=Bayesian structural time series with CausalImpact #' @param scm Whether the SCM weighting function is used -#' @param fixedeff Whether to include a unit fixed effect, default F +#' @param fixedeff Whether to include a unit fixed effect, default F #' @param cov_agg Covariate aggregation functions, if NULL then use mean with NAs omitted #' @param ... optional arguments for outcome model #' @@ -43,15 +43,33 @@ single_augsynth <- function(form, unit, time, t_int, data, unit <- enquo(unit) time <- enquo(time) + ## validate input data + + # Check for NA values + if( sum(is.na(data %>% + select(!!unit, !!time, any_of(as.character(form))) + ) + ) > 0 ) { + stop("Missing values detected.") + } + + # Check whether there are omitted rows + full_data <- data %>% + tidyr::expand({{unit}}, {{time}}) + + if( nrow(data) != nrow(full_data) ) { + stop("There are missing rows in the input data set. Panel must be balanced.") + } + ## format data outcome <- terms(formula(form, rhs=1))[[2]] trt <- terms(formula(form, rhs=1))[[3]] wide <- format_data(outcome, trt, unit, time, t_int, data) synth_data <- do.call(format_synth, wide) - + treated_units <- data %>% filter(!!trt == 1) %>% distinct(!!unit) %>% pull(!!unit) - control_units <- data %>% filter(!(!!unit %in% treated_units)) %>% + control_units <- data %>% filter(!(!!unit %in% treated_units)) %>% distinct(!!unit) %>% arrange(!!unit) %>% pull(!!unit) ## add covariates if(length(form)[2] == 2) { @@ -59,17 +77,17 @@ single_augsynth <- function(form, unit, time, t_int, data, } else { Z <- NULL } - + # fit augmented SCM - augsynth <- fit_augsynth_internal(wide, synth_data, Z, progfunc, + augsynth <- fit_augsynth_internal(wide, synth_data, Z, progfunc, scm, fixedeff, ...) - + # add some extra data augsynth$data$time <- data %>% distinct(!!time) %>% arrange(!!time) %>% pull(!!time) augsynth$call <- call_name - augsynth$t_int <- t_int - + augsynth$t_int <- t_int + augsynth$weights <- matrix(augsynth$weights) rownames(augsynth$weights) <- control_units @@ -86,9 +104,9 @@ single_augsynth <- function(form, unit, time, t_int, data, #' @param fixedeff Whether to de-mean synth #' @param V V matrix for Synth, default NULL #' @param ... Extra args for outcome model -#' +#' #' @noRd -#' +#' fit_augsynth_internal <- function(wide, synth_data, Z, progfunc, scm, fixedeff, V = NULL, ...) { @@ -119,7 +137,7 @@ fit_augsynth_internal <- function(wide, synth_data, Z, progfunc, } else if(progfunc == "none") { ## Just SCM augsynth <- do.call(fit_ridgeaug_formatted, - c(list(wide_data = fit_wide, + c(list(wide_data = fit_wide, synth_data = fit_synth_data, Z = Z, ridge = F, scm = T, V = V, ...))) } else { @@ -127,15 +145,15 @@ fit_augsynth_internal <- function(wide, synth_data, Z, progfunc, progfuncs = c("ridge", "none", "en", "rf", "gsyn", "mcp", "cits", "causalimpact", "seq2seq") if (progfunc %in% progfuncs) { - augsynth <- fit_augsyn(fit_wide, fit_synth_data, + augsynth <- fit_augsyn(fit_wide, fit_synth_data, progfunc, scm, ...) } else { stop("progfunc must be one of 'EN', 'RF', 'GSYN', 'MCP', 'CITS', 'CausalImpact', 'seq2seq', 'None'") } - + } - augsynth$mhat <- mhat + cbind(matrix(0, nrow = n, ncol = t0), + augsynth$mhat <- mhat + cbind(matrix(0, nrow = n, ncol = t0), augsynth$mhat) augsynth$data <- wide augsynth$data$Z <- Z @@ -169,13 +187,13 @@ predict.augsynth <- function(object, att = F, ...) { # att <- F # } augsynth <- object - + X <- augsynth$data$X y <- augsynth$data$y comb <- cbind(X, y) trt <- augsynth$data$trt mhat <- augsynth$mhat - + m1 <- colMeans(mhat[trt==1,,drop=F]) resid <- (comb[trt==0,,drop=F] - mhat[trt==0,drop=F]) @@ -198,7 +216,7 @@ predict.augsynth <- function(object, att = F, ...) { #' @export print.augsynth <- function(x, ...) { augsynth <- x - + ## straight from lm cat("\nCall:\n", paste(deparse(augsynth$call), sep="\n", collapse="\n"), "\n\n", sep="") @@ -214,7 +232,7 @@ print.augsynth <- function(x, ...) { #' Plot function for augsynth #' @importFrom graphics plot -#' +#' #' @param x Augsynth object to be plotted #' @param inf Boolean, whether to get confidence intervals around the point estimates #' @param cv If True, plot cross validation MSE against hyper-parameter, otherwise plot effects @@ -228,22 +246,22 @@ plot.augsynth <- function(x, inf = T, cv = F, ...) { # } augsynth <- x - + if (cv == T) { errors = data.frame(lambdas = augsynth$lambdas, errors = augsynth$lambda_errors, errors_se = augsynth$lambda_errors_se) p <- ggplot2::ggplot(errors, ggplot2::aes(x = lambdas, y = errors)) + - ggplot2::geom_point(size = 2) + + ggplot2::geom_point(size = 2) + ggplot2::geom_errorbar( ggplot2::aes(ymin = errors, ymax = errors + errors_se), - width=0.2, size = 0.5) + width=0.2, size = 0.5) p <- p + ggplot2::labs(title = bquote("Cross Validation MSE over " ~ lambda), - x = expression(lambda), y = "Cross Validation MSE", + x = expression(lambda), y = "Cross Validation MSE", parse = TRUE) p <- p + ggplot2::scale_x_log10() - + # find minimum and min + 1se lambda to plot min_lambda <- choose_lambda(augsynth$lambdas, augsynth$lambda_errors, @@ -257,7 +275,7 @@ plot.augsynth <- function(x, inf = T, cv = F, ...) { min_1se_lambda_index <- which(augsynth$lambdas == min_1se_lambda) p <- p + ggplot2::geom_point( - ggplot2::aes(x = min_lambda, + ggplot2::aes(x = min_lambda, y = augsynth$lambda_errors[min_lambda_index]), color = "gold") p + ggplot2::geom_point( @@ -299,8 +317,8 @@ summary.augsynth <- function(object, inf = T, inf_type = "conformal", ...) { # } else { # inf_type <- "conformal" # } - - + + summ <- list() t0 <- ncol(augsynth$data$X) @@ -382,8 +400,8 @@ summary.augsynth <- function(object, inf = T, inf_type = "conformal", ...) { } else { summ$bias_est <- m1 - t(mhat[trt==0,,drop=F]) %*% w } - - + + summ$inf_type <- if(inf) inf_type else "None" class(summ) <- "summary.augsynth" return(summ) @@ -395,7 +413,7 @@ summary.augsynth <- function(object, inf = T, inf_type = "conformal", ...) { #' @export print.summary.augsynth <- function(x, ...) { summ <- x - + ## straight from lm cat("\nCall:\n", paste(deparse(summ$call), sep="\n", collapse="\n"), "\n\n", sep="") @@ -405,7 +423,7 @@ print.summary.augsynth <- function(x, ...) { att_est <- summ$att$Estimate t_total <- length(att_est) t_int <- summ$att %>% filter(Time <= summ$t_int) %>% nrow() - + att_pre <- att_est[1:(t_int-1)] att_post <- att_est[t_int:t_total] @@ -420,14 +438,14 @@ print.summary.augsynth <- function(x, ...) { se_avg <- summ$average_att$Std.Error out_msg <- paste("Average ATT Estimate (Jackknife Std. Error): ", - format(round(att_post,3), nsmall=3), + format(round(att_post,3), nsmall=3), " (", format(round(se_avg,3)), ")\n") inf_type <- "Jackknife over units" } else if(summ$inf_type == "conformal") { p_val <- summ$average_att$p_val out_msg <- paste("Average ATT Estimate (p Value for Joint Null): ", - format(round(att_post,3), nsmall=3), + format(round(att_post,3), nsmall=3), " (", format(round(p_val,3)), ")\n") inf_type <- "Conformal inference" @@ -442,7 +460,7 @@ print.summary.augsynth <- function(x, ...) { } - out_msg <- paste(out_msg, + out_msg <- paste(out_msg, "L2 Imbalance: ", format(round(summ$l2_imbalance,3), nsmall=3), "\n", "Percent improvement from uniform weights: ", @@ -452,16 +470,16 @@ print.summary.augsynth <- function(x, ...) { out_msg <- paste(out_msg, "Covariate L2 Imbalance: ", - format(round(summ$covariate_l2_imbalance,3), + format(round(summ$covariate_l2_imbalance,3), nsmall=3), "\n", "Percent improvement from uniform weights: ", - format(round(1 - summ$scaled_covariate_l2_imbalance,3)*100), + format(round(1 - summ$scaled_covariate_l2_imbalance,3)*100), "%\n\n", sep="") } - out_msg <- paste(out_msg, + out_msg <- paste(out_msg, "Avg Estimated Bias: ", format(round(mean(summ$bias_est), 3),nsmall=3), "\n\n", "Inference type: ", @@ -471,30 +489,30 @@ print.summary.augsynth <- function(x, ...) { cat(out_msg) if(summ$inf_type == "jackknife") { - out_att <- summ$att[t_int:t_final,] %>% + out_att <- summ$att[t_int:t_final,] %>% select(Time, Estimate, Std.Error) } else if(summ$inf_type == "conformal") { - out_att <- summ$att[t_int:t_final,] %>% + out_att <- summ$att[t_int:t_final,] %>% select(Time, Estimate, lower_bound, upper_bound, p_val) - names(out_att) <- c("Time", "Estimate", + names(out_att) <- c("Time", "Estimate", paste0((1 - summ$alpha) * 100, "% CI Lower Bound"), paste0((1 - summ$alpha) * 100, "% CI Upper Bound"), paste0("p Value")) } else if(summ$inf_type == "jackknife+") { - out_att <- summ$att[t_int:t_final,] %>% + out_att <- summ$att[t_int:t_final,] %>% select(Time, Estimate, lower_bound, upper_bound) - names(out_att) <- c("Time", "Estimate", + names(out_att) <- c("Time", "Estimate", paste0((1 - summ$alpha) * 100, "% CI Lower Bound"), paste0((1 - summ$alpha) * 100, "% CI Upper Bound")) } else { - out_att <- summ$att[t_int:t_final,] %>% + out_att <- summ$att[t_int:t_final,] %>% select(Time, Estimate) } out_att %>% mutate_at(vars(-Time), ~ round(., 3)) %>% print(row.names = F) - + } #' Plot function for summary function for augsynth @@ -509,7 +527,7 @@ plot.summary.augsynth <- function(x, inf = T, ...) { # } else { # inf <- T # } - + p <- summ$att %>% ggplot2::ggplot(ggplot2::aes(x=Time, y=Estimate)) if(inf) { @@ -526,7 +544,7 @@ plot.summary.augsynth <- function(x, inf = T, ...) { } p + ggplot2::geom_line() + ggplot2::geom_vline(xintercept=summ$t_int, lty=2) + - ggplot2::geom_hline(yintercept=0, lty=2) + + ggplot2::geom_hline(yintercept=0, lty=2) + ggplot2::theme_bw() } @@ -534,7 +552,7 @@ plot.summary.augsynth <- function(x, inf = T, ...) { #' augsynth -#' +#' #' @description A package implementing the Augmented Synthetic Controls Method #' @docType package #' @name augsynth-package @@ -545,9 +563,9 @@ plot.summary.augsynth <- function(x, inf = T, ...) { #' @import tidyr #' @importFrom stats terms #' @importFrom stats formula -#' @importFrom stats update -#' @importFrom stats delete.response -#' @importFrom stats model.matrix -#' @importFrom stats model.frame +#' @importFrom stats update +#' @importFrom stats delete.response +#' @importFrom stats model.matrix +#' @importFrom stats model.frame #' @importFrom stats na.omit NULL diff --git a/tests/testthat/test_format.R b/tests/testthat/test_format.R index 6f79d55..8154ecf 100644 --- a/tests/testthat/test_format.R +++ b/tests/testthat/test_format.R @@ -6,9 +6,24 @@ basque <- basque %>% mutate(trt = case_when(year < 1975 ~ 0, regionno != 17 ~0, regionno == 17 ~ 1)) %>% filter(regionno != 1) - + +test_that("augsynth exits gracefully with missing data or unbalanced panels", { + + kansas_drop <- kansas[-15,] # remove an arbitrary row + + expect_error(augsynth(lngdpcapita ~ treated, fips, year_qtr, kansas_drop, + progfunc = "None", scm = T)) + + kansas_na <- kansas # set a missing value + kansas_na[12, "lngdpcapita"] = NA + + expect_error(augsynth(lngdpcapita ~ treated, fips, year_qtr, kansas_na, + progfunc = "None", scm = T)) +} +) + test_that("format_data creates matrices with the right dimensions", { - + dat <- format_data(quo(gdpcap), quo(trt), quo(regionno), quo(year),1975, basque) test_dim <- function(obj, d) { @@ -23,7 +38,7 @@ test_that("format_data creates matrices with the right dimensions", { test_that("format_synth creates matrices with the right dimensions", { - + dat <- format_data(quo(gdpcap), quo(trt), quo(regionno), quo(year),1975, basque) syn_dat <- format_synth(dat$X, dat$trt, dat$y) test_dim <- function(obj, d) { @@ -103,4 +118,4 @@ test_that("multisynth throws errors when there aren't enough pre-treatment times expect_equal(true_trt, sort(dat$trt[is.finite(dat$trt)])) - }) \ No newline at end of file + })