Skip to content

Conversation

@MoAly98
Copy link
Owner

@MoAly98 MoAly98 commented Jan 25, 2026

This should be general, I implemented the main tests to be provided by default, the maths in asymptotic approximations needs to be double-checked, user should be able to construct their own test statistic, and test statistic distribution (whatever approximation is relevant to them) and pass that to our generic calculator that performs the test. Should try this for a chi2 as an example.

To be merged after #34

Summary by CodeRabbit

  • New Features

    • Introduced comprehensive hypothesis testing framework for CLs calculations with support for asymptotic and toy-based approaches, including test statistics (QTilde, QMu, Q0, TMu), p-value distributions, and upper limit solvers with Brazil bands.
    • Added parameter uncertainty estimation functions: Hessian, covariance, correlation matrices, and per-parameter uncertainties.
  • Documentation

    • Added example notebook demonstrating EFT limit extraction workflows with toy generation and asymptotic hypothesis testing.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 25, 2026

📝 Walkthrough

Walkthrough

Introduces a comprehensive hypothesis testing framework with test statistics (QTilde, QMu, Q0, TMu), asymptotic and empirical distributions, an orchestrator calculator, toy-based Monte Carlo support, and upper limit solvers. Also adds parameter uncertainty estimation via Hessian-based covariance and expands the public API surface. Includes an example notebook demonstrating EFT limit extraction.

Changes

Cohort / File(s) Summary
Hypothesis Testing Framework - Test Statistics
src/everwillow/inference/hypotest/test_statistics.py
Introduces abstract TestStatistic base class with four concrete implementations: QTilde (profile likelihood with boundary), QMu (unconstrained likelihood ratio), Q0 (discovery statistic), and TMu (signed statistic). Each computes a TestStatResult with q value and extras dict containing fit results and asimov information.
Hypothesis Testing Framework - Distributions
src/everwillow/inference/hypotest/distributions.py
Adds abstract Distribution base class with four asymptotic implementations (QTildeAsymptotic, QMuAsymptotic, Q0Asymptotic, TMuAsymptotic) using Cowan formulas, and EmpiricalDistribution for toy-based p-value computation. Each provides pvalues() and expected_pvalues() methods with boundary handling.
Hypothesis Testing Framework - Core Orchestration
src/everwillow/inference/hypotest/calculators.py
Implements HypoTestCalculator, a lightweight orchestrator that computes test statistics, delegates to distributions for p-values, derives CLs, and aggregates expected bands into HypoTestResult.
Hypothesis Testing Framework - Toy Generator
src/everwillow/inference/hypotest/toys.py
Introduces ToyGenerator class for Monte Carlo-based testing. Profiles nuisance parameters, generates toys under alternative/null hypotheses via provided sampling function, and returns EmpiricalDistribution with empirical q arrays. Parallelizes via jax.vmap.
Hypothesis Testing Framework - Upper Limits
src/everwillow/inference/hypotest/upper_limit.py
Provides three JAX-compatible upper limit solvers: upper_limit (deterministic bisection), upper_limit_toys (stochastic bisection with fresh keys per iteration), and expected_upper_limit (computes observed + Brazil-band limits). Returns ExpectedLimitResult with −2σ to +2σ bands.
Hypothesis Testing Framework - Utilities & Results
src/everwillow/inference/hypotest/_utils.py, src/everwillow/inference/hypotest/_results.py
Adds utility functions (cl_s for CLs computation, constrained_fit for fixed-parameter fits) and five result data models (TestStatResult, ExpectedBands, HypoTestResult, HypoTestToysResult, ExpectedLimitResult) as Equinox modules.
Hypothesis Testing Framework - Public API
src/everwillow/inference/hypotest/__init__.py
Creates consolidated public API surface, re-exporting all test statistics, distributions, calculators, result types, and utility functions under a unified namespace with comprehensive module docstring and usage example.
Uncertainty Quantification
src/everwillow/inference/uncertainty.py
Introduces Hessian-based parameter uncertainty estimation with four functions: hessian_matrix (JAX autograd Hessian), covariance_matrix (Hessian inverse), correlation_matrix (normalized covariance), and uncertainties (per-parameter std devs). Handles fixed parameters via partitioning.
Uncertainty Quantification - Tests
tests/inference/test_uncertainty.py
Comprehensive test suite validating Hessian, covariance, correlation, and uncertainty computations against analytical expectations from a simple quadratic NLL. Includes shape validation, positive definiteness, symmetry, and integration workflow tests.
Public API Exports
src/everwillow/inference/__init__.py
Re-exports uncertainty functions (correlation_matrix, covariance_matrix, hessian_matrix, uncertainties) from everwillow.inference.uncertainty module, expanding the inference package's public surface.
State Library Enhancement
src/everwillow/statelib/state.py
Broadens the K type alias to allow int components in keys (str | tuple[str | int, ...]). Adds _flatten_iterables helper for recursive key flattening during canonicalization, enabling multi-element key expansion.
Configuration & Examples
pyproject.toml, examples/eft_limits_toys.ipynb
Updates Ruff ignore list to suppress RUF002/RUF003 (ambiguous Unicode) for physics contexts. Adds example notebook demonstrating EFT limit extraction workflow: model definition, toy generation, asymptotic vs. toy CLs comparison, Brazil-band limits, and JIT timing.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant HypoTestCalculator
    participant TestStatistic as QTilde
    participant Distribution
    participant FitUtils as fit/constrained_fit
    
    User->>HypoTestCalculator: __call__(nll_fn, params, poi_key, poi_test, distribution)
    
    HypoTestCalculator->>TestStatistic: compute test statistic
    TestStatistic->>FitUtils: fit at poi_test (constrained)
    FitUtils-->>TestStatistic: fit_constrained, mu_hat
    TestStatistic->>FitUtils: fit unconstrained
    FitUtils-->>TestStatistic: fit_free
    TestStatistic->>TestStatistic: compute q = -2ln(L(poi_test)/L(mu_hat))
    TestStatistic-->>HypoTestCalculator: TestStatResult(q, extras)
    
    HypoTestCalculator->>Distribution: pvalues(test_stat_result)
    Distribution->>Distribution: compute pnull, palt from q
    Distribution-->>HypoTestCalculator: (pnull, palt)
    
    HypoTestCalculator->>HypoTestCalculator: cl_s = palt / pnull
    HypoTestCalculator->>Distribution: expected_pvalues(test_stat_result)
    Distribution-->>HypoTestCalculator: ExpectedBands(-2σ...+2σ)
    
    HypoTestCalculator-->>User: HypoTestResult(q_obs, pnull, palt, cl_s, expected_bands)
