Skip to content

Add SoftmaxWeightedSumFitter for synthetic control#822

Merged
drbenvincent merged 3 commits intomainfrom
feat/softmax-weighted-sum-fitter
Apr 13, 2026
Merged

Add SoftmaxWeightedSumFitter for synthetic control#822
drbenvincent merged 3 commits intomainfrom
feat/softmax-weighted-sum-fitter

Conversation

@thomaspinder
Copy link
Copy Markdown
Contributor

Introduces a _softmax_simplex_weights helper and SoftmaxWeightedSumFitter, providing an alternative to WeightedSumFitter for synthetic control that places Normal priors on unconstrained logits and maps to the simplex via softmax (first logit pinned to zero). This parameterisation enables continuous regularisation control via the prior scale, bridging between DiD-uniform and SC-sparse weight regimes.

This PR is the first of two that work towards implementing SDiD

Relevant Issue: #47

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 3, 2026

👋 Welcome to CausalPy, @thomaspinder!

Thank you for opening your first pull request! We're excited to have you contribute to the project. 🎉

Here are a few tips to help your PR get merged smoothly:

  • ✅ Make sure all CI checks pass (tests, linting, type checking)
  • 📝 Run prek run --all-files locally before pushing
  • 📖 Check our Contributing Guide for more details

A maintainer will review your changes soon. Thanks for helping make CausalPy better! 🚀


💼 LinkedIn Shoutout: Once your PR is merged, we'd love to give you a shoutout on LinkedIn to thank you for your contribution! If you're interested, just drop your LinkedIn profile URL in a comment below.

@codecov
Copy link
Copy Markdown

codecov bot commented Apr 3, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 94.26%. Comparing base (711970c) to head (41bcd72).
⚠️ Report is 2 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #822      +/-   ##
==========================================
+ Coverage   94.20%   94.26%   +0.06%     
==========================================
  Files          78       78              
  Lines       11880    12005     +125     
  Branches      695      699       +4     
==========================================
+ Hits        11192    11317     +125     
  Misses        496      496              
  Partials      192      192              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@read-the-docs-community
Copy link
Copy Markdown

read-the-docs-community bot commented Apr 4, 2026

Documentation build overview

📚 causalpy | 🛠️ Build #32190438 | 📁 Comparing 41bcd72 against latest (711970c)

  🔍 Preview build  

Show files changed (45 files in total): 📝 4 modified | ➕ 41 added | ➖ 0 deleted
File Status
404.html 📝 modified
genindex.html 📝 modified
_modules/causalpy/pymc_models.html 📝 modified
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.init.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.add_coord.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.add_coords.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.add_named_variable.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.build_model.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.calculate_cumulative_impact.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.calculate_impact.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.check_start_vals.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.compile_d2logp.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.compile_dlogp.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.compile_fn.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.compile_logp.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.copy.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.create_value_var.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.d2logp.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.debug.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.dlogp.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.eval_rv_shapes.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.fit.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.get_context.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.initial_point.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.logp.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.logp_dlogp_function.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.make_obs_var.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.name_for.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.name_of.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.point_logps.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.predict.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.print_coefficients.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.priors_from_data.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.profile.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.register_data_var.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.register_rv.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.replace_rvs_by_values.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.score.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.set_data.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.set_dim.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.set_initval.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.shape_from_dims.html ➕ added
api/generated/causalpy.pymc_models.SoftmaxWeightedSumFitter.to_graphviz.html ➕ added
api/generated/causalpy.pymc_models.html 📝 modified

Copy link
Copy Markdown
Collaborator

@drbenvincent drbenvincent left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall assessment

Great first contribution! This is a clean, well-tested, self-contained addition that provides an alternative weight parameterization for synthetic control. The softmax-over-Normal-logits approach is well-motivated by the SDiD literature (Arkhangelsky et al., 2021) and serves as a solid foundation for the Bayesian SDiD work in PR #823. The docstrings are excellent — the Notes section clearly contrasts the two parameterizations and their regularization behaviour.

A few minor suggestions below. Most are non-blocking, but comment 3 (priors_from_data ignoring the data) deserves a response — I'd like to understand the reasoning before approving.


Code comments

1. _softmax_simplex_weights: consider validating the prior distribution

The docstring says the prior "Must be a Normal prior", but there's no runtime check. If someone passes a Dirichlet or Cauchy prior here the model would build without complaint but produce unexpected behaviour. A lightweight guard would help, especially since this helper will also be used in PR #823.

# suggestion (optional — depends on how public this helper is intended to be)
if prior.distribution != "Normal":
    raise ValueError(
        f"_softmax_simplex_weights expects a Normal prior, got {prior.distribution}"
    )

Not blocking — this is an internal helper prefixed with _, so the audience is narrow. But worth considering since #823 will call it from a second model class.

2. _softmax_simplex_weights: the 1D branch

The 1D/2D branching logic works correctly, but the condition n_rows == 1 and raw.ndim == 1 couples two independent checks. In practice this branch is only hit when priors_from_data happens to omit the treated_units dim, producing a 1D raw tensor. A brief inline comment explaining when each branch triggers would help future readers. The current comment says "1D case" but doesn't explain why it would be 1D.

