Skip to content

Commit 3138582

Browse files
committed
Add cross-fitting support for cf_auc() with ML learners
- Add .compute_auc_crossfit() function in variance.R for DR AUC with cross-fitting - Update .cross_fit_nuisance() to return q_hat (outcome probability) for AUC - Add cross_fit and n_folds parameters to cf_auc() - Update cf_auc documentation with cross-fitting reference - Add tests for cf_auc with cross_fit and ML learners - Update ml-integration vignette with cf_auc() example - Update NEWS.md to document cf_auc cross-fitting support References: - Li et al. (2022) Biometrics for AUC transportability - Chernozhukov et al. (2018) for double/debiased ML
1 parent 38ca4b3 commit 3138582

7 files changed

Lines changed: 373 additions & 34 deletions

File tree

NEWS.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ automatic cross-fitting for valid inference.
1111
* Automatic cross-fitting when `ml_learner` specs are detected
1212
* Seamlessly integrates with existing `propensity_model`/`outcome_model` arguments
1313

14+
### Cross-Fitting Support
15+
* `cf_mse()` - Full support for ML learners with cross-fitting
16+
* `cf_auc()` - Full support for ML learners with cross-fitting (DR estimator)
17+
1418
### Supported Learners
1519
* **ranger** - Fast random forest implementation
1620
* **xgboost** - Gradient boosting (XGBoost)
@@ -21,12 +25,21 @@ automatic cross-fitting for valid inference.
2125

2226
### Usage Example
2327
```r
28+
# MSE with ML learners
2429
cf_mse(
2530
predictions = pred, outcomes = y, treatment = a, covariates = df,
2631
propensity_model = ml_learner("ranger", num.trees = 500),
2732
outcome_model = ml_learner("xgboost", nrounds = 100),
2833
cross_fit = TRUE
2934
)
35+
36+
# AUC with ML learners
37+
cf_auc(
38+
predictions = pred, outcomes = y, treatment = a, covariates = df,
39+
propensity_model = ml_learner("ranger", num.trees = 500),
40+
outcome_model = ml_learner("ranger", num.trees = 500),
41+
cross_fit = TRUE
42+
)
3043
```
3144