Loading
sequenceDiagram
    participant User
    participant ToyGenerator
    participant SampleFn as sample_fn
    participant NLLFactory as nll_factory
    participant TestStat as TestStatistic
    participant JAX as jax.vmap
    
    User->>ToyGenerator: generate(nll_fn, params, poi_key, poi_test, sample_fn, nll_factory, key)
    
    ToyGenerator->>ToyGenerator: profile nuisance params at poi_test (alt & null)
    
    ToyGenerator->>JAX: _run_toys(parallel execution)
    
    loop for each toy iteration
        JAX->>SampleFn: sample_fn(sample_params, key)
        SampleFn-->>JAX: toy_data
        JAX->>NLLFactory: nll_factory(toy_data)
        NLLFactory-->>JAX: toy_nll
        JAX->>TestStat: compute test statistic on toy_nll
        TestStat-->>JAX: q_toy
    end
    
    JAX-->>ToyGenerator: arrays [q_alt, q_null]
    
    ToyGenerator-->>User: EmpiricalDistribution(q_alt=..., q_null=...)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Possibly related PRs

  • #27: Changes FitResult structure and fit() signature/behavior that the new constrained_fit utility and test statistics directly depend on for profile likelihood computations.

Suggested reviewers

  • pfackeldey

Poem

🐰 With whiskers twitching, binning clues so bright,
We test hypotheses through day and night,
Q-tildes flutter, toys dance in the air,
Asymptotic bands beyond compare!
Brazil stripes painted, limits set with care—
This is the hypothesis testing lair! 🎭✨

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title "First implementation of hypothesis testing" accurately describes the primary change: introducing a new hypothesis-testing API and submodule with test statistics, distributions, and calculators.
Docstring Coverage ✅ Passed Docstring coverage is 82.56% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🤖 Fix all issues with AI agents
In `@src/everwillow/inference/hypotest/test_statistics.py`:
- Around line 132-138: The local variables fitted_state and fixed need explicit
type annotations to satisfy mypy's check_untyped_defs; annotate them as sl.State
(e.g., declare fitted_state: sl.State = sl.State.from_pytree(...) and fixed:
sl.State = sl.State.from_pytree(...)) so both variables are explicitly typed
when calling sl.State.from_pytree in the block that creates mu_hat and the
constrained fit (references: fitted_state, fixed, sl.State.from_pytree, poi_key,
poi_test).

In `@src/everwillow/inference/hypotest/toys.py`:
- Around line 153-165: The vmap over single_toy assumes sample_fn and
nll_factory are JAX-traceable and side-effect-free; if they are not, vmap will
recompile per iteration or fail. Fix by either documenting this requirement in
the _run_toys docstring (explicitly state sample_fn and nll_factory must be
JAX-compatible/functional and avoid Python-side effects) or by ensuring
consistent compilation by wrapping the _run_toys body (including single_toy and
the jax.vmap call that produces q_toys) with jax.jit so the whole toy loop is
traced/compiled once; reference single_toy, sample_fn, nll_factory, _run_toys
and the jax.vmap(keys) call to locate where to apply the change.

In `@src/everwillow/inference/hypotest/upper_limit.py`:
- Around line 125-134: The docstring for the function upper_limit (parameter
max_iterations) is inconsistent with the function signature: the signature
default is 100 but the docstring states 15; update the docstring to state
"Maximum bisection iterations (default 100)" (or change the function signature
default to 15 if intended) so the textual docs match the actual parameter
default; specifically edit the parameter description for max_iterations in the
upper_limit function's docstring to reference the correct default value.
- Around line 220-248: Replace the **solver_kwargs unpacking with explicit
keyword args to preserve precise types: call upper_limit for observed using
rtol=rtol, atol=atol, max_steps=max_steps (instead of **solver_kwargs), and
likewise in limit_at_band return upper_limit(..., rtol=rtol, atol=atol,
max_steps=max_steps); keep the same arguments to upper_limit, using calc_fn,
cl_s, bounds, and level as before and remove reliance on the solver_kwargs dict
for these calls.

