Skip to content

Commit f248655

Browse files
committed
Add outcome_type parameter and cross-fitting for transportability
1 parent b004c2e commit f248655

7 files changed

Lines changed: 3068 additions & 363 deletions

File tree

R/cf_mse.R

Lines changed: 66 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,14 @@
1515
#' - `"dr"`: Doubly robust estimator (default)
1616
#' @param propensity_model Optional fitted propensity score model. If NULL,
1717
#' a logistic regression model is fit using the covariates.
18-
#' @param outcome_model Optional fitted outcome model. If NULL, a logistic
19-
#' regression model is fit using the covariates among treated/untreated.
18+
#' @param outcome_model Optional fitted outcome model. If NULL, a regression
19+
#' model is fit using the covariates among treated/untreated. For binary
20+
#' outcomes, this should be a model for E\[Y|X,A\] (binomial family). For
21+
#' continuous outcomes, this should be a model for E\[L|X,A\] (gaussian family).
22+
#' @param outcome_type Character string specifying the outcome type:
23+
#' - `"auto"`: Auto-detect from data (default)
24+
#' - `"binary"`: Binary outcome (0/1) - uses efficient transformation
25+
#' - `"continuous"`: Continuous outcome - models loss directly
2026
#' @param se_method Method for standard error estimation:
2127
#' - `"bootstrap"`: Bootstrap standard errors (default)
2228
#' - `"influence"`: Influence function-based standard errors
@@ -99,6 +105,7 @@ cf_mse <- function(predictions,
99105
estimator = c("dr", "cl", "ipw", "naive"),
100106
propensity_model = NULL,
101107
outcome_model = NULL,
108+
outcome_type = c("auto", "binary", "continuous"),
102109
se_method = c("bootstrap", "influence", "none"),
103110
n_boot = 500,
104111
conf_level = 0.95,
@@ -112,13 +119,19 @@ cf_mse <- function(predictions,
112119
# Input validation
113120
estimator <- match.arg(estimator)
114121
se_method <- match.arg(se_method)
122+
outcome_type <- match.arg(outcome_type)
115123

116124
# Validate inputs
117125

118126
.validate_inputs(predictions, outcomes, treatment, covariates)
119127

120128
n <- length(outcomes)
121129

130+
# Auto-detect outcome type if not specified
131+
if (outcome_type == "auto") {
132+
outcome_type <- if (all(outcomes %in% c(0, 1))) "binary" else "continuous"
133+
}
134+
122135
# Initialize SE variables
123136
se <- NULL
124137
ci_lower <- NULL
@@ -141,6 +154,7 @@ cf_mse <- function(predictions,
141154
K = n_folds,
142155
propensity_learner = if (use_ml_propensity) propensity_model else NULL,
143156
outcome_learner = if (use_ml_outcome) outcome_model else NULL,
157+
outcome_type = outcome_type,
144158
parallel = parallel,
145159
...
146160
)
@@ -165,7 +179,10 @@ cf_mse <- function(predictions,
165179
covariates = covariates,
166180
treatment_level = treatment_level,
167181
propensity_model = propensity_model,
168-
outcome_model = outcome_model
182+
outcome_model = outcome_model,
183+
estimator = estimator,
184+
outcome_type = outcome_type,
185+
predictions = predictions
169186
)
170187
estimate <- NULL
171188
}
@@ -184,7 +201,8 @@ cf_mse <- function(predictions,
184201
treatment_level = treatment_level,
185202
estimator = estimator,
186203
propensity_model = nuisance$propensity,
187-
outcome_model = nuisance$outcome
204+
outcome_model = nuisance$outcome,
205+
outcome_type = outcome_type
188206
)
189207
}
190208

