Skip to content

Commit 237278f

Browse files
committed
Fix AUC cross-fitting: use GRF for better calibration
Key changes: - Fix .predict_grf() to work with formula containing '.' - Update propensity score truncation bounds to [0.025, 0.975] - Truncate outcome probabilities (q_hat) for stability - Add warning when >10% of propensity scores at truncation bounds - Update vignette to recommend GRF over ranger for AUC estimation - Document that GRF's 'honest' estimation produces well-calibrated probabilities needed for the doubly robust AUC estimator The DR AUC estimator is sensitive to poorly calibrated probability estimates due to its pairwise/outer-product structure. Standard random forests (ranger) can produce extreme predictions that destabilize the estimator. GRF's probability_forest with honesty produces estimates nearly identical to GLM in our tests.
1 parent 3138582 commit 237278f

3 files changed

Lines changed: 45 additions & 11 deletions

File tree

R/ml_learner.R

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,8 +330,17 @@ print.ml_learner <- function(x, ...) {
330330
}
331331

332332
.predict_grf <- function(fit, newdata) {
333-
formula <- attr(fit, "formula")
334-
x <- model.matrix(formula, data = newdata)[, -1, drop = FALSE]
333+
# GRF stores the training data columns - use those names
334+
# For probability_forest and regression_forest, the X matrix column names are stored
335+
train_cols <- colnames(fit$X.orig)
336+
337+
if (is.null(train_cols)) {
338+
# Fallback: use all columns from newdata
339+
x <- as.matrix(newdata)
340+
} else {
341+
# Use the same columns as training
342+
x <- as.matrix(newdata[, train_cols, drop = FALSE])
343+
}
335344

336345
pred <- predict(fit, newdata = x)
337346

R/variance.R

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,10 @@ NULL
553553
if (treatment_level == 0) {
554554
ps_pred <- 1 - ps_pred
555555
}
556-
ps_cf[val_idx] <- pmax(pmin(ps_pred, 0.99), 0.01)
556+
557+
# Truncate propensity scores for stability
558+
# Use 0.025-0.975 bounds to avoid extreme weights while allowing flexibility
559+
ps_cf[val_idx] <- pmax(pmin(ps_pred, 0.975), 0.025)
557560

558561
# Fit outcome model on training fold (among treated)
559562
subset_train <- train_idx[treatment[train_idx] == treatment_level]
@@ -577,7 +580,8 @@ NULL
577580
}
578581

579582
# Store outcome probability (q_hat) for AUC
580-
q_cf[val_idx] <- pY
583+
# Truncate to avoid extreme values that cause instability in DR estimator
584+
q_cf[val_idx] <- pmax(pmin(pY, 0.99), 0.01)
581585

582586
# Compute conditional loss: E[(Y - pred)^2 | X, A=a] = p(1-p) + (p - pred)^2
583587
# For binary Y: E[Y^2] = p, so E[(Y - pred)^2] = p - 2*p*pred + pred^2
@@ -700,8 +704,18 @@ NULL
700704
# Treatment indicator
701705
I_a <- as.numeric(treatment == treatment_level)
702706

703-
# Truncate propensity scores for stability
704-
ps <- pmax(pmin(ps, 0.99), 0.01)
707+
# Note: propensity scores are already truncated in .cross_fit_nuisance()
708+
# Additional truncation here is defensive
709+
710+
# Check for extreme propensity scores and warn
711+
ps_extreme <- sum(ps <= 0.025 | ps >= 0.975)
712+
if (ps_extreme > 0.1 * n) {
713+
warning(sprintf(
714+
"%.0f%% of propensity scores are at truncation bounds. ",
715+
100 * ps_extreme / n
716+
), "Consider using a simpler propensity model or more regularization.",
717+
call. = FALSE)
718+
}
705719

706720
# Concordance indicator matrix (f_i > f_j)
707721
ind_f <- outer(predictions, predictions, ">")

vignettes/ml-integration.Rmd

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -293,17 +293,23 @@ print(result_mse)
293293

294294
### AUC with ML Learners
295295

296-
```{r auc-final-example, eval=requireNamespace("ranger", quietly = TRUE)}
297-
# AUC estimation with ML nuisance models
296+
For AUC estimation, we recommend using **GRF (Generalized Random Forests)**
297+
rather than standard random forests. GRF's "honesty" property produces
298+
well-calibrated probability estimates, which is critical for the doubly robust
299+
AUC estimator. Standard random forests can produce extreme probability
300+
predictions that destabilize the estimator.
301+
302+
```{r auc-final-example, eval=requireNamespace("grf", quietly = TRUE)}
303+
# AUC estimation with GRF nuisance models (recommended)
298304
result_auc <- cf_auc(
299305
predictions = pred,
300306
outcomes = y,
301307
treatment = a,
302308
covariates = df,
303309
treatment_level = 0,
304310
estimator = "dr",
305-
propensity_model = ml_learner("ranger", num.trees = 100),
306-
outcome_model = ml_learner("ranger", num.trees = 100),
311+
propensity_model = ml_learner("grf", num.trees = 500),
312+
outcome_model = ml_learner("grf", num.trees = 500),
307313
cross_fit = TRUE,
308314
n_folds = 5,
309315
se_method = "influence"
@@ -328,7 +334,12 @@ print(result_auc)
328334
exploration, use fewer trees/rounds, then increase for final analysis.
329335

330336
5. **Check for extreme propensity scores**: ML methods can produce very extreme
331-
propensity scores. The package truncates these at [0.01, 0.99] by default.
337+
propensity scores. The package truncates these at [0.025, 0.975] by default.
338+
339+
6. **Use GRF for AUC estimation**: The doubly robust AUC estimator is sensitive
340+
to poorly calibrated probability estimates. GRF's "honest" estimation
341+
produces well-calibrated probabilities, while standard random forests
342+
(ranger) can produce extreme predictions that destabilize the estimator.
332343

333344
## References
334345

0 commit comments

Comments
 (0)