In `@src/everwillow/statelib/state.py`:
- Around line 226-235: The docstring example incorrectly calls
State.from_pytree(..., canonicalize=False) even though State.from_pytree does
not accept a canonicalize parameter; update the documentation to remove the
canonicalize argument and instead show the intended round-trip using a
pre-canonicalized dict (e.g., show creating flat = state.to_dict() and then
State.from_pytree(flat).to_dict() == flat) or, if the behavior was intended, add
a canonicalize parameter to State.from_pytree and implement handling; reference
the State.from_pytree and to_dict symbols when making the change.
🧹 Nitpick comments (8)
src/everwillow/inference/hypotest/distributions.py (4)

24-31: Consider sorting __all__ for consistency.

The linter flags that __all__ is not sorted. While this is a minor style concern, sorting can improve maintainability and make diffs cleaner when adding new exports.

♻️ Suggested fix
 __all__ = [
     "Distribution",
+    "EmpiricalDistribution",
+    "Q0Asymptotic",
+    "QMuAsymptotic",
     "QTildeAsymptotic",
-    "QMuAsymptotic",
-    "Q0Asymptotic",
     "TMuAsymptotic",
-    "EmpiricalDistribution",
 ]

172-175: Unused q_asimov parameter in _palt.

The q_asimov parameter is declared but never used in this method. This appears intentional since QMuAsymptotic doesn't apply boundary handling, but the parameter signature is inconsistent with the actual computation.

♻️ Suggested fix: Remove unused parameter or add comment

Option 1 - Remove unused parameter:

-    def _palt(self, q: Array, q_asimov: Array, nsigma: float = 0.0) -> Array:
+    def _palt(self, q: Array, nsigma: float = 0.0) -> Array:
         """p-value under alternative hypothesis (signal+background)."""
         sqrt_q = jnp.sqrt(jnp.maximum(q, 0.0))
         return 1.0 - jax.scipy.stats.norm.cdf(sqrt_q - nsigma)

Option 2 - Add a comment explaining why it's kept for API consistency:

-    def _palt(self, q: Array, q_asimov: Array, nsigma: float = 0.0) -> Array:
+    def _palt(self, q: Array, q_asimov: Array, nsigma: float = 0.0) -> Array:  # noqa: ARG002
         """p-value under alternative hypothesis (signal+background)."""
+        # q_asimov unused here (no boundary handling), kept for API consistency
         sqrt_q = jnp.sqrt(jnp.maximum(q, 0.0))
         return 1.0 - jax.scipy.stats.norm.cdf(sqrt_q - nsigma)

319-324: Potential numerical instability with small toy sample sizes.

When q_null or q_alt arrays are small, jnp.mean(self.q_null >= q) could return extreme values (0.0 or 1.0), which may cause issues in downstream CLs calculations. Consider adding a note in the docstring about minimum recommended toy counts.


326-351: Unused result parameter in expected_pvalues.

The result parameter is required by the abstract interface but is not used. This is acceptable for API consistency, but consider adding a comment or # noqa: ARG002 to suppress the linter warning.

♻️ Suggested fix
-    def expected_pvalues(self, result: TestStatResult) -> ExpectedBands:
+    def expected_pvalues(self, result: TestStatResult) -> ExpectedBands:  # noqa: ARG002
         """Compute expected p-values at standard sigma bands from toy distributions.
 
         For each test statistic value in q_null, computes what the p-values would be,
         then takes percentiles at the standard normal quantiles.
         """
+        del result  # Unused; required by Distribution interface
src/everwillow/inference/hypotest/toys.py (1)

93-96: Hardcoded null hypothesis POI value of 0.0 may not be appropriate for all use cases.

The null hypothesis assumes POI = 0 (background-only), which is standard for signal strength parameters but may not be suitable for all hypothesis testing scenarios (e.g., testing μ = 1 vs μ = μ_alt).

Consider making the null hypothesis POI value configurable:

♻️ Suggested fix
     def generate(
         self,
         nll_fn: tp.Callable[[PyTree], float],
         params: sl.State,
         poi_key: sl.K,
         poi_test: float,
         *,
         sample_fn: tp.Callable[[sl.State, PRNGKeyArray], tp.Any],
         nll_factory: tp.Callable[[sl.State, tp.Any], tp.Callable[[PyTree], float]],
         key: PRNGKeyArray,
+        poi_null: float = 0.0,
         **fit_kwargs: tp.Any,
     ) -> EmpiricalDistribution:
         ...
         # Null hypothesis: POI = poi_null (background-only by default)
-        fixed_null = sl.State.from_pytree({poi_key: 0.0})
+        fixed_null = sl.State.from_pytree({poi_key: poi_null})
tests/inference/test_uncertainty.py (1)

48-48: Global JAX configuration should be isolated to test scope.

Setting jax.config.update("jax_enable_x64", True) at module level affects all subsequent tests in the test session. Consider using a pytest fixture or conftest.py to manage this configuration more safely.

♻️ Suggested approach using a fixture
# In conftest.py or at the top of this file
`@pytest.fixture`(autouse=True)
def enable_x64():
    """Enable 64-bit precision for these tests."""
    original = jax.config.jax_enable_x64
    jax.config.update("jax_enable_x64", True)
    yield
    jax.config.update("jax_enable_x64", original)