@@ -262,7 +280,7 @@ cf_mse <- function(predictions,
262280
# Internal function to compute MSE
263281
.compute_mse <- function(predictions, outcomes, treatment, covariates,
264282
treatment_level, estimator, propensity_model,
265-
outcome_model) {
283+
outcome_model, outcome_type = "binary") {
266284

267285
n <- length(outcomes)
268286
loss <- (outcomes - predictions)^2
@@ -288,9 +306,17 @@ cf_mse <- function(predictions,
288306

289307
# Get outcome predictions (conditional loss) for ALL observations
290308
if (!is.null(outcome_model)) {
291-
# For binary outcomes with squared error loss
309+
# For binary outcomes, the outcome model predicts E[Y|X,A] = pY
310+
# and we transform to E[L|X,A] = pY - 2*pred*pY + pred^2
311+
# For continuous outcomes, the model directly predicts E[L|X,A]
292312
pY <- .predict_nuisance(outcome_model, covariates)
293-
h <- pY - 2 * predictions * pY + predictions^2
313+
if (outcome_type == "binary") {
314+
# E[(Y - pred)^2 | X] = E[Y | X] - 2*pred*E[Y | X] + pred^2
315+
# since Y^2 = Y for binary Y
316+
h <- pY - 2 * predictions * pY + predictions^2
317+
} else {
318+
h <- pY
319+
}
294320
}
295321

296322
# Indicator for counterfactual treatment
@@ -301,9 +327,9 @@ cf_mse <- function(predictions,
301327
return(mean(h))
302328

303329
} else if (estimator == "ipw") {
304-
# IPW estimator
330+
# IPW estimator (Horvitz-Thompson style)
305331
weights <- I_a / ps
306-
return(sum(weights * loss) / sum(I_a))
332+
return(mean(weights * loss))
307333

308334
} else if (estimator == "dr") {
309335
# Doubly robust estimator
@@ -316,15 +342,17 @@ cf_mse <- function(predictions,
316342
# Internal function to fit nuisance models
317343
.fit_nuisance_models <- function(treatment, outcomes, covariates,
318344
treatment_level, propensity_model,
319-
outcome_model) {
345+
outcome_model, estimator = "dr",
346+
outcome_type = "binary",
347+
predictions = NULL) {
320348

321349
# Convert covariates to data frame if needed
322350
if (!is.data.frame(covariates)) {
323351
covariates <- as.data.frame(covariates)
324352
}
325353

326-
# Fit propensity model if not provided
327-
if (is.null(propensity_model)) {
354+
# Fit propensity model if not provided (needed for ipw and dr)
355+
if (estimator %in% c("ipw", "dr") && is.null(propensity_model)) {
328356
ps_data <- cbind(A = treatment, covariates)
329357
propensity_model <- glm(A ~ ., data = ps_data, family = binomial())
330358
} else if (is_ml_learner(propensity_model)) {
@@ -337,11 +365,22 @@ cf_mse <- function(predictions,
337365
data = ps_data, family = "binomial")
338366
}
339367

340-
# Fit outcome model if not provided (among those with counterfactual treatment)
341-
if (is.null(outcome_model)) {
368+
# Fit outcome model if not provided (needed for cl and dr only)
369+
# For binary outcomes: model E[Y | X, A=a] and transform to loss later
370+
# For continuous outcomes: model E[L | X, A=a] directly
371+
if (estimator %in% c("cl", "dr") && is.null(outcome_model)) {
342372
subset_idx <- treatment == treatment_level
343-
outcome_data <- cbind(Y = outcomes, covariates)[subset_idx, ]
344-
outcome_model <- glm(Y ~ ., data = outcome_data, family = binomial())
373+
374+
if (outcome_type == "binary") {
375+
# Model E[Y | X, A=a] - the transformation to loss happens in .compute_mse
376+
outcome_data <- cbind(Y = outcomes, covariates)[subset_idx, ]
377+
outcome_model <- glm(Y ~ ., data = outcome_data, family = binomial())
378+
} else {
379+
# Model E[L | X, A=a] directly for continuous outcomes
380+
loss <- (outcomes - predictions)^2
381+
outcome_data <- cbind(L = loss, covariates)[subset_idx, ]
382+
outcome_model <- glm(L ~ ., data = outcome_data, family = gaussian())
383+
}
345384
# Store the full data for prediction
346385
attr(outcome_model, "full_data") <- cbind(Y = outcomes, covariates)
347386
} else if (is_ml_learner(outcome_model)) {
@@ -350,9 +389,17 @@ cf_mse <- function(predictions,
350389
"Using cross_fit=TRUE is recommended for ML learners.",
351390
call. = FALSE)
352391
subset_idx <- treatment == treatment_level
353-
outcome_data <- cbind(Y = outcomes, covariates)[subset_idx, ]
354-
outcome_model <- .fit_ml_learner(outcome_model, Y ~ .,
355-
data = outcome_data, family = "binomial")
392+
393+
if (outcome_type == "binary") {
394+
outcome_data <- cbind(Y = outcomes, covariates)[subset_idx, ]
395+
outcome_model <- .fit_ml_learner(outcome_model, Y ~ .,
396+
data = outcome_data, family = "binomial")
397+
} else {
398+
loss <- (outcomes - predictions)^2
399+
outcome_data <- cbind(L = loss, covariates)[subset_idx, ]
400+
outcome_model <- .fit_ml_learner(outcome_model, L ~ .,
401+
data = outcome_data, family = "gaussian")
402+
}
356403
}
357404

358405
list(propensity = propensity_model, outcome = outcome_model)

0 commit comments

Comments
 (0)