1515# ' - `"dr"`: Doubly robust estimator (default)
1616# ' @param propensity_model Optional fitted propensity score model. If NULL,
1717# ' a logistic regression model is fit using the covariates.
18- # ' @param outcome_model Optional fitted outcome model. If NULL, a logistic
19- # ' regression model is fit using the covariates among treated/untreated.
18+ # ' @param outcome_model Optional fitted outcome model. If NULL, a regression
19+ # ' model is fit using the covariates among treated/untreated. For binary
20+ # ' outcomes, this should be a model for E\[Y|X,A\] (binomial family). For
21+ # ' continuous outcomes, this should be a model for E\[L|X,A\] (gaussian family).
22+ # ' @param outcome_type Character string specifying the outcome type:
23+ # ' - `"auto"`: Auto-detect from data (default)
24+ # ' - `"binary"`: Binary outcome (0/1) - uses efficient transformation
25+ # ' - `"continuous"`: Continuous outcome - models loss directly
2026# ' @param se_method Method for standard error estimation:
2127# ' - `"bootstrap"`: Bootstrap standard errors (default)
2228# ' - `"influence"`: Influence function-based standard errors
@@ -99,6 +105,7 @@ cf_mse <- function(predictions,
99105 estimator = c(" dr" , " cl" , " ipw" , " naive" ),
100106 propensity_model = NULL ,
101107 outcome_model = NULL ,
108+ outcome_type = c(" auto" , " binary" , " continuous" ),
102109 se_method = c(" bootstrap" , " influence" , " none" ),
103110 n_boot = 500 ,
104111 conf_level = 0.95 ,
@@ -112,13 +119,19 @@ cf_mse <- function(predictions,
112119 # Input validation
113120 estimator <- match.arg(estimator )
114121 se_method <- match.arg(se_method )
122+ outcome_type <- match.arg(outcome_type )
115123
116124 # Validate inputs
117125
118126 .validate_inputs(predictions , outcomes , treatment , covariates )
119127
120128 n <- length(outcomes )
121129
130+ # Auto-detect outcome type if not specified
131+ if (outcome_type == " auto" ) {
132+ outcome_type <- if (all(outcomes %in% c(0 , 1 ))) " binary" else " continuous"
133+ }
134+
122135 # Initialize SE variables
123136 se <- NULL
124137 ci_lower <- NULL
@@ -141,6 +154,7 @@ cf_mse <- function(predictions,
141154 K = n_folds ,
142155 propensity_learner = if (use_ml_propensity ) propensity_model else NULL ,
143156 outcome_learner = if (use_ml_outcome ) outcome_model else NULL ,
157+ outcome_type = outcome_type ,
144158 parallel = parallel ,
145159 ...
146160 )
@@ -165,7 +179,10 @@ cf_mse <- function(predictions,
165179 covariates = covariates ,
166180 treatment_level = treatment_level ,
167181 propensity_model = propensity_model ,
168- outcome_model = outcome_model
182+ outcome_model = outcome_model ,
183+ estimator = estimator ,
184+ outcome_type = outcome_type ,
185+ predictions = predictions
169186 )
170187 estimate <- NULL
171188 }
@@ -184,7 +201,8 @@ cf_mse <- function(predictions,
184201 treatment_level = treatment_level ,
185202 estimator = estimator ,
186203 propensity_model = nuisance $ propensity ,
187- outcome_model = nuisance $ outcome
204+ outcome_model = nuisance $ outcome ,
205+ outcome_type = outcome_type
188206 )
189207 }
190208
@@ -262,7 +280,7 @@ cf_mse <- function(predictions,
262280# Internal function to compute MSE
263281.compute_mse <- function (predictions , outcomes , treatment , covariates ,
264282 treatment_level , estimator , propensity_model ,
265- outcome_model ) {
283+ outcome_model , outcome_type = " binary " ) {
266284
267285 n <- length(outcomes )
268286 loss <- (outcomes - predictions )^ 2
@@ -288,9 +306,17 @@ cf_mse <- function(predictions,
288306
289307 # Get outcome predictions (conditional loss) for ALL observations
290308 if (! is.null(outcome_model )) {
291- # For binary outcomes with squared error loss
309+ # For binary outcomes, the outcome model predicts E[Y|X,A] = pY
310+ # and we transform to E[L|X,A] = pY - 2*pred*pY + pred^2
311+ # For continuous outcomes, the model directly predicts E[L|X,A]
292312 pY <- .predict_nuisance(outcome_model , covariates )
293- h <- pY - 2 * predictions * pY + predictions ^ 2
313+ if (outcome_type == " binary" ) {
314+ # E[(Y - pred)^2 | X] = E[Y | X] - 2*pred*E[Y | X] + pred^2
315+ # since Y^2 = Y for binary Y
316+ h <- pY - 2 * predictions * pY + predictions ^ 2
317+ } else {
318+ h <- pY
319+ }
294320 }
295321
296322 # Indicator for counterfactual treatment
@@ -301,9 +327,9 @@ cf_mse <- function(predictions,
301327 return (mean(h ))
302328
303329 } else if (estimator == " ipw" ) {
304- # IPW estimator
330+ # IPW estimator (Horvitz-Thompson style)
305331 weights <- I_a / ps
306- return (sum (weights * loss ) / sum( I_a ))
332+ return (mean (weights * loss ))
307333
308334 } else if (estimator == " dr" ) {
309335 # Doubly robust estimator
@@ -316,15 +342,17 @@ cf_mse <- function(predictions,
316342# Internal function to fit nuisance models
317343.fit_nuisance_models <- function (treatment , outcomes , covariates ,
318344 treatment_level , propensity_model ,
319- outcome_model ) {
345+ outcome_model , estimator = " dr" ,
346+ outcome_type = " binary" ,
347+ predictions = NULL ) {
320348
321349 # Convert covariates to data frame if needed
322350 if (! is.data.frame(covariates )) {
323351 covariates <- as.data.frame(covariates )
324352 }
325353
326- # Fit propensity model if not provided
327- if (is.null(propensity_model )) {
354+ # Fit propensity model if not provided (needed for ipw and dr)
355+ if (estimator %in% c( " ipw " , " dr " ) && is.null(propensity_model )) {
328356 ps_data <- cbind(A = treatment , covariates )
329357 propensity_model <- glm(A ~ . , data = ps_data , family = binomial())
330358 } else if (is_ml_learner(propensity_model )) {
@@ -337,11 +365,22 @@ cf_mse <- function(predictions,
337365 data = ps_data , family = " binomial" )
338366 }
339367
340- # Fit outcome model if not provided (among those with counterfactual treatment)
341- if (is.null(outcome_model )) {
368+ # Fit outcome model if not provided (needed for cl and dr only)
369+ # For binary outcomes: model E[Y | X, A=a] and transform to loss later
370+ # For continuous outcomes: model E[L | X, A=a] directly
371+ if (estimator %in% c(" cl" , " dr" ) && is.null(outcome_model )) {
342372 subset_idx <- treatment == treatment_level
343- outcome_data <- cbind(Y = outcomes , covariates )[subset_idx , ]
344- outcome_model <- glm(Y ~ . , data = outcome_data , family = binomial())
373+
374+ if (outcome_type == " binary" ) {
375+ # Model E[Y | X, A=a] - the transformation to loss happens in .compute_mse
376+ outcome_data <- cbind(Y = outcomes , covariates )[subset_idx , ]
377+ outcome_model <- glm(Y ~ . , data = outcome_data , family = binomial())
378+ } else {
379+ # Model E[L | X, A=a] directly for continuous outcomes
380+ loss <- (outcomes - predictions )^ 2
381+ outcome_data <- cbind(L = loss , covariates )[subset_idx , ]
382+ outcome_model <- glm(L ~ . , data = outcome_data , family = gaussian())
383+ }
345384 # Store the full data for prediction
346385 attr(outcome_model , " full_data" ) <- cbind(Y = outcomes , covariates )
347386 } else if (is_ml_learner(outcome_model )) {
@@ -350,9 +389,17 @@ cf_mse <- function(predictions,
350389 " Using cross_fit=TRUE is recommended for ML learners." ,
351390 call. = FALSE )
352391 subset_idx <- treatment == treatment_level
353- outcome_data <- cbind(Y = outcomes , covariates )[subset_idx , ]
354- outcome_model <- .fit_ml_learner(outcome_model , Y ~ . ,
355- data = outcome_data , family = " binomial" )
392+
393+ if (outcome_type == " binary" ) {
394+ outcome_data <- cbind(Y = outcomes , covariates )[subset_idx , ]
395+ outcome_model <- .fit_ml_learner(outcome_model , Y ~ . ,
396+ data = outcome_data , family = " binomial" )
397+ } else {
398+ loss <- (outcomes - predictions )^ 2
399+ outcome_data <- cbind(L = loss , covariates )[subset_idx , ]
400+ outcome_model <- .fit_ml_learner(outcome_model , L ~ . ,
401+ data = outcome_data , family = " gaussian" )
402+ }
356403 }
357404
358405 list (propensity = propensity_model , outcome = outcome_model )
0 commit comments