src/everwillow/inference/uncertainty.py (2)

101-103: Consider handling singular/ill-conditioned Hessian matrices.

jnp.linalg.inv will silently produce NaN or Inf for singular matrices, which can occur at parameter boundaries or with highly correlated parameters. Consider adding a condition number check or using pseudo-inverse with a warning.

♻️ Suggested improvement
     hess = hessian_matrix(nll_fn, params, fixed=fixed)
-    # Invert Hessian to get Fisher information matrix (covariance)
-    return jnp.linalg.inv(hess)
+    # Invert Hessian to get Fisher information matrix (covariance)
+    # Use pseudo-inverse for numerical stability with near-singular matrices
+    return jnp.linalg.pinv(hess)

Alternatively, check the condition number and warn if the matrix is ill-conditioned before inversion.


133-138: Return type annotation should reflect None values for fixed parameters.

The function returns None for fixed parameters (line 171), but the return type annotation is sl.State[V]. This should be sl.State[V | None] to accurately represent the return type.

📝 Suggested fix
 def uncertainties(
     nll_fn: tp.Callable[[PyTree[V]], float],
     params: sl.State[V],
     *,
     fixed: sl.State[V | EllipsisType] | None = None,
-) -> sl.State[V]:
+) -> sl.State[V | None]:

Comment on lines +132 to +138
fit_free = ew.fit(nll_fn, params, **fit_kwargs)
fitted_state = sl.State.from_pytree(fit_free.params)
mu_hat = fitted_state[poi_key]

