Add SoftmaxWeightedSumFitter for synthetic control#822
Conversation
|
👋 Welcome to CausalPy, @thomaspinder! Thank you for opening your first pull request! We're excited to have you contribute to the project. 🎉 Here are a few tips to help your PR get merged smoothly:
A maintainer will review your changes soon. Thanks for helping make CausalPy better! 🚀 💼 LinkedIn Shoutout: Once your PR is merged, we'd love to give you a shoutout on LinkedIn to thank you for your contribution! If you're interested, just drop your LinkedIn profile URL in a comment below. |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #822 +/- ##
==========================================
+ Coverage 94.20% 94.26% +0.06%
==========================================
Files 78 78
Lines 11880 12005 +125
Branches 695 699 +4
==========================================
+ Hits 11192 11317 +125
Misses 496 496
Partials 192 192 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
drbenvincent
left a comment
There was a problem hiding this comment.
Overall assessment
Great first contribution! This is a clean, well-tested, self-contained addition that provides an alternative weight parameterization for synthetic control. The softmax-over-Normal-logits approach is well-motivated by the SDiD literature (Arkhangelsky et al., 2021) and serves as a solid foundation for the Bayesian SDiD work in PR #823. The docstrings are excellent — the Notes section clearly contrasts the two parameterizations and their regularization behaviour.
A few minor suggestions below. Most are non-blocking, but comment 3 (priors_from_data ignoring the data) deserves a response — I'd like to understand the reasoning before approving.
Code comments
1. _softmax_simplex_weights: consider validating the prior distribution
The docstring says the prior "Must be a Normal prior", but there's no runtime check. If someone passes a Dirichlet or Cauchy prior here the model would build without complaint but produce unexpected behaviour. A lightweight guard would help, especially since this helper will also be used in PR #823.
# suggestion (optional — depends on how public this helper is intended to be)
if prior.distribution != "Normal":
raise ValueError(
f"_softmax_simplex_weights expects a Normal prior, got {prior.distribution}"
)Not blocking — this is an internal helper prefixed with _, so the audience is narrow. But worth considering since #823 will call it from a second model class.
2. _softmax_simplex_weights: the 1D branch
The 1D/2D branching logic works correctly, but the condition n_rows == 1 and raw.ndim == 1 couples two independent checks. In practice this branch is only hit when priors_from_data happens to omit the treated_units dim, producing a 1D raw tensor. A brief inline comment explaining when each branch triggers would help future readers. The current comment says "1D case" but doesn't explain why it would be 1D.
3. priors_from_data ignores the data
Unlike WeightedSumFitter.priors_from_data, which adapts the Dirichlet shape to X.shape[1], this implementation always returns sigma=1.0 regardless of the data. This is fine as a default, but it means users who want tighter regularization (more DiD-like) or looser regularization (more SC-like) need to know to pass custom priors.
A docstring note explaining how to override sigma would be helpful for users:
# Example of overriding the logit prior scale for tighter regularization:
model = SoftmaxWeightedSumFitter(
priors={"beta_raw": Prior("Normal", mu=0, sigma=0.1, dims=["treated_units", "coeffs_raw"])}
)4. build_model: coeffs_raw fallback
coeffs_raw = (
coords["coeffs"][1:]
if coords and "coeffs" in coords
else list(range(1, X.shape[1]))
)The fallback list(range(1, X.shape[1])) produces integer coordinate labels, while the primary path produces string labels (sliced from the user's coords). This asymmetry is harmless in practice since coords are always passed through the experiment classes, but it's a bit surprising. Could this fallback raise instead? If coords is missing "coeffs", that's likely a usage error rather than an expected path.
5. Test: test_softmax_weighted_sum_fitter_priors_from_data could also verify dims adapt to data shape
The test verifies that priors_from_data returns the right distribution and fixed parameters, but since the current implementation doesn't actually use X or y, it doesn't test that the prior shape would be wrong for a different number of control units. This is fine for now but worth noting as a gap — if priors_from_data later adapts sigma to data scale, the test should be updated to verify that.
Questions for the author
-
Pinning the first logit vs the last: The choice to pin the first logit to zero means the first control unit in
coords["coeffs"]becomes the reference. This is arbitrary and mathematically equivalent to pinning any other position, but is there a reason you chose the first? Just want to confirm this is intentional and not a consequence of the concatenation order. (It matches the blog post's parameterization, so I suspect it's deliberate.) -
Connection to PR #823: The
_softmax_simplex_weightshelper is designed to be reused in the SDiD weight fitter. Have you verified that the 1D branch (which is never exercised in PR #822's tests) is actually needed for #823, or could the helper be simplified to always use the 2D path?
Summary
| Category | Verdict |
|---|---|
| Correctness | No issues found |
| Tests | Comprehensive — unit, integration, and prior tests all present |
| CI | All 17 checks pass |
| Documentation | Excellent docstrings with math, contrast to Dirichlet, and runnable examples |
| Code style | Consistent with existing WeightedSumFitter patterns |
| Public API | Consistent — accessible via cp.pymc_models.SoftmaxWeightedSumFitter, same pattern as WeightedSumFitter |
Recommendation: Approve with minor suggestions, pending a response on comment 3 (priors_from_data not adapting to the data).
…risation Introduces a _softmax_simplex_weights helper and SoftmaxWeightedSumFitter, providing an alternative to WeightedSumFitter for synthetic control that places Normal priors on unconstrained logits and maps to the simplex via softmax (first logit pinned to zero). This parameterisation enables continuous regularisation control via the prior scale, bridging between DiD-uniform and SC-sparse weight regimes. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
c7befc5 to
c4283d9
Compare
|
On Code Comment:
On Questions:
|
drbenvincent
left a comment
There was a problem hiding this comment.
Approved. Thanks for this!
Hoping to get to review the other, bigger, PR soon.
Introduces a
_softmax_simplex_weightshelper andSoftmaxWeightedSumFitter, providing an alternative toWeightedSumFitterfor synthetic control that places Normal priors on unconstrained logits and maps to the simplex via softmax (first logit pinned to zero). This parameterisation enables continuous regularisation control via the prior scale, bridging between DiD-uniform and SC-sparse weight regimes.This PR is the first of two that work towards implementing SDiD
Relevant Issue: #47