-
Notifications
You must be signed in to change notification settings - Fork 7
Description
Issue at hand
ArseniyZvyagintsevQC brought the following to our attention:
Let us assume a binary treatment variant scenario in which we want to work with in-sample predictions, i.e. is_oos=False.
The current implementation would go about fitting five models, three of which considered nuisance models and two of which considered treatment models:
| model | target | cross-fitting dataset | stage | name |
|---|---|---|---|---|
| nuisance | "treatment_variant" |
|||
| nuisance | "treatment_variant" |
|||
| nuisance/propensity | "propensity_model" |
|||
| treatment | "control_effect_model" |
|||
| treatment | "treatment_effect_model" |
More background on this here.
Note that each of these models is cross-fitted. More precisely, each is cross-fitted wrt the data it has seen at training time.
Let's suppose now that we are at inference time and encounter an in-sample data point
In order to come up with a CATE estimate, the predict method will run
-
$\hat{\tau}_0(X_i)$ withis_oos=Truesince this datapoint has not been seen during training time of the model$\hat{\tau}_0$ -
$\hat{\tau}_1(X_i)$ withis_oos=Falsesince this datapoint has indeed been seen during the training time of the model$\hat{\tau}_1$
The latter call makes sure we avoid leakage in
even though
Next steps
We can devise an extreme, naïve approach to counteract this issue by training every type of model once per datapoint. Clearly, this ensures the absence of data leakage. The challenge with this issue revolves around coming up with a design that
- allows for arbitrary numbers (>1, <=n) of cross-fitting folds, i.e. not fixing it to be equal to the number of training data points
- integrates well into the structure of the library