# Constrained fit (POI fixed at test value)
fixed = sl.State.from_pytree({poi_key: poi_test})
fit_constrained = constrained_fit(nll_fn, params, fixed, **fit_kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# First, check if the file exists and examine lines 132-138
if [ -f "src/everwillow/inference/hypotest/test_statistics.py" ]; then
  echo "=== File exists. Checking lines 130-140 ==="
  sed -n '130,140p' src/everwillow/inference/hypotest/test_statistics.py | cat -n
else
  echo "File not found. Listing directory structure:"
  find . -name "test_statistics.py" -type f 2>/dev/null | head -20
fi

Repository: MoAly98/everwillow

Length of output: 655


🏁 Script executed:

#!/bin/bash
# Search for sl.State imports in the test file
rg -n "^import|^from" src/everwillow/inference/hypotest/test_statistics.py | head -20

Repository: MoAly98/everwillow

Length of output: 427


🏁 Script executed:

#!/bin/bash
# Check for any mypy or CI configuration files that might show errors
find . -name "*.yml" -o -name "*.yaml" -o -name "pyproject.toml" -o -name "setup.cfg" | grep -E "(github|gitlab|mypy|ci)" | head -10

Repository: MoAly98/everwillow

Length of output: 113


🏁 Script executed:

#!/bin/bash
# Check the CI configuration to see if mypy is being run
cat ./.github/workflows/ci.yml | head -100

Repository: MoAly98/everwillow

Length of output: 2052


🏁 Script executed:

#!/bin/bash
# Find sl.State definition to understand what type it is
find . -path "./.*" -prune -o -name "*.py" -type f -exec grep -l "class State" {} \; | grep -v __pycache__ | head -10

Repository: MoAly98/everwillow

Length of output: 96


🏁 Script executed:

#!/bin/bash
# Check if there are any existing type annotations on similar lines in the same file
rg -n ":\s*sl\.State\s*=" src/everwillow/inference/hypotest/test_statistics.py

Repository: MoAly98/everwillow

Length of output: 44


🏁 Script executed:

#!/bin/bash
# Look at the full context around lines 130-145 to understand the function signature and context
sed -n '120,150p' src/everwillow/inference/hypotest/test_statistics.py | cat -n

Repository: MoAly98/everwillow

Length of output: 1376


🏁 Script executed:

#!/bin/bash
# Check the pre-commit configuration to see if mypy is enabled
if [ -f ".pre-commit-config.yaml" ]; then
  cat .pre-commit-config.yaml
else
  echo "No .pre-commit-config.yaml found"
fi

Repository: MoAly98/everwillow

Length of output: 2098


🏁 Script executed:

#!/bin/bash
# Check the State class definition to understand the from_pytree method return type
sed -n '1,100p' src/everwillow/statelib/state.py | cat -n

Repository: MoAly98/everwillow

Length of output: 3734


🏁 Script executed:

#!/bin/bash
# Search for from_pytree method definition in State class
rg -A 5 "def from_pytree" src/everwillow/statelib/state.py

Repository: MoAly98/everwillow

Length of output: 220


🏁 Script executed:

#!/bin/bash
# Get the complete from_pytree method signature including return type
rg -A 15 "def from_pytree" src/everwillow/statelib/state.py | head -30

Repository: MoAly98/everwillow

Length of output: 716


🏁 Script executed:

#!/bin/bash
# Check what class this method belongs to and its definition
rg -B 10 "def from_pytree" src/everwillow/statelib/state.py | grep -E "^class|def from_pytree"

Repository: MoAly98/everwillow

Length of output: 82


🏁 Script executed:

#!/bin/bash
# Look at the pyproject.toml to see mypy configuration
grep -A 20 "\[tool.mypy\]" pyproject.toml

Repository: MoAly98/everwillow

Length of output: 722


🏁 Script executed:

#!/bin/bash
# Check for other State type annotations in the codebase to see the pattern
rg -n ":\s*sl\.State\[" src/everwillow/ | head -20

Repository: MoAly98/everwillow

Length of output: 1480


🏁 Script executed:

#!/bin/bash
# Check the fit_free.params type to understand what type V should be
rg -n "def fit" src/everwillow/ --type py | grep -v test | head -10

Repository: MoAly98/everwillow

Length of output: 110


🏁 Script executed:

#!/bin/bash
# Look for any type stub or type information about fit_free return type
find . -name "*.pyi" | xargs grep -l "fit" 2>/dev/null

Repository: MoAly98/everwillow

Length of output: 46


🏁 Script executed:

#!/bin/bash
# Check if there's a recent mypy error log or CI failure related to these lines
git log --oneline --all --grep="mypy\|type" | head -20

Repository: MoAly98/everwillow

Length of output: 564


🏁 Script executed:

#!/bin/bash
# Search for variable assignments annotated with sl.State in the codebase (looking for patterns)
rg -n ":\s*sl\.State\s*=" src/ | head -20

Repository: MoAly98/everwillow

Length of output: 44


🏁 Script executed:

#!/bin/bash
# Check what type V would be in context - look at the fit function's return type
sed -n '302,330p' src/everwillow/inference/fitting.py | cat -n

Repository: MoAly98/everwillow

Length of output: 1654


🏁 Script executed:

#!/bin/bash
# Look for any recent commits or PRs that mention the test_statistics file
git log --oneline --all -- src/everwillow/inference/hypotest/test_statistics.py | head -10

Repository: MoAly98/everwillow

Length of output: 154


🏁 Script executed:

#!/bin/bash
# Check the full function signature of _compute_q to see parameter types
sed -n '123,145p' src/everwillow/inference/hypotest/test_statistics.py | cat -n

Repository: MoAly98/everwillow

Length of output: 1111


🏁 Script executed:

#!/bin/bash
# Look for FitResult definition to understand the type of params attribute
rg -n "class FitResult" src/everwillow/ | head -5

Repository: MoAly98/everwillow

Length of output: 144


🏁 Script executed:

#!/bin/bash
# Check if there are any existing issues or PRs discussing mypy errors in this file
git log --all --oneline --grep="test_statistics" | head -10

Repository: MoAly98/everwillow

Length of output: 44


🏁 Script executed:

#!/bin/bash
# Check FitResult class and params attribute type
sed -n '52,80p' src/everwillow/inference/fitting.py | cat -n

Repository: MoAly98/everwillow

Length of output: 1036


🏁 Script executed:

#!/bin/bash
# Run a quick check: does the test file pass mypy currently?
# Look at the structure to understand if this is already a known issue
git diff HEAD -- src/everwillow/inference/hypotest/test_statistics.py 2>/dev/null | head -50

Repository: MoAly98/everwillow

Length of output: 44


🏁 Script executed:

#!/bin/bash
# Check how constrained_fit function is typed to understand if fixed should have a type parameter
rg -B 5 "def constrained_fit" src/everwillow/inference/hypotest/_utils.py | head -20

Repository: MoAly98/everwillow

Length of output: 191


🏁 Script executed:

#!/bin/bash
# Get a complete view of how the State class is used in annotations throughout the codebase
rg -n ":\s*sl\.State\b" src/ | grep -v "\[" | head -20

Repository: MoAly98/everwillow

Length of output: 1247


🏁 Script executed:

#!/bin/bash
# Check the constrained_fit function signature more completely
sed -n '25,45p' src/everwillow/inference/hypotest/_utils.py | cat -n

Repository: MoAly98/everwillow

Length of output: 929


🏁 Script executed:

#!/bin/bash
# Look for any type: ignore comments or similar workarounds in the test file
grep -n "type: ignore\|noqa" src/everwillow/inference/hypotest/test_statistics.py

Repository: MoAly98/everwillow

Length of output: 44


🏁 Script executed:

#!/bin/bash
# Check if State class is properly exported and what the actual import looks like
rg -n "State" src/everwillow/statelib/__init__.py

Repository: MoAly98/everwillow

Length of output: 90


Add explicit type annotations to fitted_state and fixed to resolve mypy errors.

The check_untyped_defs = true mypy configuration flags unannotated local variables. Adding explicit sl.State type annotations resolves these failures:

Fix
-        fitted_state = sl.State.from_pytree(fit_free.params)
+        fitted_state: sl.State = sl.State.from_pytree(fit_free.params)
 
-        fixed = sl.State.from_pytree({poi_key: poi_test})
+        fixed: sl.State = sl.State.from_pytree({poi_key: poi_test})
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
fit_free = ew.fit(nll_fn, params, **fit_kwargs)
fitted_state = sl.State.from_pytree(fit_free.params)
mu_hat = fitted_state[poi_key]
# Constrained fit (POI fixed at test value)
fixed = sl.State.from_pytree({poi_key: poi_test})
fit_constrained = constrained_fit(nll_fn, params, fixed, **fit_kwargs)
fit_free = ew.fit(nll_fn, params, **fit_kwargs)
fitted_state: sl.State = sl.State.from_pytree(fit_free.params)
mu_hat = fitted_state[poi_key]
# Constrained fit (POI fixed at test value)
fixed: sl.State = sl.State.from_pytree({poi_key: poi_test})
fit_constrained = constrained_fit(nll_fn, params, fixed, **fit_kwargs)
🧰 Tools
🪛 GitHub Actions: CI

[error] 135-135: Mypy: Need type annotation for 'fixed'.

🤖 Prompt for AI Agents
In `@src/everwillow/inference/hypotest/test_statistics.py` around lines 132 - 138,
The local variables fitted_state and fixed need explicit type annotations to
satisfy mypy's check_untyped_defs; annotate them as sl.State (e.g., declare
fitted_state: sl.State = sl.State.from_pytree(...) and fixed: sl.State =
sl.State.from_pytree(...)) so both variables are explicitly typed when calling
sl.State.from_pytree in the block that creates mu_hat and the constrained fit
(references: fitted_state, fixed, sl.State.from_pytree, poi_key, poi_test).