3245
### Documentation
@@ -38,6 +51,10 @@ cf_mse(
3851
Chernozhukov, V., et al. (2018). Double/debiased machine learning for treatment
3952
and structural parameters. *The Econometrics Journal*, 21(1), C1-C68.
4053

54+
Li, B., Gatsonis, C., Dahabreh, I. J., & Steingrimsson, J. A. (2022).
55+
Estimating the area under the ROC curve when transporting a prediction
56+
model to a target population. *Biometrics*, 79(3), 2343-2356.
57+
4158
---
4259

4360
# cfperformance 0.2.0

R/cf_auc.R

Lines changed: 70 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
#' of treatment.
3232
#'
3333
#' **Doubly Robust (DR) Estimator**: Combines OM and IPW for double robustness.
34+
#' When `cross_fit = TRUE`, uses cross-fitting for valid inference with flexible
35+
#' ML methods (see [ml_learner()]).
3436
#'
3537
#' @references
3638
#' Boyer, C. B., Dahabreh, I. J., & Steingrimsson, J. A. (2025).
@@ -39,9 +41,10 @@
3941
#'
4042
#' Li, B., Gatsonis, C., Dahabreh, I. J., & Steingrimsson, J. A. (2022).
4143
#' "Estimating the area under the ROC curve when transporting a prediction
42-
#' model to a target population." *Biometrics*.
44+
#' model to a target population." *Biometrics*, 79(3), 2343-2356.
45+
#' \doi{10.1111/biom.13796}
4346
#'
44-
#' @seealso [cf_mse()], [cf_calibration()]
47+
#' @seealso [cf_mse()], [cf_calibration()], [ml_learner()]
4548
#'
4649
#' @export
4750
#'
@@ -76,6 +79,8 @@ cf_auc <- function(predictions,
7679
se_method = c("bootstrap", "influence", "none"),
7780
n_boot = 500,
7881
conf_level = 0.95,
82+
cross_fit = FALSE,
83+
n_folds = 5,
7984
parallel = FALSE,
8085
ncores = NULL,
8186
...) {
@@ -92,40 +97,79 @@ cf_auc <- function(predictions,
9297

9398
n <- length(outcomes)
9499

100+
# Detect if ml_learners are provided
101+
use_ml_propensity <- is_ml_learner(propensity_model)
102+
use_ml_outcome <- is_ml_learner(outcome_model)
103+
104+
# Initialize SE variables
105+
se <- NULL
106+
ci_lower <- NULL
107+
ci_upper <- NULL
108+
95109
# Fit nuisance models if needed
96110
if (estimator != "naive") {
97-
nuisance <- .fit_nuisance_models(
98-
treatment = treatment,
111+
if (cross_fit && estimator == "dr") {
112+
# Use cross-fitting for DR estimator
113+
cf_result <- .compute_auc_crossfit(
114+
predictions = predictions,
115+
outcomes = outcomes,
116+
treatment = treatment,
117+
covariates = covariates,
118+
treatment_level = treatment_level,
119+
K = n_folds,
120+
propensity_learner = if (use_ml_propensity) propensity_model else NULL,
121+
outcome_learner = if (use_ml_outcome) outcome_model else NULL,
122+
parallel = parallel,
123+
...
124+
)
125+
estimate <- cf_result$estimate
126+
nuisance <- list(propensity = NULL, outcome = NULL,
127+
cross_fitted = TRUE,
128+
ps = cf_result$ps,
129+
q = cf_result$q,
130+
folds = cf_result$folds)
131+
132+
# SE from cross-fitting
133+
if (se_method == "influence") {
134+
se <- cf_result$se
135+
z <- qnorm(1 - (1 - conf_level) / 2)
136+
ci_lower <- estimate - z * se
137+
ci_upper <- estimate + z * se
138+
}
139+
} else {
140+
nuisance <- .fit_nuisance_models(
141+
treatment = treatment,
142+
outcomes = outcomes,
143+
covariates = covariates,
144+
treatment_level = treatment_level,
145+
propensity_model = propensity_model,
146+
outcome_model = outcome_model
147+
)
148+
estimate <- NULL
149+
}
150+
} else {
151+
nuisance <- list(propensity = NULL, outcome = NULL)
152+
estimate <- NULL
153+
}
154+
155+
# Compute point estimate (if not already computed via cross-fitting)
156+
if (is.null(estimate)) {
157+
estimate <- .compute_auc(
158+
predictions = predictions,
99159
outcomes = outcomes,
160+
treatment = treatment,
100161
covariates = covariates,
101162
treatment_level = treatment_level,
102-
propensity_model = propensity_model,
103-
outcome_model = outcome_model
163+
estimator = estimator,
164+
propensity_model = nuisance$propensity,
165+
outcome_model = nuisance$outcome
104166
)
105-
} else {
106-
nuisance <- list(propensity = NULL, outcome = NULL)
107167
}
108168

109-
# Compute point estimate
110-
estimate <- .compute_auc(
111-
predictions = predictions,
112-
outcomes = outcomes,
113-
treatment = treatment,
114-
covariates = covariates,
115-
treatment_level = treatment_level,
116-
estimator = estimator,
117-
propensity_model = nuisance$propensity,
118-
outcome_model = nuisance$outcome
119-
)
120-
121169
# Naive estimate
122170
naive_estimate <- .compute_auc_naive(predictions, outcomes)
123171

124-
# Standard errors
125-
se <- NULL
126-
ci_lower <- NULL
127-
ci_upper <- NULL
128-
172+
# Standard errors (if not already computed via cross-fitting)
129173
if (se_method == "bootstrap") {
130174
boot_result <- .bootstrap_auc(
131175
predictions = predictions,
@@ -142,7 +186,7 @@ cf_auc <- function(predictions,
142186
se <- boot_result$se
143187
ci_lower <- boot_result$ci_lower
144188
ci_upper <- boot_result$ci_upper
145-
} else if (se_method == "influence") {
189+
} else if (se_method == "influence" && !(cross_fit && estimator == "dr")) {
146190
se <- .influence_se_auc(
147191
predictions = predictions,
148192
outcomes = outcomes,

R/variance.R

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,8 @@ NULL
523523

524524
# Initialize output vectors
525525
ps_cf <- numeric(n) # Cross-fitted propensity scores
526-
h_cf <- numeric(n) # Cross-fitted conditional loss predictions
526+
q_cf <- numeric(n) # Cross-fitted outcome probabilities (for AUC)
527+
h_cf <- numeric(n) # Cross-fitted conditional loss predictions (for MSE)
527528

528529
for (k in 1:K) {
529530
# Training and validation indices
@@ -575,13 +576,17 @@ NULL
575576
type = "response")
576577
}
577578

579+
# Store outcome probability (q_hat) for AUC
580+
q_cf[val_idx] <- pY
581+
578582
# Compute conditional loss: E[(Y - pred)^2 | X, A=a] = p(1-p) + (p - pred)^2
579583
# For binary Y: E[Y^2] = p, so E[(Y - pred)^2] = p - 2*p*pred + pred^2
580584
h_cf[val_idx] <- pY - 2 * predictions[val_idx] * pY + predictions[val_idx]^2
581585
}
582586

583587
list(
584588
ps = ps_cf,
589+
q = q_cf,
585590
h = h_cf,
586591
folds = folds
587592
)
@@ -644,3 +649,115 @@ NULL
644649
folds = cf_nuisance$folds
645650
)
646651
}
652+
653+
654+
#' Compute DR AUC with Cross-Fitting
655+
#'
656+
#' Computes the doubly robust AUC estimator using cross-fitted nuisance functions.
657+
#'
658+
#' @inheritParams cf_auc
659+
#' @param K Number of folds for cross-fitting.
660+
#' @param propensity_learner Optional ml_learner for propensity model.
661+
#' @param outcome_learner Optional ml_learner for outcome model.
662+
#' @param parallel Logical for parallel processing (not yet implemented).
663+
#' @param ... Additional arguments.
664+
#'
665+
#' @return Doubly robust AUC estimate with cross-fitting.
666+
#'
667+
#' @references
668+
#' Li, B., Gatsonis, C., Dahabreh, I. J., & Steingrimsson, J. A. (2022).
669+
#' "Estimating the area under the ROC curve when transporting a prediction
670+
#' model to a target population." *Biometrics*, 79(3), 2343-2356.
671+
#' \doi{10.1111/biom.13796}
672+
#'
673+
#' @keywords internal
674+
.compute_auc_crossfit <- function(predictions, outcomes, treatment, covariates,
675+
treatment_level, K = 5,
676+
propensity_learner = NULL,
677+
outcome_learner = NULL,
678+
parallel = FALSE,
679+
...) {
680+
681+
n <- length(outcomes)
682+
683+
# Get cross-fitted nuisance functions
684+
cf_nuisance <- .cross_fit_nuisance(
685+
treatment = treatment,
686+
outcomes = outcomes,
687+
covariates = covariates,
688+
treatment_level = treatment_level,
689+
predictions = predictions,
690+
K = K,
691+
propensity_learner = propensity_learner,
692+
outcome_learner = outcome_learner,
693+
parallel = parallel,
694+
...
695+
)
696+
697+
ps <- cf_nuisance$ps
698+
q_hat <- cf_nuisance$q
699+
700+
# Treatment indicator
701+
I_a <- as.numeric(treatment == treatment_level)
702+
703+
# Truncate propensity scores for stability
704+
ps <- pmax(pmin(ps, 0.99), 0.01)
705+
706+
# Concordance indicator matrix (f_i > f_j)
707+
ind_f <- outer(predictions, predictions, ">")
708+
709+
# IPW component: reweight observed cases and controls
710+
pi_ratio <- I_a / ps
711+
mat_ipw0 <- outer(I_a * (outcomes == 1), I_a * (outcomes == 0), "*") *
712+
outer(pi_ratio, pi_ratio, "*")
713+
mat_ipw1 <- mat_ipw0 * ind_f
714+
715+
# OM component: use outcome model predictions
716+
mat_om0 <- outer(q_hat, 1 - q_hat, "*")
717+
mat_om1 <- mat_om0 * ind_f
718+
719+
# DR correction term: subtract overlap
720+
mat_dr0 <- outer(I_a * pi_ratio * q_hat, I_a * pi_ratio * (1 - q_hat), "*")
721+
diag(mat_dr0) <- 0
722+
mat_dr1 <- mat_dr0 * ind_f
723+
724+
# DR estimator
725+
numerator <- sum(mat_ipw1) + sum(mat_om1) - sum(mat_dr1)
726+
denominator <- sum(mat_ipw0) + sum(mat_om0) - sum(mat_dr0)
727+
estimate <- numerator / denominator
728+
729+
# Influence function for SE based on DeLong-like approach
730+
731+
# Compute V10: for each case i, proportion of controls with lower score
732+
# Compute V01: for each control j, proportion of cases with higher score
733+
cases <- which(outcomes == 1)
734+
controls <- which(outcomes == 0)
735+
n1 <- length(cases)
736+
n0 <- length(controls)
737+
738+
V10 <- numeric(n1)
739+
V01 <- numeric(n0)
740+
741+
for (i in seq_len(n1)) {
742+
V10[i] <- mean(as.numeric(predictions[cases[i]] > predictions[controls]) +
743+
0.5 * as.numeric(predictions[cases[i]] == predictions[controls]))
744+
}
745+
746+
for (j in seq_len(n0)) {
747+
V01[j] <- mean(as.numeric(predictions[controls[j]] < predictions[cases]) +
748+
0.5 * as.numeric(predictions[controls[j]] == predictions[cases]))
749+
}
750+
751+
# DeLong variance estimate
752+
S10 <- if (n1 > 1) var(V10) else 0
753+
S01 <- if (n0 > 1) var(V01) else 0
754+
se <- sqrt(S10 / n1 + S01 / n0)
755+
756+
list(
757+
estimate = estimate,
758+
se = se,
759+
ps = ps,
760+
q = q_hat,
761+
folds = cf_nuisance$folds
762+
)
763+
}

man/cf_auc.Rd

Lines changed: 13 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)