-
Notifications
You must be signed in to change notification settings - Fork 0
First implementation of hypothesis testing #35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…ion if user is providing already-canonical pytrees. Allow only if all keys are tuples and are not nested.
📝 WalkthroughWalkthroughIntroduces a comprehensive hypothesis testing framework with test statistics (QTilde, QMu, Q0, TMu), asymptotic and empirical distributions, an orchestrator calculator, toy-based Monte Carlo support, and upper limit solvers. Also adds parameter uncertainty estimation via Hessian-based covariance and expands the public API surface. Includes an example notebook demonstrating EFT limit extraction. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant HypoTestCalculator
participant TestStatistic as QTilde
participant Distribution
participant FitUtils as fit/constrained_fit
User->>HypoTestCalculator: __call__(nll_fn, params, poi_key, poi_test, distribution)
HypoTestCalculator->>TestStatistic: compute test statistic
TestStatistic->>FitUtils: fit at poi_test (constrained)
FitUtils-->>TestStatistic: fit_constrained, mu_hat
TestStatistic->>FitUtils: fit unconstrained
FitUtils-->>TestStatistic: fit_free
TestStatistic->>TestStatistic: compute q = -2ln(L(poi_test)/L(mu_hat))
TestStatistic-->>HypoTestCalculator: TestStatResult(q, extras)
HypoTestCalculator->>Distribution: pvalues(test_stat_result)
Distribution->>Distribution: compute pnull, palt from q
Distribution-->>HypoTestCalculator: (pnull, palt)
HypoTestCalculator->>HypoTestCalculator: cl_s = palt / pnull
HypoTestCalculator->>Distribution: expected_pvalues(test_stat_result)
Distribution-->>HypoTestCalculator: ExpectedBands(-2σ...+2σ)
HypoTestCalculator-->>User: HypoTestResult(q_obs, pnull, palt, cl_s, expected_bands)
sequenceDiagram
participant User
participant ToyGenerator
participant SampleFn as sample_fn
participant NLLFactory as nll_factory
participant TestStat as TestStatistic
participant JAX as jax.vmap
User->>ToyGenerator: generate(nll_fn, params, poi_key, poi_test, sample_fn, nll_factory, key)
ToyGenerator->>ToyGenerator: profile nuisance params at poi_test (alt & null)
ToyGenerator->>JAX: _run_toys(parallel execution)
loop for each toy iteration
JAX->>SampleFn: sample_fn(sample_params, key)
SampleFn-->>JAX: toy_data
JAX->>NLLFactory: nll_factory(toy_data)
NLLFactory-->>JAX: toy_nll
JAX->>TestStat: compute test statistic on toy_nll
TestStat-->>JAX: q_toy
end
JAX-->>ToyGenerator: arrays [q_alt, q_null]
ToyGenerator-->>User: EmpiricalDistribution(q_alt=..., q_null=...)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
🤖 Fix all issues with AI agents
In `@src/everwillow/inference/hypotest/test_statistics.py`:
- Around line 132-138: The local variables fitted_state and fixed need explicit
type annotations to satisfy mypy's check_untyped_defs; annotate them as sl.State
(e.g., declare fitted_state: sl.State = sl.State.from_pytree(...) and fixed:
sl.State = sl.State.from_pytree(...)) so both variables are explicitly typed
when calling sl.State.from_pytree in the block that creates mu_hat and the
constrained fit (references: fitted_state, fixed, sl.State.from_pytree, poi_key,
poi_test).
In `@src/everwillow/inference/hypotest/toys.py`:
- Around line 153-165: The vmap over single_toy assumes sample_fn and
nll_factory are JAX-traceable and side-effect-free; if they are not, vmap will
recompile per iteration or fail. Fix by either documenting this requirement in
the _run_toys docstring (explicitly state sample_fn and nll_factory must be
JAX-compatible/functional and avoid Python-side effects) or by ensuring
consistent compilation by wrapping the _run_toys body (including single_toy and
the jax.vmap call that produces q_toys) with jax.jit so the whole toy loop is
traced/compiled once; reference single_toy, sample_fn, nll_factory, _run_toys
and the jax.vmap(keys) call to locate where to apply the change.
In `@src/everwillow/inference/hypotest/upper_limit.py`:
- Around line 125-134: The docstring for the function upper_limit (parameter
max_iterations) is inconsistent with the function signature: the signature
default is 100 but the docstring states 15; update the docstring to state
"Maximum bisection iterations (default 100)" (or change the function signature
default to 15 if intended) so the textual docs match the actual parameter
default; specifically edit the parameter description for max_iterations in the
upper_limit function's docstring to reference the correct default value.
- Around line 220-248: Replace the **solver_kwargs unpacking with explicit
keyword args to preserve precise types: call upper_limit for observed using
rtol=rtol, atol=atol, max_steps=max_steps (instead of **solver_kwargs), and
likewise in limit_at_band return upper_limit(..., rtol=rtol, atol=atol,
max_steps=max_steps); keep the same arguments to upper_limit, using calc_fn,
cl_s, bounds, and level as before and remove reliance on the solver_kwargs dict
for these calls.
In `@src/everwillow/statelib/state.py`:
- Around line 226-235: The docstring example incorrectly calls
State.from_pytree(..., canonicalize=False) even though State.from_pytree does
not accept a canonicalize parameter; update the documentation to remove the
canonicalize argument and instead show the intended round-trip using a
pre-canonicalized dict (e.g., show creating flat = state.to_dict() and then
State.from_pytree(flat).to_dict() == flat) or, if the behavior was intended, add
a canonicalize parameter to State.from_pytree and implement handling; reference
the State.from_pytree and to_dict symbols when making the change.
🧹 Nitpick comments (8)
src/everwillow/inference/hypotest/distributions.py (4)
24-31: Consider sorting__all__for consistency.The linter flags that
__all__is not sorted. While this is a minor style concern, sorting can improve maintainability and make diffs cleaner when adding new exports.♻️ Suggested fix
__all__ = [ "Distribution", + "EmpiricalDistribution", + "Q0Asymptotic", + "QMuAsymptotic", "QTildeAsymptotic", - "QMuAsymptotic", - "Q0Asymptotic", "TMuAsymptotic", - "EmpiricalDistribution", ]
172-175: Unusedq_asimovparameter in_palt.The
q_asimovparameter is declared but never used in this method. This appears intentional sinceQMuAsymptoticdoesn't apply boundary handling, but the parameter signature is inconsistent with the actual computation.♻️ Suggested fix: Remove unused parameter or add comment
Option 1 - Remove unused parameter:
- def _palt(self, q: Array, q_asimov: Array, nsigma: float = 0.0) -> Array: + def _palt(self, q: Array, nsigma: float = 0.0) -> Array: """p-value under alternative hypothesis (signal+background).""" sqrt_q = jnp.sqrt(jnp.maximum(q, 0.0)) return 1.0 - jax.scipy.stats.norm.cdf(sqrt_q - nsigma)Option 2 - Add a comment explaining why it's kept for API consistency:
- def _palt(self, q: Array, q_asimov: Array, nsigma: float = 0.0) -> Array: + def _palt(self, q: Array, q_asimov: Array, nsigma: float = 0.0) -> Array: # noqa: ARG002 """p-value under alternative hypothesis (signal+background).""" + # q_asimov unused here (no boundary handling), kept for API consistency sqrt_q = jnp.sqrt(jnp.maximum(q, 0.0)) return 1.0 - jax.scipy.stats.norm.cdf(sqrt_q - nsigma)
319-324: Potential numerical instability with small toy sample sizes.When
q_nullorq_altarrays are small,jnp.mean(self.q_null >= q)could return extreme values (0.0 or 1.0), which may cause issues in downstream CLs calculations. Consider adding a note in the docstring about minimum recommended toy counts.
326-351: Unusedresultparameter inexpected_pvalues.The
resultparameter is required by the abstract interface but is not used. This is acceptable for API consistency, but consider adding a comment or# noqa: ARG002to suppress the linter warning.♻️ Suggested fix
- def expected_pvalues(self, result: TestStatResult) -> ExpectedBands: + def expected_pvalues(self, result: TestStatResult) -> ExpectedBands: # noqa: ARG002 """Compute expected p-values at standard sigma bands from toy distributions. For each test statistic value in q_null, computes what the p-values would be, then takes percentiles at the standard normal quantiles. """ + del result # Unused; required by Distribution interfacesrc/everwillow/inference/hypotest/toys.py (1)
93-96: Hardcoded null hypothesis POI value of 0.0 may not be appropriate for all use cases.The null hypothesis assumes
POI = 0(background-only), which is standard for signal strength parameters but may not be suitable for all hypothesis testing scenarios (e.g., testingμ = 1vsμ = μ_alt).Consider making the null hypothesis POI value configurable:
♻️ Suggested fix
def generate( self, nll_fn: tp.Callable[[PyTree], float], params: sl.State, poi_key: sl.K, poi_test: float, *, sample_fn: tp.Callable[[sl.State, PRNGKeyArray], tp.Any], nll_factory: tp.Callable[[sl.State, tp.Any], tp.Callable[[PyTree], float]], key: PRNGKeyArray, + poi_null: float = 0.0, **fit_kwargs: tp.Any, ) -> EmpiricalDistribution: ... # Null hypothesis: POI = poi_null (background-only by default) - fixed_null = sl.State.from_pytree({poi_key: 0.0}) + fixed_null = sl.State.from_pytree({poi_key: poi_null})tests/inference/test_uncertainty.py (1)
48-48: Global JAX configuration should be isolated to test scope.Setting
jax.config.update("jax_enable_x64", True)at module level affects all subsequent tests in the test session. Consider using a pytest fixture or conftest.py to manage this configuration more safely.♻️ Suggested approach using a fixture
# In conftest.py or at the top of this file `@pytest.fixture`(autouse=True) def enable_x64(): """Enable 64-bit precision for these tests.""" original = jax.config.jax_enable_x64 jax.config.update("jax_enable_x64", True) yield jax.config.update("jax_enable_x64", original)src/everwillow/inference/uncertainty.py (2)
101-103: Consider handling singular/ill-conditioned Hessian matrices.
jnp.linalg.invwill silently produce NaN or Inf for singular matrices, which can occur at parameter boundaries or with highly correlated parameters. Consider adding a condition number check or using pseudo-inverse with a warning.♻️ Suggested improvement
hess = hessian_matrix(nll_fn, params, fixed=fixed) - # Invert Hessian to get Fisher information matrix (covariance) - return jnp.linalg.inv(hess) + # Invert Hessian to get Fisher information matrix (covariance) + # Use pseudo-inverse for numerical stability with near-singular matrices + return jnp.linalg.pinv(hess)Alternatively, check the condition number and warn if the matrix is ill-conditioned before inversion.
133-138: Return type annotation should reflectNonevalues for fixed parameters.The function returns
Nonefor fixed parameters (line 171), but the return type annotation issl.State[V]. This should besl.State[V | None]to accurately represent the return type.📝 Suggested fix
def uncertainties( nll_fn: tp.Callable[[PyTree[V]], float], params: sl.State[V], *, fixed: sl.State[V | EllipsisType] | None = None, -) -> sl.State[V]: +) -> sl.State[V | None]:
| fit_free = ew.fit(nll_fn, params, **fit_kwargs) | ||
| fitted_state = sl.State.from_pytree(fit_free.params) | ||
| mu_hat = fitted_state[poi_key] | ||
|
|
||
| # Constrained fit (POI fixed at test value) | ||
| fixed = sl.State.from_pytree({poi_key: poi_test}) | ||
| fit_constrained = constrained_fit(nll_fn, params, fixed, **fit_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# First, check if the file exists and examine lines 132-138
if [ -f "src/everwillow/inference/hypotest/test_statistics.py" ]; then
echo "=== File exists. Checking lines 130-140 ==="
sed -n '130,140p' src/everwillow/inference/hypotest/test_statistics.py | cat -n
else
echo "File not found. Listing directory structure:"
find . -name "test_statistics.py" -type f 2>/dev/null | head -20
fiRepository: MoAly98/everwillow
Length of output: 655
🏁 Script executed:
#!/bin/bash
# Search for sl.State imports in the test file
rg -n "^import|^from" src/everwillow/inference/hypotest/test_statistics.py | head -20Repository: MoAly98/everwillow
Length of output: 427
🏁 Script executed:
#!/bin/bash
# Check for any mypy or CI configuration files that might show errors
find . -name "*.yml" -o -name "*.yaml" -o -name "pyproject.toml" -o -name "setup.cfg" | grep -E "(github|gitlab|mypy|ci)" | head -10Repository: MoAly98/everwillow
Length of output: 113
🏁 Script executed:
#!/bin/bash
# Check the CI configuration to see if mypy is being run
cat ./.github/workflows/ci.yml | head -100Repository: MoAly98/everwillow
Length of output: 2052
🏁 Script executed:
#!/bin/bash
# Find sl.State definition to understand what type it is
find . -path "./.*" -prune -o -name "*.py" -type f -exec grep -l "class State" {} \; | grep -v __pycache__ | head -10Repository: MoAly98/everwillow
Length of output: 96
🏁 Script executed:
#!/bin/bash
# Check if there are any existing type annotations on similar lines in the same file
rg -n ":\s*sl\.State\s*=" src/everwillow/inference/hypotest/test_statistics.pyRepository: MoAly98/everwillow
Length of output: 44
🏁 Script executed:
#!/bin/bash
# Look at the full context around lines 130-145 to understand the function signature and context
sed -n '120,150p' src/everwillow/inference/hypotest/test_statistics.py | cat -nRepository: MoAly98/everwillow
Length of output: 1376
🏁 Script executed:
#!/bin/bash
# Check the pre-commit configuration to see if mypy is enabled
if [ -f ".pre-commit-config.yaml" ]; then
cat .pre-commit-config.yaml
else
echo "No .pre-commit-config.yaml found"
fiRepository: MoAly98/everwillow
Length of output: 2098
🏁 Script executed:
#!/bin/bash
# Check the State class definition to understand the from_pytree method return type
sed -n '1,100p' src/everwillow/statelib/state.py | cat -nRepository: MoAly98/everwillow
Length of output: 3734
🏁 Script executed:
#!/bin/bash
# Search for from_pytree method definition in State class
rg -A 5 "def from_pytree" src/everwillow/statelib/state.pyRepository: MoAly98/everwillow
Length of output: 220
🏁 Script executed:
#!/bin/bash
# Get the complete from_pytree method signature including return type
rg -A 15 "def from_pytree" src/everwillow/statelib/state.py | head -30Repository: MoAly98/everwillow
Length of output: 716
🏁 Script executed:
#!/bin/bash
# Check what class this method belongs to and its definition
rg -B 10 "def from_pytree" src/everwillow/statelib/state.py | grep -E "^class|def from_pytree"Repository: MoAly98/everwillow
Length of output: 82
🏁 Script executed:
#!/bin/bash
# Look at the pyproject.toml to see mypy configuration
grep -A 20 "\[tool.mypy\]" pyproject.tomlRepository: MoAly98/everwillow
Length of output: 722
🏁 Script executed:
#!/bin/bash
# Check for other State type annotations in the codebase to see the pattern
rg -n ":\s*sl\.State\[" src/everwillow/ | head -20Repository: MoAly98/everwillow
Length of output: 1480
🏁 Script executed:
#!/bin/bash
# Check the fit_free.params type to understand what type V should be
rg -n "def fit" src/everwillow/ --type py | grep -v test | head -10Repository: MoAly98/everwillow
Length of output: 110
🏁 Script executed:
#!/bin/bash
# Look for any type stub or type information about fit_free return type
find . -name "*.pyi" | xargs grep -l "fit" 2>/dev/nullRepository: MoAly98/everwillow
Length of output: 46
🏁 Script executed:
#!/bin/bash
# Check if there's a recent mypy error log or CI failure related to these lines
git log --oneline --all --grep="mypy\|type" | head -20Repository: MoAly98/everwillow
Length of output: 564
🏁 Script executed:
#!/bin/bash
# Search for variable assignments annotated with sl.State in the codebase (looking for patterns)
rg -n ":\s*sl\.State\s*=" src/ | head -20Repository: MoAly98/everwillow
Length of output: 44
🏁 Script executed:
#!/bin/bash
# Check what type V would be in context - look at the fit function's return type
sed -n '302,330p' src/everwillow/inference/fitting.py | cat -nRepository: MoAly98/everwillow
Length of output: 1654
🏁 Script executed:
#!/bin/bash
# Look for any recent commits or PRs that mention the test_statistics file
git log --oneline --all -- src/everwillow/inference/hypotest/test_statistics.py | head -10Repository: MoAly98/everwillow
Length of output: 154
🏁 Script executed:
#!/bin/bash
# Check the full function signature of _compute_q to see parameter types
sed -n '123,145p' src/everwillow/inference/hypotest/test_statistics.py | cat -nRepository: MoAly98/everwillow
Length of output: 1111
🏁 Script executed:
#!/bin/bash
# Look for FitResult definition to understand the type of params attribute
rg -n "class FitResult" src/everwillow/ | head -5Repository: MoAly98/everwillow
Length of output: 144
🏁 Script executed:
#!/bin/bash
# Check if there are any existing issues or PRs discussing mypy errors in this file
git log --all --oneline --grep="test_statistics" | head -10Repository: MoAly98/everwillow
Length of output: 44
🏁 Script executed:
#!/bin/bash
# Check FitResult class and params attribute type
sed -n '52,80p' src/everwillow/inference/fitting.py | cat -nRepository: MoAly98/everwillow
Length of output: 1036
🏁 Script executed:
#!/bin/bash
# Run a quick check: does the test file pass mypy currently?
# Look at the structure to understand if this is already a known issue
git diff HEAD -- src/everwillow/inference/hypotest/test_statistics.py 2>/dev/null | head -50Repository: MoAly98/everwillow
Length of output: 44
🏁 Script executed:
#!/bin/bash
# Check how constrained_fit function is typed to understand if fixed should have a type parameter
rg -B 5 "def constrained_fit" src/everwillow/inference/hypotest/_utils.py | head -20Repository: MoAly98/everwillow
Length of output: 191
🏁 Script executed:
#!/bin/bash
# Get a complete view of how the State class is used in annotations throughout the codebase
rg -n ":\s*sl\.State\b" src/ | grep -v "\[" | head -20Repository: MoAly98/everwillow
Length of output: 1247
🏁 Script executed:
#!/bin/bash
# Check the constrained_fit function signature more completely
sed -n '25,45p' src/everwillow/inference/hypotest/_utils.py | cat -nRepository: MoAly98/everwillow
Length of output: 929
🏁 Script executed:
#!/bin/bash
# Look for any type: ignore comments or similar workarounds in the test file
grep -n "type: ignore\|noqa" src/everwillow/inference/hypotest/test_statistics.pyRepository: MoAly98/everwillow
Length of output: 44
🏁 Script executed:
#!/bin/bash
# Check if State class is properly exported and what the actual import looks like
rg -n "State" src/everwillow/statelib/__init__.pyRepository: MoAly98/everwillow
Length of output: 90
Add explicit type annotations to fitted_state and fixed to resolve mypy errors.
The check_untyped_defs = true mypy configuration flags unannotated local variables. Adding explicit sl.State type annotations resolves these failures:
Fix
- fitted_state = sl.State.from_pytree(fit_free.params)
+ fitted_state: sl.State = sl.State.from_pytree(fit_free.params)
- fixed = sl.State.from_pytree({poi_key: poi_test})
+ fixed: sl.State = sl.State.from_pytree({poi_key: poi_test})📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| fit_free = ew.fit(nll_fn, params, **fit_kwargs) | |
| fitted_state = sl.State.from_pytree(fit_free.params) | |
| mu_hat = fitted_state[poi_key] | |
| # Constrained fit (POI fixed at test value) | |
| fixed = sl.State.from_pytree({poi_key: poi_test}) | |
| fit_constrained = constrained_fit(nll_fn, params, fixed, **fit_kwargs) | |
| fit_free = ew.fit(nll_fn, params, **fit_kwargs) | |
| fitted_state: sl.State = sl.State.from_pytree(fit_free.params) | |
| mu_hat = fitted_state[poi_key] | |
| # Constrained fit (POI fixed at test value) | |
| fixed: sl.State = sl.State.from_pytree({poi_key: poi_test}) | |
| fit_constrained = constrained_fit(nll_fn, params, fixed, **fit_kwargs) |
🧰 Tools
🪛 GitHub Actions: CI
[error] 135-135: Mypy: Need type annotation for 'fixed'.
🤖 Prompt for AI Agents
In `@src/everwillow/inference/hypotest/test_statistics.py` around lines 132 - 138,
The local variables fitted_state and fixed need explicit type annotations to
satisfy mypy's check_untyped_defs; annotate them as sl.State (e.g., declare
fitted_state: sl.State = sl.State.from_pytree(...) and fixed: sl.State =
sl.State.from_pytree(...)) so both variables are explicitly typed when calling
sl.State.from_pytree in the block that creates mu_hat and the constrained fit
(references: fitted_state, fixed, sl.State.from_pytree, poi_key, poi_test).
| def single_toy(key: PRNGKeyArray) -> Array: | ||
| # Generate toy data | ||
| toy_data = sample_fn(sample_params, key) | ||
| # Create NLL for this toy | ||
| toy_nll = nll_factory(sample_params, toy_data) | ||
| # Compute test statistic | ||
| result = self.test_statistic( | ||
| toy_nll, fit_params, poi_key, poi_test, **fit_kwargs | ||
| ) | ||
| return result.q | ||
|
|
||
| # Run toys in parallel using vmap | ||
| q_toys = jax.vmap(single_toy)(keys) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# First, let's see the full context of the _run_toys method and ToyGenerator class
head -200 src/everwillow/inference/hypotest/toys.py | tail -100Repository: MoAly98/everwillow
Length of output: 3681
🏁 Script executed:
#!/bin/bash
# Search for ToyGenerator usage across the codebase
rg -n "ToyGenerator" --type=py -B 2 -A 5Repository: MoAly98/everwillow
Length of output: 3319
🏁 Script executed:
#!/bin/bash
# Check if there are any tests or examples for the ToyGenerator
fd -e py -path "*test*" | xargs rg -l "ToyGenerator" 2>/dev/nullRepository: MoAly98/everwillow
Length of output: 366
🏁 Script executed:
#!/bin/bash
# Look for jax.jit usage in the toys.py file
rg -n "jax.jit|@jit" src/everwillow/inference/hypotest/toys.pyRepository: MoAly98/everwillow
Length of output: 44
🏁 Script executed:
#!/bin/bash
# Check the docstring and signature of the _run_toys method
sed -n '100,200p' src/everwillow/inference/hypotest/toys.pyRepository: MoAly98/everwillow
Length of output: 2120
🏁 Script executed:
#!/bin/bash
# View the full generate method and its docstring
sed -n '50,130p' src/everwillow/inference/hypotest/toys.pyRepository: MoAly98/everwillow
Length of output: 2969
🏁 Script executed:
#!/bin/bash
# Check the example notebook to see how ToyGenerator is used in practice
head -500 examples/eft_limits_toys.ipynb | grep -A 20 "ToyGenerator"Repository: MoAly98/everwillow
Length of output: 1642
🏁 Script executed:
#!/bin/bash
# Check the imports and see what constrained_fit is and where it comes from
rg -n "constrained_fit|test_statistic" src/everwillow/inference/hypotest/toys.py | head -20Repository: MoAly98/everwillow
Length of output: 594
🏁 Script executed:
#!/bin/bash
# Look at the full class definition to see if there's any mention of JIT or documentation
head -60 src/everwillow/inference/hypotest/toys.pyRepository: MoAly98/everwillow
Length of output: 1966
Document that sample_fn and nll_factory must be JAX-compatible or wrap _run_toys in jax.jit.
The vmap(single_toy) pattern depends on both sample_fn and nll_factory being traceable and free of Python-side effects. If nll_factory generates complex computational graphs or either function has non-JAX operations, vmapping may inefficiently recompile per toy or fail during tracing. Either add a docstring note explicitly requiring JAX-compatible functions, or wrap the entire _run_toys method body in @jax.jit for consistent compilation.
🤖 Prompt for AI Agents
In `@src/everwillow/inference/hypotest/toys.py` around lines 153 - 165, The vmap
over single_toy assumes sample_fn and nll_factory are JAX-traceable and
side-effect-free; if they are not, vmap will recompile per iteration or fail.
Fix by either documenting this requirement in the _run_toys docstring
(explicitly state sample_fn and nll_factory must be JAX-compatible/functional
and avoid Python-side effects) or by ensuring consistent compilation by wrapping
the _run_toys body (including single_toy and the jax.vmap call that produces
q_toys) with jax.jit so the whole toy loop is traced/compiled once; reference
single_toy, sample_fn, nll_factory, _run_toys and the jax.vmap(keys) call to
locate where to apply the change.
| Args: | ||
| objective_fn: Function mapping (poi, key) to quantity of interest. | ||
| Should be monotonic (typically decreasing) as POI increases. | ||
| Must be JAX-traceable (no float() calls on traced values). | ||
| bounds: (lower, upper) search range for POI value. | ||
| key: JAX PRNG key for reproducibility. | ||
| level: Target value for the objective function (default 0.05). | ||
| tol: Stop when objective is within tol of level (default 0.02). | ||
| max_iterations: Maximum bisection iterations (default 15). | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docstring default for max_iterations is out of sync.
Signature defaults to 100, but the docstring says 15.
📝 Suggested docstring fix
- max_iterations: Maximum bisection iterations (default 15).
+ max_iterations: Maximum bisection iterations (default 100).📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| Args: | |
| objective_fn: Function mapping (poi, key) to quantity of interest. | |
| Should be monotonic (typically decreasing) as POI increases. | |
| Must be JAX-traceable (no float() calls on traced values). | |
| bounds: (lower, upper) search range for POI value. | |
| key: JAX PRNG key for reproducibility. | |
| level: Target value for the objective function (default 0.05). | |
| tol: Stop when objective is within tol of level (default 0.02). | |
| max_iterations: Maximum bisection iterations (default 15). | |
| Args: | |
| objective_fn: Function mapping (poi, key) to quantity of interest. | |
| Should be monotonic (typically decreasing) as POI increases. | |
| Must be JAX-traceable (no float() calls on traced values). | |
| bounds: (lower, upper) search range for POI value. | |
| key: JAX PRNG key for reproducibility. | |
| level: Target value for the objective function (default 0.05). | |
| tol: Stop when objective is within tol of level (default 0.02). | |
| max_iterations: Maximum bisection iterations (default 100). |
🤖 Prompt for AI Agents
In `@src/everwillow/inference/hypotest/upper_limit.py` around lines 125 - 134, The
docstring for the function upper_limit (parameter max_iterations) is
inconsistent with the function signature: the signature default is 100 but the
docstring states 15; update the docstring to state "Maximum bisection iterations
(default 100)" (or change the function signature default to 15 if intended) so
the textual docs match the actual parameter default; specifically edit the
parameter description for max_iterations in the upper_limit function's docstring
to reference the correct default value.
| solver_kwargs = {"rtol": rtol, "atol": atol, "max_steps": max_steps} | ||
|
|
||
| # Observed limit | ||
| observed = upper_limit( | ||
| lambda poi: calc_fn(poi).cl_s, | ||
| bounds, | ||
| level, | ||
| **solver_kwargs, | ||
| ) | ||
|
|
||
| # Expected limits at each sigma band | ||
| def limit_at_band(band_idx: int) -> Array: | ||
| """Compute limit where expected CLs at given band equals level.""" | ||
|
|
||
| def objective(poi: float) -> Array: | ||
| result = calc_fn(poi) | ||
| bands = result.expected_bands | ||
| # Each band contains (pnull, palt) tuple | ||
| band_list = [ | ||
| bands.minus_2sigma, | ||
| bands.minus_1sigma, | ||
| bands.median, | ||
| bands.plus_1sigma, | ||
| bands.plus_2sigma, | ||
| ] | ||
| pnull, palt = band_list[band_idx] | ||
| return cl_s(palt, pnull) | ||
|
|
||
| return upper_limit(objective, bounds, level, **solver_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# First, let's look at the file structure and confirm the code at lines 220-248
wc -l src/everwillow/inference/hypotest/upper_limit.pyRepository: MoAly98/everwillow
Length of output: 114
🏁 Script executed:
# Extract the relevant section
sed -n '220,248p' src/everwillow/inference/hypotest/upper_limit.pyRepository: MoAly98/everwillow
Length of output: 989
🏁 Script executed:
# Search for the upper_limit function definition to understand its signature
rg -n "def upper_limit" src/everwillow/inference/hypotest/upper_limit.py -A 10Repository: MoAly98/everwillow
Length of output: 733
🏁 Script executed:
# Check all usages of solver_kwargs in the file
rg -n "solver_kwargs" src/everwillow/inference/hypotest/upper_limit.pyRepository: MoAly98/everwillow
Length of output: 241
🏁 Script executed:
# Get the function that contains lines 220-248
sed -n '1,220p' src/everwillow/inference/hypotest/upper_limit.py | tail -50Repository: MoAly98/everwillow
Length of output: 1898
🏁 Script executed:
# Look for the function definition before line 220
rg -n "^def " src/everwillow/inference/hypotest/upper_limit.py | awk -F: '$1 < 220 {last=$0} END {print last}'Repository: MoAly98/everwillow
Length of output: 91
🏁 Script executed:
# Get the full function signature by looking around where the solver_kwargs is used
sed -n '170,225p' src/everwillow/inference/hypotest/upper_limit.pyRepository: MoAly98/everwillow
Length of output: 2071
🏁 Script executed:
# Verify there are no other solver_kwargs usages in the entire file
rg -n "solver_kwargs" src/everwillow/inference/hypotest/upper_limit.pyRepository: MoAly98/everwillow
Length of output: 241
🏁 Script executed:
# Check if there are any other files in the codebase that might use similar patterns
fd -t f "\.py$" -x rg -l "solver_kwargs" {} \; 2>/dev/null | head -20Repository: MoAly98/everwillow
Length of output: 146
🏁 Script executed:
# Check the test file to see if it has similar issues we should be aware of
rg -n "solver_kwargs" tests/inference/test_fitting.py -B 2 -A 2Repository: MoAly98/everwillow
Length of output: 257
Pass solver kwargs explicitly to satisfy mypy (CI failure).
When solver_kwargs is created as {"rtol": rtol, "atol": atol, "max_steps": max_steps}, mypy infers the dict type as dict[str, float | int]. Unpacking this dict with **solver_kwargs loses type specificity, causing max_steps to be treated as float instead of int, which conflicts with the max_steps: int parameter in upper_limit(). Pass the arguments explicitly to preserve their types in both function calls.
Fix
- solver_kwargs = {"rtol": rtol, "atol": atol, "max_steps": max_steps}
-
# Observed limit
observed = upper_limit(
lambda poi: calc_fn(poi).cl_s,
bounds,
level,
- **solver_kwargs,
+ rtol=rtol,
+ atol=atol,
+ max_steps=max_steps,
)
@@
- return upper_limit(objective, bounds, level, **solver_kwargs)
+ return upper_limit(
+ objective,
+ bounds,
+ level,
+ rtol=rtol,
+ atol=atol,
+ max_steps=max_steps,
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| solver_kwargs = {"rtol": rtol, "atol": atol, "max_steps": max_steps} | |
| # Observed limit | |
| observed = upper_limit( | |
| lambda poi: calc_fn(poi).cl_s, | |
| bounds, | |
| level, | |
| **solver_kwargs, | |
| ) | |
| # Expected limits at each sigma band | |
| def limit_at_band(band_idx: int) -> Array: | |
| """Compute limit where expected CLs at given band equals level.""" | |
| def objective(poi: float) -> Array: | |
| result = calc_fn(poi) | |
| bands = result.expected_bands | |
| # Each band contains (pnull, palt) tuple | |
| band_list = [ | |
| bands.minus_2sigma, | |
| bands.minus_1sigma, | |
| bands.median, | |
| bands.plus_1sigma, | |
| bands.plus_2sigma, | |
| ] | |
| pnull, palt = band_list[band_idx] | |
| return cl_s(palt, pnull) | |
| return upper_limit(objective, bounds, level, **solver_kwargs) | |
| # Observed limit | |
| observed = upper_limit( | |
| lambda poi: calc_fn(poi).cl_s, | |
| bounds, | |
| level, | |
| rtol=rtol, | |
| atol=atol, | |
| max_steps=max_steps, | |
| ) | |
| # Expected limits at each sigma band | |
| def limit_at_band(band_idx: int) -> Array: | |
| """Compute limit where expected CLs at given band equals level.""" | |
| def objective(poi: float) -> Array: | |
| result = calc_fn(poi) | |
| bands = result.expected_bands | |
| # Each band contains (pnull, palt) tuple | |
| band_list = [ | |
| bands.minus_2sigma, | |
| bands.minus_1sigma, | |
| bands.median, | |
| bands.plus_1sigma, | |
| bands.plus_2sigma, | |
| ] | |
| pnull, palt = band_list[band_idx] | |
| return cl_s(palt, pnull) | |
| return upper_limit( | |
| objective, | |
| bounds, | |
| level, | |
| rtol=rtol, | |
| atol=atol, | |
| max_steps=max_steps, | |
| ) |
🤖 Prompt for AI Agents
In `@src/everwillow/inference/hypotest/upper_limit.py` around lines 220 - 248,
Replace the **solver_kwargs unpacking with explicit keyword args to preserve
precise types: call upper_limit for observed using rtol=rtol, atol=atol,
max_steps=max_steps (instead of **solver_kwargs), and likewise in limit_at_band
return upper_limit(..., rtol=rtol, atol=atol, max_steps=max_steps); keep the
same arguments to upper_limit, using calc_fn, cl_s, bounds, and level as before
and remove reliance on the solver_kwargs dict for these calls.
| Examples: | ||
| >>> State.from_pytree({"a": [1, 2]}).mapping | ||
| mappingproxy({('a', 0): 1.0, ('a', 1): 2.0}) | ||
| mappingproxy({('a', 0): 1, ('a', 1): 2}) | ||
| Round-trip with pre-canonicalized keys: | ||
| >>> state = State.from_pytree({"x": 1.0, "y": 2.0}) | ||
| >>> flat = state.to_dict() # {('x',): 1.0, ('y',): 2.0} | ||
| >>> State.from_pytree(flat, canonicalize=False).to_dict() == flat | ||
| True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix doc example: canonicalize arg isn’t in State.from_pytree.
The example will raise TypeError because the method doesn’t accept canonicalize. Update the docstring (or add the parameter if intended).
📝 Suggested docstring fix
- >>> State.from_pytree(flat, canonicalize=False).to_dict() == flat
+ >>> State.from_pytree(flat).to_dict() == flat📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| Examples: | |
| >>> State.from_pytree({"a": [1, 2]}).mapping | |
| mappingproxy({('a', 0): 1.0, ('a', 1): 2.0}) | |
| mappingproxy({('a', 0): 1, ('a', 1): 2}) | |
| Round-trip with pre-canonicalized keys: | |
| >>> state = State.from_pytree({"x": 1.0, "y": 2.0}) | |
| >>> flat = state.to_dict() # {('x',): 1.0, ('y',): 2.0} | |
| >>> State.from_pytree(flat, canonicalize=False).to_dict() == flat | |
| True | |
| Examples: | |
| >>> State.from_pytree({"a": [1, 2]}).mapping | |
| mappingproxy({('a', 0): 1, ('a', 1): 2}) | |
| Round-trip with pre-canonicalized keys: | |
| >>> state = State.from_pytree({"x": 1.0, "y": 2.0}) | |
| >>> flat = state.to_dict() # {('x',): 1.0, ('y',): 2.0} | |
| >>> State.from_pytree(flat).to_dict() == flat | |
| True |
🤖 Prompt for AI Agents
In `@src/everwillow/statelib/state.py` around lines 226 - 235, The docstring
example incorrectly calls State.from_pytree(..., canonicalize=False) even though
State.from_pytree does not accept a canonicalize parameter; update the
documentation to remove the canonicalize argument and instead show the intended
round-trip using a pre-canonicalized dict (e.g., show creating flat =
state.to_dict() and then State.from_pytree(flat).to_dict() == flat) or, if the
behavior was intended, add a canonicalize parameter to State.from_pytree and
implement handling; reference the State.from_pytree and to_dict symbols when
making the change.
This should be general, I implemented the main tests to be provided by default, the maths in asymptotic approximations needs to be double-checked, user should be able to construct their own test statistic, and test statistic distribution (whatever approximation is relevant to them) and pass that to our generic calculator that performs the test. Should try this for a chi2 as an example.
To be merged after #34
Summary by CodeRabbit
New Features
Documentation
✏️ Tip: You can customize this high-level summary in your review settings.