Comment on lines +153 to +165
def single_toy(key: PRNGKeyArray) -> Array:
# Generate toy data
toy_data = sample_fn(sample_params, key)
# Create NLL for this toy
toy_nll = nll_factory(sample_params, toy_data)
# Compute test statistic
result = self.test_statistic(
toy_nll, fit_params, poi_key, poi_test, **fit_kwargs
)
return result.q

# Run toys in parallel using vmap
q_toys = jax.vmap(single_toy)(keys)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# First, let's see the full context of the _run_toys method and ToyGenerator class
head -200 src/everwillow/inference/hypotest/toys.py | tail -100

Repository: MoAly98/everwillow

Length of output: 3681


🏁 Script executed:

#!/bin/bash
# Search for ToyGenerator usage across the codebase
rg -n "ToyGenerator" --type=py -B 2 -A 5

Repository: MoAly98/everwillow

Length of output: 3319


🏁 Script executed:

#!/bin/bash
# Check if there are any tests or examples for the ToyGenerator
fd -e py -path "*test*" | xargs rg -l "ToyGenerator" 2>/dev/null

Repository: MoAly98/everwillow

Length of output: 366


🏁 Script executed:

#!/bin/bash
# Look for jax.jit usage in the toys.py file
rg -n "jax.jit|@jit" src/everwillow/inference/hypotest/toys.py

Repository: MoAly98/everwillow

Length of output: 44


🏁 Script executed:

#!/bin/bash
# Check the docstring and signature of the _run_toys method
sed -n '100,200p' src/everwillow/inference/hypotest/toys.py

Repository: MoAly98/everwillow

Length of output: 2120


🏁 Script executed:

#!/bin/bash
# View the full generate method and its docstring
sed -n '50,130p' src/everwillow/inference/hypotest/toys.py

Repository: MoAly98/everwillow

Length of output: 2969


🏁 Script executed:

#!/bin/bash
# Check the example notebook to see how ToyGenerator is used in practice
head -500 examples/eft_limits_toys.ipynb | grep -A 20 "ToyGenerator"

Repository: MoAly98/everwillow

Length of output: 1642


🏁 Script executed:

#!/bin/bash
# Check the imports and see what constrained_fit is and where it comes from
rg -n "constrained_fit|test_statistic" src/everwillow/inference/hypotest/toys.py | head -20

Repository: MoAly98/everwillow

Length of output: 594


🏁 Script executed:

#!/bin/bash
# Look at the full class definition to see if there's any mention of JIT or documentation
head -60 src/everwillow/inference/hypotest/toys.py

Repository: MoAly98/everwillow

Length of output: 1966


Document that sample_fn and nll_factory must be JAX-compatible or wrap _run_toys in jax.jit.

The vmap(single_toy) pattern depends on both sample_fn and nll_factory being traceable and free of Python-side effects. If nll_factory generates complex computational graphs or either function has non-JAX operations, vmapping may inefficiently recompile per toy or fail during tracing. Either add a docstring note explicitly requiring JAX-compatible functions, or wrap the entire _run_toys method body in @jax.jit for consistent compilation.

🤖 Prompt for AI Agents
In `@src/everwillow/inference/hypotest/toys.py` around lines 153 - 165, The vmap
over single_toy assumes sample_fn and nll_factory are JAX-traceable and
side-effect-free; if they are not, vmap will recompile per iteration or fail.
Fix by either documenting this requirement in the _run_toys docstring
(explicitly state sample_fn and nll_factory must be JAX-compatible/functional
and avoid Python-side effects) or by ensuring consistent compilation by wrapping
the _run_toys body (including single_toy and the jax.vmap call that produces
q_toys) with jax.jit so the whole toy loop is traced/compiled once; reference
single_toy, sample_fn, nll_factory, _run_toys and the jax.vmap(keys) call to
locate where to apply the change.

Comment on lines +125 to +134
Args:
objective_fn: Function mapping (poi, key) to quantity of interest.
Should be monotonic (typically decreasing) as POI increases.
Must be JAX-traceable (no float() calls on traced values).
bounds: (lower, upper) search range for POI value.
key: JAX PRNG key for reproducibility.
level: Target value for the objective function (default 0.05).
tol: Stop when objective is within tol of level (default 0.02).
max_iterations: Maximum bisection iterations (default 15).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Docstring default for max_iterations is out of sync.