3. priors_from_data ignores the data

Unlike WeightedSumFitter.priors_from_data, which adapts the Dirichlet shape to X.shape[1], this implementation always returns sigma=1.0 regardless of the data. This is fine as a default, but it means users who want tighter regularization (more DiD-like) or looser regularization (more SC-like) need to know to pass custom priors.

A docstring note explaining how to override sigma would be helpful for users:

# Example of overriding the logit prior scale for tighter regularization:
model = SoftmaxWeightedSumFitter(
    priors={"beta_raw": Prior("Normal", mu=0, sigma=0.1, dims=["treated_units", "coeffs_raw"])}
)

4. build_model: coeffs_raw fallback

coeffs_raw = (
    coords["coeffs"][1:]
    if coords and "coeffs" in coords
    else list(range(1, X.shape[1]))
)

The fallback list(range(1, X.shape[1])) produces integer coordinate labels, while the primary path produces string labels (sliced from the user's coords). This asymmetry is harmless in practice since coords are always passed through the experiment classes, but it's a bit surprising. Could this fallback raise instead? If coords is missing "coeffs", that's likely a usage error rather than an expected path.

5. Test: test_softmax_weighted_sum_fitter_priors_from_data could also verify dims adapt to data shape

The test verifies that priors_from_data returns the right distribution and fixed parameters, but since the current implementation doesn't actually use X or y, it doesn't test that the prior shape would be wrong for a different number of control units. This is fine for now but worth noting as a gap — if priors_from_data later adapts sigma to data scale, the test should be updated to verify that.


Questions for the author

  1. Pinning the first logit vs the last: The choice to pin the first logit to zero means the first control unit in coords["coeffs"] becomes the reference. This is arbitrary and mathematically equivalent to pinning any other position, but is there a reason you chose the first? Just want to confirm this is intentional and not a consequence of the concatenation order. (It matches the blog post's parameterization, so I suspect it's deliberate.)

  2. Connection to PR #823: The _softmax_simplex_weights helper is designed to be reused in the SDiD weight fitter. Have you verified that the 1D branch (which is never exercised in PR #822's tests) is actually needed for #823, or could the helper be simplified to always use the 2D path?


Summary

Category Verdict
Correctness No issues found
Tests Comprehensive — unit, integration, and prior tests all present
CI All 17 checks pass
Documentation Excellent docstrings with math, contrast to Dirichlet, and runnable examples
Code style Consistent with existing WeightedSumFitter patterns
Public API Consistent — accessible via cp.pymc_models.SoftmaxWeightedSumFitter, same pattern as WeightedSumFitter

Recommendation: Approve with minor suggestions, pending a response on comment 3 (priors_from_data not adapting to the data).

thomaspinder and others added 2 commits April 9, 2026 19:28
…risation

Introduces a _softmax_simplex_weights helper and SoftmaxWeightedSumFitter,
providing an alternative to WeightedSumFitter for synthetic control that
places Normal priors on unconstrained logits and maps to the simplex via
softmax (first logit pinned to zero). This parameterisation enables
continuous regularisation control via the prior scale, bridging between
DiD-uniform and SC-sparse weight regimes.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@thomaspinder thomaspinder force-pushed the feat/softmax-weighted-sum-fitter branch from c7befc5 to c4283d9 Compare April 9, 2026 17:28
@thomaspinder
Copy link
Copy Markdown
Contributor Author

On Code Comment:

  1. Good callout - updated.
  2. Docstring updated
  3. This is deliberate, and the mechanism differs from WeightedSumFitter. The Dirichlet dist. used in WeightedSumFitter requires an explicit length-N concentration vector a=np.ones(n_predictors), so priors_from_data must read X.shape[1] to set the shape. In this class, the Normal dist.'s logits prior broadcasts via dims=["treated_units", "coeffs_raw"], so the shape adapts without inspecting the data. The sigma=1.0 is the analogue of a=1, imposing moderate regularisation. Documenting how to override sigma is a good suggestion. I'll add a docstring example showing tighter (sigma=0.1, DiD-like) and looser (sigma=10, SC-like) configurations.
  4. Agreed — if coords is missing "coeffs" that's a usage error. Will replace the fallback with a ValueError.
  5. Acknowledged. The test will become more meaningful if priors_from_data later adapts to data scale.

On Questions:

  1. Pinning the first logit is deliberate as it matches the standard softmax identifiability convention. Pinning any position is mathematically equivalent - I just chose the first.
  2. The SDiD weight fitter calls _softmax_simplex_weights(n_rows=1, ...) for both omega (unit weights) and lam (time weights), both producing 1D outputs. Without the 1D branch, the concatenation would produce an incorrect (1, N) shape for what should be a 1D simplex.

Copy link
Copy Markdown
Collaborator

@drbenvincent drbenvincent left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved. Thanks for this!

Hoping to get to review the other, bigger, PR soon.

@drbenvincent drbenvincent merged commit dc44c69 into main Apr 13, 2026
15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants