X-Learner: Use the same sample splits in all base models.#84
X-Learner: Use the same sample splits in all base models.#84Kevin Klein (kklein) wants to merge 11 commits intomainfrom
Conversation
|
FYI Matthias Loeffler (@MatthiasLoefflerQC) I created a first draft of how the same splits could be used for all base learners, including treatment models. As of now the estimates are still clearly awry, e.g. an RMSE of ~13 compared to ~0.05. This happens for both in-sample and out-of-sample estimation. I currently have no real ideas on what's going wrong; will try to make some progress still |
| if synchronize_cross_fitting: | ||
| cv_split_indices = self._split( | ||
| index_matrix(X, self._treatment_variants_indices[treatment_variant]) | ||
| treatment_indices = np.where( |
There was a problem hiding this comment.
This is an opaque way of turning an array [True, True, False, False, True] into an array [0, 1, 4]. Not sure if there's a neater way of doing that.
There was a problem hiding this comment.
[index for index, value in enumerate(vector) if value] would work too, I guess, and is more verbose, but I like the np.where :)
0785dcd to
410e9e7
Compare
The base models all seem to be doing fine wrt their individual targets at hand. Yet, when I compare pairs of treatment effect model estimates at prediction time, it become blatantly apparent that something is going wrong: Update: These discrepancies have been substantially reduced by bbfff15. The RMSEs on true cates are still massive when compared to status quo. |
Co-authored-by: Matthias Loeffler <106818324+MatthiasLoefflerQC@users.noreply.github.com>
| model_ord=treatment_variant - 1, | ||
| is_oos=False, | ||
| ) | ||
| )[control_indices] |
There was a problem hiding this comment.
do we need is_oos=False below (and likewise for tau_hat_treatment)? Might be worth a try.
TODOs:
cvare actually not used for training when passingcvtocross_validate.CrossFitEstimator.synchronize_cross_fittingshould be allowed to beFalsefor the X-Learner.Observations
yields the following output:
Checklist
CHANGELOG.rstentry