Signature defaults to 100, but the docstring says 15.

📝 Suggested docstring fix
-        max_iterations: Maximum bisection iterations (default 15).
+        max_iterations: Maximum bisection iterations (default 100).
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
Args:
objective_fn: Function mapping (poi, key) to quantity of interest.
Should be monotonic (typically decreasing) as POI increases.
Must be JAX-traceable (no float() calls on traced values).
bounds: (lower, upper) search range for POI value.
key: JAX PRNG key for reproducibility.
level: Target value for the objective function (default 0.05).
tol: Stop when objective is within tol of level (default 0.02).
max_iterations: Maximum bisection iterations (default 15).
Args:
objective_fn: Function mapping (poi, key) to quantity of interest.
Should be monotonic (typically decreasing) as POI increases.
Must be JAX-traceable (no float() calls on traced values).
bounds: (lower, upper) search range for POI value.
key: JAX PRNG key for reproducibility.
level: Target value for the objective function (default 0.05).
tol: Stop when objective is within tol of level (default 0.02).
max_iterations: Maximum bisection iterations (default 100).
🤖 Prompt for AI Agents
In `@src/everwillow/inference/hypotest/upper_limit.py` around lines 125 - 134, The
docstring for the function upper_limit (parameter max_iterations) is
inconsistent with the function signature: the signature default is 100 but the
docstring states 15; update the docstring to state "Maximum bisection iterations
(default 100)" (or change the function signature default to 15 if intended) so
the textual docs match the actual parameter default; specifically edit the
parameter description for max_iterations in the upper_limit function's docstring
to reference the correct default value.

Comment on lines +220 to +248
solver_kwargs = {"rtol": rtol, "atol": atol, "max_steps": max_steps}

# Observed limit
observed = upper_limit(
lambda poi: calc_fn(poi).cl_s,
bounds,
level,
**solver_kwargs,
)

# Expected limits at each sigma band
def limit_at_band(band_idx: int) -> Array:
"""Compute limit where expected CLs at given band equals level."""

def objective(poi: float) -> Array:
result = calc_fn(poi)
bands = result.expected_bands
# Each band contains (pnull, palt) tuple
band_list = [
bands.minus_2sigma,
bands.minus_1sigma,
bands.median,
bands.plus_1sigma,
bands.plus_2sigma,
]
pnull, palt = band_list[band_idx]
return cl_s(palt, pnull)

return upper_limit(objective, bounds, level, **solver_kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's look at the file structure and confirm the code at lines 220-248
wc -l src/everwillow/inference/hypotest/upper_limit.py

Repository: MoAly98/everwillow

Length of output: 114


🏁 Script executed:

# Extract the relevant section
sed -n '220,248p' src/everwillow/inference/hypotest/upper_limit.py

Repository: MoAly98/everwillow

Length of output: 989


🏁 Script executed:

# Search for the upper_limit function definition to understand its signature
rg -n "def upper_limit" src/everwillow/inference/hypotest/upper_limit.py -A 10

Repository: MoAly98/everwillow

Length of output: 733


🏁 Script executed:

# Check all usages of solver_kwargs in the file
rg -n "solver_kwargs" src/everwillow/inference/hypotest/upper_limit.py

Repository: MoAly98/everwillow

Length of output: 241


🏁 Script executed:

# Get the function that contains lines 220-248
sed -n '1,220p' src/everwillow/inference/hypotest/upper_limit.py | tail -50

Repository: MoAly98/everwillow

Length of output: 1898


🏁 Script executed:

# Look for the function definition before line 220
rg -n "^def " src/everwillow/inference/hypotest/upper_limit.py | awk -F: '$1 < 220 {last=$0} END {print last}'

Repository: MoAly98/everwillow

Length of output: 91


🏁 Script executed:

# Get the full function signature by looking around where the solver_kwargs is used
sed -n '170,225p' src/everwillow/inference/hypotest/upper_limit.py

Repository: MoAly98/everwillow

Length of output: 2071


🏁 Script executed:

# Verify there are no other solver_kwargs usages in the entire file
rg -n "solver_kwargs" src/everwillow/inference/hypotest/upper_limit.py

Repository: MoAly98/everwillow

Length of output: 241


🏁 Script executed:

# Check if there are any other files in the codebase that might use similar patterns
fd -t f "\.py$" -x rg -l "solver_kwargs" {} \; 2>/dev/null | head -20

Repository: MoAly98/everwillow

Length of output: 146


🏁 Script executed:

# Check the test file to see if it has similar issues we should be aware of
rg -n "solver_kwargs" tests/inference/test_fitting.py -B 2 -A 2

Repository: MoAly98/everwillow

Length of output: 257


Pass solver kwargs explicitly to satisfy mypy (CI failure).

When solver_kwargs is created as {"rtol": rtol, "atol": atol, "max_steps": max_steps}, mypy infers the dict type as dict[str, float | int]. Unpacking this dict with **solver_kwargs loses type specificity, causing max_steps to be treated as float instead of int, which conflicts with the max_steps: int parameter in upper_limit(). Pass the arguments explicitly to preserve their types in both function calls.

Fix
-    solver_kwargs = {"rtol": rtol, "atol": atol, "max_steps": max_steps}
-
     # Observed limit
     observed = upper_limit(
         lambda poi: calc_fn(poi).cl_s,
         bounds,
         level,
-        **solver_kwargs,
+        rtol=rtol,
+        atol=atol,
+        max_steps=max_steps,
     )
@@
-        return upper_limit(objective, bounds, level, **solver_kwargs)
+        return upper_limit(
+            objective,
+            bounds,
+            level,
+            rtol=rtol,
+            atol=atol,
+            max_steps=max_steps,
+        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
solver_kwargs = {"rtol": rtol, "atol": atol, "max_steps": max_steps}
# Observed limit
observed = upper_limit(
lambda poi: calc_fn(poi).cl_s,
bounds,
level,
**solver_kwargs,
)
# Expected limits at each sigma band
def limit_at_band(band_idx: int) -> Array:
"""Compute limit where expected CLs at given band equals level."""
def objective(poi: float) -> Array:
result = calc_fn(poi)
bands = result.expected_bands
# Each band contains (pnull, palt) tuple
band_list = [
bands.minus_2sigma,
bands.minus_1sigma,
bands.median,
bands.plus_1sigma,
bands.plus_2sigma,
]
pnull, palt = band_list[band_idx]
return cl_s(palt, pnull)
return upper_limit(objective, bounds, level, **solver_kwargs)
# Observed limit
observed = upper_limit(
lambda poi: calc_fn(poi).cl_s,
bounds,
level,
rtol=rtol,
atol=atol,
max_steps=max_steps,
)
# Expected limits at each sigma band
def limit_at_band(band_idx: int) -> Array:
"""Compute limit where expected CLs at given band equals level."""
def objective(poi: float) -> Array:
result = calc_fn(poi)
bands = result.expected_bands
# Each band contains (pnull, palt) tuple
band_list = [
bands.minus_2sigma,
bands.minus_1sigma,
bands.median,
bands.plus_1sigma,
bands.plus_2sigma,
]
pnull, palt = band_list[band_idx]
return cl_s(palt, pnull)
return upper_limit(
objective,
bounds,
level,
rtol=rtol,
atol=atol,
max_steps=max_steps,
)
🤖 Prompt for AI Agents
In `@src/everwillow/inference/hypotest/upper_limit.py` around lines 220 - 248,
Replace the **solver_kwargs unpacking with explicit keyword args to preserve
precise types: call upper_limit for observed using rtol=rtol, atol=atol,
max_steps=max_steps (instead of **solver_kwargs), and likewise in limit_at_band
return upper_limit(..., rtol=rtol, atol=atol, max_steps=max_steps); keep the
same arguments to upper_limit, using calc_fn, cl_s, bounds, and level as before
and remove reliance on the solver_kwargs dict for these calls.

Comment on lines 226 to +235
Examples:
>>> State.from_pytree({"a": [1, 2]}).mapping
mappingproxy({('a', 0): 1.0, ('a', 1): 2.0})
mappingproxy({('a', 0): 1, ('a', 1): 2})
Round-trip with pre-canonicalized keys:
>>> state = State.from_pytree({"x": 1.0, "y": 2.0})
>>> flat = state.to_dict() # {('x',): 1.0, ('y',): 2.0}
>>> State.from_pytree(flat, canonicalize=False).to_dict() == flat
True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix doc example: canonicalize arg isn’t in State.from_pytree.

The example will raise TypeError because the method doesn’t accept canonicalize. Update the docstring (or add the parameter if intended).

📝 Suggested docstring fix
-            >>> State.from_pytree(flat, canonicalize=False).to_dict() == flat
+            >>> State.from_pytree(flat).to_dict() == flat
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
Examples:
>>> State.from_pytree({"a": [1, 2]}).mapping
mappingproxy({('a', 0): 1.0, ('a', 1): 2.0})
mappingproxy({('a', 0): 1, ('a', 1): 2})
Round-trip with pre-canonicalized keys:
>>> state = State.from_pytree({"x": 1.0, "y": 2.0})
>>> flat = state.to_dict() # {('x',): 1.0, ('y',): 2.0}
>>> State.from_pytree(flat, canonicalize=False).to_dict() == flat
True
Examples:
>>> State.from_pytree({"a": [1, 2]}).mapping
mappingproxy({('a', 0): 1, ('a', 1): 2})
Round-trip with pre-canonicalized keys:
>>> state = State.from_pytree({"x": 1.0, "y": 2.0})
>>> flat = state.to_dict() # {('x',): 1.0, ('y',): 2.0}
>>> State.from_pytree(flat).to_dict() == flat
True
🤖 Prompt for AI Agents
In `@src/everwillow/statelib/state.py` around lines 226 - 235, The docstring
example incorrectly calls State.from_pytree(..., canonicalize=False) even though
State.from_pytree does not accept a canonicalize parameter; update the
documentation to remove the canonicalize argument and instead show the intended
round-trip using a pre-canonicalized dict (e.g., show creating flat =
state.to_dict() and then State.from_pytree(flat).to_dict() == flat) or, if the
behavior was intended, add a canonicalize parameter to State.from_pytree and
implement handling; reference the State.from_pytree and to_dict symbols when
making the change.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants