-
Notifications
You must be signed in to change notification settings - Fork 0
Implement parameter uncertainties and covariance #34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
📝 WalkthroughWalkthroughAdded Hessian-based uncertainty utilities to everwillow.inference, migrated statelib APIs from PartitionedMapping/MergeMeta to a State-centric design (merge/split/partition semantics), updated inference fitting to use State, bumped Python/pytest requirements, and added tests and docs reflecting the API changes. Changes
Sequence Diagram(s)sequenceDiagram
participant Caller
participant StateUtils as State utils
participant Uncertainty as uncertainty.py
participant JAX as JAX (hessian, linalg)
Caller->>Uncertainty: uncertainties(nll_fn, params, fixed)
Uncertainty->>StateUtils: partition params into fixed / free
StateUtils-->>Uncertainty: free layout + mapping info
Uncertainty->>Uncertainty: flatten free params -> flat_x
Uncertainty->>Uncertainty: _flat_nll(flat_x) (reconstruct full state for eval)
Uncertainty->>JAX: jax.hessian(_flat_nll) -> Hessian H
JAX-->>Uncertainty: Hessian H
Uncertainty->>JAX: invert H -> covariance
JAX-->>Uncertainty: covariance
Uncertainty->>Uncertainty: compute uncertainties (sqrt(diag)), reassemble State with None for fixed
Uncertainty-->>Caller: return State of uncertainties
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@src/everwillow/inference/uncertainty.py`:
- Around line 162-170: The returned State uses params.treedefmeta
(full-parameter TreeDefMeta) while its mapping only contains free parameter
keys, causing to_pytree() KeyErrors; change the treedefmeta used when
constructing the State to be built from free_mapping (the result of
sl.partition) instead of params.treedefmeta so the State's treedefmeta matches
the keys in uncertainty_mapping (refer to free_mapping, stderrs, sl.State,
params.treedefmeta, and to_pytree()).
In `@tests/inference/test_uncertainty.py`:
- Around line 72-145: The failures are due to missing local variable type
annotations in the tests (mypy check_untyped_defs); add explicit type
annotations for local vars like nll, params, fixed, and hess in each test (e.g.
annotate nll as Callable[[sl.State], float] or Callable[..., float], params and
fixed as sl.State or Optional[sl.State], and hess as jnp.ndarray) in tests under
TestHessianMatrix (e.g. test_simple_quadratic, test_shape_matches_free_params,
test_fixed_params_excluded, test_all_fixed_returns_empty, test_symmetric,
test_validates_params_type, test_validates_fixed_type) that use
simple_quadratic_nll and hessian_matrix; alternatively, if you prefer not to
annotate tests, add per-line `# type: ignore[var-annotated]` comments or
relax/exclude tests from check_untyped_defs in pyproject.toml to satisfy CI.
🧹 Nitpick comments (1)
src/everwillow/inference/uncertainty.py (1)
80-103: Consider documenting behavior for singular Hessians.The implementation is mathematically correct. However,
jnp.linalg.invwill produce NaN or errors for singular/ill-conditioned Hessians (which can occur at saddle points or when the NLL is flat in some direction).This is acceptable for typical MLE workflows where the Hessian should be positive definite, but documenting this precondition in the docstring would help users diagnose issues.
…ion if user is providing already-canonical pytrees. Allow only if all keys are tuples and are not nested.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In `@src/everwillow/inference/uncertainty.py`:
- Around line 25-31: Add per-line noqa for pyflakes F722 on the jaxtyping
shape-annotated return/type strings to silence Ruff: update the lines that use
jaxtyping shape strings (e.g., the return annotation on hessian_matrix:
Float[Array, "n_free n_free"] and the other shape-annotated signatures
referenced in this file) by appending " # noqa: F722" to each of those
annotation lines so F722 no longer flags the quoted names as undefined.
In `@src/everwillow/statelib/state.py`:
- Around line 246-261: The code assumes path entries have a .key when
canonicalize is False (in the branch that currently does key =
next(iter(path)).key), which raises AttributeError for sequence keys; change
this to extract the first entry (e.g., entry = next(iter(path))) and explicitly
check its type (isinstance(entry, jtu.DictKey)) before accessing .key; if it
isn’t a DictKey (e.g., is a SequenceKey), raise a ValueError with a clear
message similar to the other checks, otherwise set key = entry.key; reference
the canonicalize=False branch, the path variable, and
jtu.DictKey/jtu.SequenceKey in your change.
| def hessian_matrix( | ||
| nll_fn: tp.Callable[[PyTree[V]], float], | ||
| params: sl.State[V], | ||
| *, | ||
| fixed: sl.State[V | EllipsisType] | None = None, | ||
| ) -> Float[Array, "n_free n_free"]: | ||
| """Compute the Hessian matrix of the NLL at given parameters. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# First, find the file and check if it exists
find . -type f -name "uncertainty.py" | head -20Repository: MoAly98/everwillow
Length of output: 103
🏁 Script executed:
# Check for Ruff configuration files
find . -type f \( -name "ruff.toml" -o -name ".ruff.toml" -o -name "pyproject.toml" \) | head -20Repository: MoAly98/everwillow
Length of output: 90
🏁 Script executed:
# Read the uncertainty.py file to see the code at the specified lines
if [ -f "src/everwillow/inference/uncertainty.py" ]; then
wc -l src/everwillow/inference/uncertainty.py
echo "=== Lines 20-35 ==="
sed -n '20,35p' src/everwillow/inference/uncertainty.py
echo "=== Lines 55-65 ==="
sed -n '55,65p' src/everwillow/inference/uncertainty.py
echo "=== Lines 75-90 ==="
sed -n '75,90p' src/everwillow/inference/uncertainty.py
fiRepository: MoAly98/everwillow
Length of output: 1449
🏁 Script executed:
# Read the Ruff configuration files
echo "=== ruff.toml ==="
cat ruff.toml
echo ""
echo "=== pyproject.toml (Ruff section) ==="
cat pyproject.toml | grep -A 50 "\[tool.ruff\]"Repository: MoAly98/everwillow
Length of output: 1025
🏁 Script executed:
# Check what Ruff F722 rule is about
# F722 is a Pyflakes rule related to syntax/forward references
# Let's verify if there's any F722-related configuration
rg "F722" . --type tomlRepository: MoAly98/everwillow
Length of output: 209
Add # noqa: F722 comments for jaxtyping shape annotations to prevent Ruff CI failures.
The Ruff configuration selects Pyflakes rules (including F722) but does not ignore it. Lines 30, 61, and 84 contain jaxtyping shape strings like Float[Array, "n_free n_free"] which F722 flags as undefined names. Add per-line # noqa: F722 comments or extend the Ruff ignore list to resolve this:
Per-line ignore approach
-) -> Float[Array, "n_free n_free"]:
+) -> Float[Array, "n_free n_free"]: # noqa: F722
- def _flat_nll(flat_free: Float[Array, ...]) -> Float[Array, ""]:
+ def _flat_nll(flat_free: Float[Array, ...]) -> Float[Array, ""]: # noqa: F722
-) -> Float[Array, "nparams nparams"]:
+) -> Float[Array, "nparams nparams"]: # noqa: F722🧰 Tools
🪛 Ruff (0.14.13)
30-30: Syntax error in forward annotation: Unexpected token at the end of an expression
(F722)
🤖 Prompt for AI Agents
In `@src/everwillow/inference/uncertainty.py` around lines 25 - 31, Add per-line
noqa for pyflakes F722 on the jaxtyping shape-annotated return/type strings to
silence Ruff: update the lines that use jaxtyping shape strings (e.g., the
return annotation on hessian_matrix: Float[Array, "n_free n_free"] and the other
shape-annotated signatures referenced in this file) by appending " # noqa:
F722" to each of those annotation lines so F722 no longer flags the quoted names
as undefined.
| # Get flat array of free values | ||
| free_keys = tuple(free_state.keys()) | ||
| flat_values = jnp.array([free_state[k] for k in free_keys]) | ||
|
|
||
| def _flat_nll(flat_free: Float[Array, ...]) -> Float[Array, ""]: | ||
| """Compute the negative log-likelihood for the flat parameter vector. | ||
| Necessary for jax.hessian. | ||
| Args: | ||
| flat_free (Float[Array, "..."]): Flattened free parameter values. | ||
| Returns: | ||
| Float[Array, ""]: Negative log-likelihood value. | ||
| """ | ||
| free_mapping = {k: flat_free[i] for i, k in enumerate(free_keys)} | ||
| new_free = sl.PartitionedMapping(free_mapping, origin=free_state.origin) | ||
| combined = sl.combine_partitions(fixed_state, new_free) | ||
| full_state = sl.State(combined, treedefmeta=params.treedefmeta) | ||
| return nll_fn(full_state.to_pytree()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# First, let's read the full file to understand the context
wc -l src/everwillow/inference/uncertainty.pyRepository: MoAly98/everwillow
Length of output: 180
🏁 Script executed:
# Read the uncertainty.py file with line numbers to see full context
cat -n src/everwillow/inference/uncertainty.py | head -100Repository: MoAly98/everwillow
Length of output: 180
🏁 Script executed:
# Search for hessian_matrix function and its callers
rg -A 50 "def hessian_matrix" src/everwillow/inference/uncertainty.pyRepository: MoAly98/everwillow
Length of output: 1978
🏁 Script executed:
# Check how free_state is created/used in the file
rg "free_state" src/everwillow/inference/uncertainty.py -B 2 -A 2Repository: MoAly98/everwillow
Length of output: 716
🏁 Script executed:
# Look for any existing leaf validation
rg "scalar|leaf|ndim|flatten" src/everwillow/inference/uncertainty.pyRepository: MoAly98/everwillow
Length of output: 346
🏁 Script executed:
# Check if jax.flatten_util is already imported or used elsewhere in the codebase
rg "flatten_util|ravel_pytree" --type pyRepository: MoAly98/everwillow
Length of output: 307
🏁 Script executed:
# Check for Ruff F722 warnings (unused type annotation) in the file
rg "F722" src/everwillow/inference/uncertainty.pyRepository: MoAly98/everwillow
Length of output: 44
🏁 Script executed:
# Read the entire uncertainty.py file to see downstream usage of hessian_matrix
cat -n src/everwillow/inference/uncertainty.pyRepository: MoAly98/everwillow
Length of output: 7510
🏁 Script executed:
# Check covariance_matrix and correlation_matrix functions that use hessian_matrix
rg -A 20 "def covariance_matrix|def correlation_matrix|def uncertainties" src/everwillow/inference/uncertainty.pyRepository: MoAly98/everwillow
Length of output: 2302
🏁 Script executed:
# Look for how hessian_matrix output is used
rg "hessian_matrix" src/everwillow/inference/uncertainty.py -A 3 -B 1Repository: MoAly98/everwillow
Length of output: 325
🏁 Script executed:
# Check tests to see what inputs are expected/tested
find . -name "*.py" -path "*/test*" | xargs grep -l "hessian_matrix\|covariance_matrix" 2>/dev/null | head -5Repository: MoAly98/everwillow
Length of output: 99
🏁 Script executed:
# Read the test file to understand expected behavior and inputs
cat -n tests/inference/test_uncertainty.pyRepository: MoAly98/everwillow
Length of output: 15752
🏁 Script executed:
# Check what types of parameters are used in examples
rg "State\(" examples/ -A 2 -B 2 | head -40Repository: MoAly98/everwillow
Length of output: 44
🏁 Script executed:
# Look for any array-valued parameters in the codebase
rg "sl.State|free_state" tests/inference/test_uncertainty.py -A 3 -B 1Repository: MoAly98/everwillow
Length of output: 4606
Hessian computation requires scalar-valued parameters.
Line 59 stacks leaves into a single array without validating they are scalars. If any leaf is an array (e.g., shape (2,)), line 71's reconstruction via flat_free[i] would extract scalars instead of preserving the original array shape, causing incorrect gradient computation. Downstream functions (covariance_matrix, correlation_matrix, uncertainties) all assume the Hessian is 2D via jnp.diag() and jnp.linalg.inv(), which will fail or produce incorrect results for higher-rank tensors.
Add fail-fast validation or replace with jax.flatten_util.ravel_pytree() to properly support array-valued parameters.
🛡️ Fail-fast validation (minimal fix)
# Get flat array of free values
free_keys = tuple(free_state.keys())
+ if any(jnp.ndim(free_state[k]) != 0 for k in free_keys):
+ raise ValueError(
+ "hessian_matrix currently supports scalar leaves only; "
+ "flatten array parameters first."
+ )
flat_values = jnp.array([free_state[k] for k in free_keys])🧰 Tools
🪛 Ruff (0.14.13)
61-61: Syntax error in forward annotation: Expected an expression
(F722)
| K: tp.TypeAlias = str | tuple[str, ...] | ||
| V = tp.TypeVar("V", bound=ArrayLike) | ||
|
|
||
| def _flatten_iterables(x: tp.Any) -> tp.Iterator[tp.Any]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pfackeldey cloest i can come up with to general handling 🗡️
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@src/everwillow/inference/uncertainty.py`:
- Around line 1-176: CI flagged formatting changes from ruff-format; run the
formatter and commit the updated file. Locally run `ruff format
src/everwillow/inference/uncertainty.py`, review the edits (affecting functions
like hessian_matrix, covariance_matrix, correlation_matrix, uncertainties),
stage the reformatted file, and commit/push to resolve the pre-commit/CI
failure.
♻️ Duplicate comments (2)
src/everwillow/inference/uncertainty.py (2)
30-30: Add# noqa: F722for jaxtyping shape annotation.The jaxtyping string
"n_free n_free"triggers Ruff F722. This also applies to lines 61 and 85.Suggested fix
-) -> Float[Array, "n_free n_free"]: +) -> Float[Array, "n_free n_free"]: # noqa: F722
57-75: Hessian computation assumes scalar-valued parameter leaves.Line 59 stacks leaves into a flat array assuming each leaf is a scalar. If any leaf is an array (e.g., shape
(2,)), line 71'sflat_free[i]extracts a scalar instead of preserving the original shape, causing incorrect gradient computation.Consider adding validation or using
jax.flatten_util.ravel_pytree()for proper array support:Minimal fail-fast validation
# Get flat array of free values free_keys = tuple(free_state.keys()) + if any(jnp.ndim(free_state[k]) != 0 for k in free_keys): + raise ValueError( + "hessian_matrix currently supports scalar leaves only; " + "flatten array parameters first." + ) flat_values = jnp.array([free_state[k] for k in free_keys])
🧹 Nitpick comments (2)
src/everwillow/inference/uncertainty.py (2)
101-103: Consider handling singular or ill-conditioned Hessians.
jnp.linalg.inv(hess)will fail or produce numerically unstable results if the Hessian is singular or ill-conditioned (e.g., at saddle points or with redundant parameters). Consider usingjnp.linalg.pinv()or adding a small regularization term for robustness.
162-165: Minor: Redundant partitioning.The
partitioncall here duplicates work already done internally bycovariance_matrix→hessian_matrix. For a small optimization,hessian_matrixcould return the free keys alongside the Hessian, or this could be refactored to share the partitioning logic. This is a minor efficiency concern and can be deferred.
…d of PartionedMapping objects.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/everwillow/statelib/state.py (1)
368-386: Partition docs still describe the old mapping/origin behavior.The Args block and example still imply mapping identity and show
right.to_pytree()withoutNoneplaceholders, but partitions now containNonefor excluded keys. Update the docstring and example to align with the new State-based semantics.✏️ Suggested doc fix
- Args: - mapping: Mapping obtained from a :class:`State`. The identity of this - mapping is stored to ensure only compatible partitions are merged. + Args: + state: State instance to partition. @@ - >>> right.to_pytree() - {('b',): 2} + >>> dict(right.notnone) + {('b',): 2}
🤖 Fix all issues with AI agents
In `@src/everwillow/inference/fitting.py`:
- Around line 34-41: The tuple unpacking in _reconstruct_full_state currently
binds treedefmeta but never uses it; update the unpacking of Args inside
_reconstruct_full_state to use an underscore for that element (e.g.,
(fixed_state, _treedefmeta, bounds) or simply (_,) style) or remove treedefmeta
from the Args definition if it is genuinely unnecessary; modify the unpacking in
the _reconstruct_full_state function (which currently binds free_state,
fixed_state, treedefmeta, bounds) so treedefmeta is explicitly marked unused to
satisfy RUF059 and avoid linter warnings.
In `@src/everwillow/inference/uncertainty.py`:
- Around line 170-173: The code creates an intermediate variable uncs_combined
only to return it; replace the assignment and return with a single inline
return: directly return sl.combine_partitions(fixed_uncertainty,
free_uncertainty) where fixed_uncertainty is produced by jax.tree.map(lambda _:
None, fixed_state); remove the now-unused uncs_combined variable to satisfy
RET504 and keep the function concise.
- Around line 161-169: The current reconstruction of free_uncertainty uses
tree_unflatten(free_state) which fails when fixed params exist because
free_state contains None leaves; instead, mirror the pattern from
hessian_matrix: build a flat list only for free keys (use
free_state.notnone.keys() or equivalent), reconstruct a pytree with sl.update()
by mapping stderrs back into those free keys, and then merge with fixed_state to
produce the final uncertainty pytree; update the code around fixed_state,
free_state, stderrs to use that selective-key unflattening and sl.update() so
leaf counts match.
♻️ Duplicate comments (2)
src/everwillow/inference/uncertainty.py (2)
57-74: Hessian assumes scalar leaves; array-valued params break shapes/inversion.
flat_valuesis built from leaves and downstream code assumes a 2D Hessian. If any leaf is array-valued, the Hessian rank explodes andinv/diagbecome invalid. Add a scalar-only guard or switch tojax.flatten_util.ravel_pytree.✅ Minimal scalar-only guard
# Get flat array of free values free_keys = tuple(free_state.notnone.keys()) + if any(jnp.ndim(free_state[k]) != 0 for k in free_keys): + raise ValueError( + "hessian_matrix currently supports scalar leaves only; " + "flatten array parameters before calling." + ) flat_values = jnp.array([free_state[k] for k in free_keys])
25-31: Ruff F722 on jaxtyping shape strings.The shape-annotated return types trigger F722. Add per-line
# noqa: F722(or update Ruff config).✅ Per-line ignores
-) -> Float[Array, "n_free n_free"]: +) -> Float[Array, "n_free n_free"]: # noqa: F722 - def _flat_nll(flat_free: Float[Array, ...]) -> Float[Array, ""]: + def _flat_nll(flat_free: Float[Array, ...]) -> Float[Array, ""]: # noqa: F722 -) -> Float[Array, "nparams nparams"]: +) -> Float[Array, "nparams nparams"]: # noqa: F722Also applies to: 61-61, 84-85
| # Get free_state with same structure/ordering as used for Hessian | ||
| fixed_state, free_state = sl.partition( | ||
| params, predicate=lambda key, _: key in fixed | ||
| ) | ||
|
|
||
| # Unflatten stderrs back into the same pytree structure as free_state | ||
| _, treedef = jax.tree_util.tree_flatten(free_state) | ||
| free_uncertainty = jax.tree_util.tree_unflatten(treedef, stderrs) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
cat -n src/everwillow/inference/uncertainty.py | head -200Repository: MoAly98/everwillow
Length of output: 7298
🏁 Script executed:
find . -name "statelib*" -o -name "*statelib*" | head -20Repository: MoAly98/everwillow
Length of output: 315
🏁 Script executed:
ls -la src/everwillow/statelib/Repository: MoAly98/everwillow
Length of output: 477
🏁 Script executed:
cat -n src/everwillow/statelib/state.pyRepository: MoAly98/everwillow
Length of output: 18068
🏁 Script executed:
find . -path ./node_modules -prune -o -type f -name "*.py" -print | xargs grep -l "uncertainties" | head -20Repository: MoAly98/everwillow
Length of output: 180
🏁 Script executed:
cat -n tests/inference/test_uncertainty.pyRepository: MoAly98/everwillow
Length of output: 15752
🏁 Script executed:
cd . && python -m pytest tests/inference/test_uncertainty.py::TestUncertainties::test_fixed_params_none_uncertainty -xvs 2>&1 | head -100Repository: MoAly98/everwillow
Length of output: 148
🏁 Script executed:
python3 << 'EOF'
import jax
import jax.tree_util as jtu
import jax.numpy as jnp
# Simulate what happens with free_state containing None values
free_state_with_none = {
"x": 0.5, # free param
"y": None # fixed param (None)
}
# Test tree_flatten with None values
leaves, treedef = jtu.tree_flatten(free_state_with_none)
print(f"Flattened leaves: {leaves}")
print(f"Number of leaves: {len(leaves)}")
# Now try to unflatten with fewer leaves (simulating stderrs with only 1 value)
stderrs = jnp.array([0.1]) # Only 1 value for the free param
print(f"Stderrs shape: {stderrs.shape}")
print(f"Stderrs length: {len(stderrs)}")
try:
result = jtu.tree_unflatten(treedef, stderrs)
print(f"Unflatten succeeded: {result}")
except Exception as e:
print(f"Unflatten FAILED with error: {type(e).__name__}: {e}")
EOFRepository: MoAly98/everwillow
Length of output: 177
🏁 Script executed:
python3 << 'EOF'
# Let's trace through the logic manually without JAX
# When partition returns free_state, it has the same treedefmeta as original
# but with None values for fixed keys
free_state_data = {
("x",): 0.5, # free param
("y",): None # fixed param set to None
}
# tree_flatten(free_state) will flatten the State's internal mapping
# which includes all keys including the None value
print("Free state data (includes all keys):")
print(free_state_data)
print(f"Number of leaves in free_state: {len(free_state_data)}")
# But stderrs only has values for free parameters
stderrs = [0.1] # Only 1 value for x
print(f"\nStderrs length: {len(stderrs)}")
# tree_unflatten expects 2 leaves but gets only 1 -> ERROR
print(f"\nMismatch: tree_unflatten expects {len(free_state_data)} leaves but stderrs has {len(stderrs)}")
EOFRepository: MoAly98/everwillow
Length of output: 240
🏁 Script executed:
sed -n '57,76p' src/everwillow/inference/uncertainty.pyRepository: MoAly98/everwillow
Length of output: 870
Refactor uncertain reconstruction to use only free parameters, matching the pattern in hessian_matrix.
When fixed parameters are present, tree_flatten(free_state) includes None leaves for those keys, but stderrs only contains values for free parameters. The tree_unflatten call will fail due to leaf count mismatch. The hessian_matrix function (lines 58, 71–72) already shows the correct pattern: use free_state.notnone.keys() to get only free keys, then rebuild with sl.update().
🔧 Corrected reconstruction
- # Unflatten stderrs back into the same pytree structure as free_state
- _, treedef = jax.tree_util.tree_flatten(free_state)
- free_uncertainty = jax.tree_util.tree_unflatten(treedef, stderrs)
+ # Rebuild free uncertainties using only free keys
+ free_keys = tuple(free_state.notnone.keys())
+ free_updates = {k: stderrs[i] for i, k in enumerate(free_keys)}
+ free_uncertainty = sl.update(free_state, updates=free_updates)🤖 Prompt for AI Agents
In `@src/everwillow/inference/uncertainty.py` around lines 161 - 169, The current
reconstruction of free_uncertainty uses tree_unflatten(free_state) which fails
when fixed params exist because free_state contains None leaves; instead, mirror
the pattern from hessian_matrix: build a flat list only for free keys (use
free_state.notnone.keys() or equivalent), reconstruct a pytree with sl.update()
by mapping stderrs back into those free keys, and then merge with fixed_state to
produce the final uncertainty pytree; update the code around fixed_state,
free_state, stderrs to use that selective-key unflattening and sl.update() so
leaf counts match.
…tead of FrozenChainMap, based on jax FlatTree pack and unpack methods
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (6)
docs/statelib_overview.md (6)
3-3: Inconsistent terminology:FlatStateshould beState.The opening paragraph still references
FlatStateas the core primitive, but the API has been renamed toState. This should be updated for consistency with the rest of the documentation and the actual API.📝 Suggested fix
-Small collection of helpers for working with flattened parameter states in the Everwillow codebase. Import the API via `import everwillow.statelib as sl`. The core primitive is `FlatState`, an immutable mapping of canonical key tuples to leaf values that carries enough metadata to round‑trip back to the original pytree. +Small collection of helpers for working with flattened parameter states in the Everwillow codebase. Import the API via `import everwillow.statelib as sl`. The core primitive is `State`, an immutable mapping of canonical key tuples to leaf values that carries enough metadata to round‑trip back to the original pytree.
11-11: Inconsistent usage:sl.FlatState.from_pytreeshould besl.State.from_pytree.This example uses the old
FlatStatename which has been renamed toState.📝 Suggested fix
-state = sl.FlatState.from_pytree(tree) +state = sl.State.from_pytree(tree)
26-26: Inconsistent terminology:FlatStateshould beState.The feature description still references
FlatState.from_pytreeinstead of the renamedState.from_pytree.
39-42: Inconsistent function name:partition_stateshould bepartition.The example uses
sl.partition_statebut the function has been renamed tosl.partition.📝 Suggested fix
-flags, rest = sl.partition_state( +flags, rest = sl.partition( state, predicate=lambda key, _value: key[0] == "flags", )
95-97: Inconsistent usage:sl.FlatState.from_pytreeshould besl.State.from_pytree.The end-to-end example still uses
FlatStateinstead ofState.📝 Suggested fix
-state_a = sl.FlatState.from_pytree(tree_a) -state_b = sl.FlatState.from_pytree(tree_b) -state_c = sl.FlatState.from_pytree(tree_c) +state_a = sl.State.from_pytree(tree_a) +state_b = sl.State.from_pytree(tree_b) +state_c = sl.State.from_pytree(tree_c)
129-132: Inconsistent function name:partition_stateshould bepartition.This example uses the old
sl.partition_statefunction name.📝 Suggested fix
-first_partition, second_partition = sl.partition_state( +first_partition, second_partition = sl.partition( merged, predicate=lambda key, _: "a" in key, )
🤖 Fix all issues with AI agents
In `@src/everwillow/statelib/state.py`:
- Around line 297-305: Add explicit type annotations for the local variables
all_keys, all_vals, child_treedefs, mapping, and compound_treedef inside the
code that builds the compound state (the block that constructs all_keys,
all_vals, child_treedefs, compound_treedef, mapping and returns State(...));
also change the dict(zip(...)) call to use zip(all_keys, all_vals, strict=True)
so the pairing is strict. Ensure the annotations use appropriate types (e.g.,
List[str] or Sequence[...] for all_keys, List[Any] for all_vals,
List[jtu.PyTreeDef] for child_treedefs, Dict[str, Any] for mapping, and the
treedef type for compound_treedef) and keep the return line using State(mapping,
treedefmeta=TreeDefMeta(compound_treedef, tuple(all_keys))).
- Around line 349-353: The docstring example for partition is wrong because
partition returns States that keep excluded keys as None; update the example to
call right.notnone before to_pytree (e.g., use right.notnone.to_pytree()) or
adjust the expected output to include {('a',): None, ('b',): 2}; reference the
State.from_pytree and partition example lines and show the corrected output that
matches right.notnone.to_pytree().
🧹 Nitpick comments (1)
src/everwillow/statelib/state.py (1)
393-393: Consider breaking the longtree_mapcall for readability.This line exceeds typical line length limits and could be more readable if split.
♻️ Suggested formatting
- return jtu.tree_map(lambda x1, x2: x1 if x1 is not None else x2, left, right, is_leaf=lambda x: x is None) + return jtu.tree_map( + lambda x1, x2: x1 if x1 is not None else x2, + left, + right, + is_leaf=lambda x: x is None, + )
src/everwillow/statelib/state.py
Outdated
| all_keys, all_vals, child_treedefs = [], [], [] | ||
| for s in states: | ||
| all_keys.extend(s.treedefmeta.keys) | ||
| all_vals.extend(s[k] for k in s.treedefmeta.keys) | ||
| child_treedefs.append(s.treedefmeta.treedef) | ||
|
|
||
| compound_treedef = jtu.treedef_tuple(child_treedefs) | ||
| mapping = dict(zip(all_keys, all_vals)) | ||
| return State(mapping, treedefmeta=TreeDefMeta(compound_treedef, tuple(all_keys))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add type annotations and strict=True to address CI failures.
The pipeline reports missing type annotations for local variables, and the static analysis flags the zip() call without an explicit strict= parameter.
📝 Suggested fix
- all_keys, all_vals, child_treedefs = [], [], []
+ all_keys: list[K] = []
+ all_vals: list[V] = []
+ child_treedefs: list[jtu.PyTreeDef] = []
for s in states:
all_keys.extend(s.treedefmeta.keys)
all_vals.extend(s[k] for k in s.treedefmeta.keys)
child_treedefs.append(s.treedefmeta.treedef)
compound_treedef = jtu.treedef_tuple(child_treedefs)
- mapping = dict(zip(all_keys, all_vals))
+ mapping = dict(zip(all_keys, all_vals, strict=True))
return State(mapping, treedefmeta=TreeDefMeta(compound_treedef, tuple(all_keys)))🧰 Tools
🪛 GitHub Actions: CI
[error] 301-305: Mypy/type hints: missing explicit type annotations for local variables introduced by patch (e.g., all_keys, all_vals). Ensure types are annotated to satisfy type checker.
🪛 Ruff (0.14.13)
304-304: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
🤖 Prompt for AI Agents
In `@src/everwillow/statelib/state.py` around lines 297 - 305, Add explicit type
annotations for the local variables all_keys, all_vals, child_treedefs, mapping,
and compound_treedef inside the code that builds the compound state (the block
that constructs all_keys, all_vals, child_treedefs, compound_treedef, mapping
and returns State(...)); also change the dict(zip(...)) call to use
zip(all_keys, all_vals, strict=True) so the pairing is strict. Ensure the
annotations use appropriate types (e.g., List[str] or Sequence[...] for
all_keys, List[Any] for all_vals, List[jtu.PyTreeDef] for child_treedefs,
Dict[str, Any] for mapping, and the treedef type for compound_treedef) and keep
the return line using State(mapping, treedefmeta=TreeDefMeta(compound_treedef,
tuple(all_keys))).
| >>> state = State.from_pytree({"a": 1, "b": 2}) | ||
| >>> left, right = partition(state, lambda key, _: key == ("a",)) | ||
| >>> dict(right.mapping) | ||
| >>> right.to_pytree() | ||
| {('b',): 2} | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# First, let's find and examine the partition function definition
rg -A 30 "def partition" src/everwillow/statelib/state.py | head -50Repository: MoAly98/everwillow
Length of output: 1143
🏁 Script executed:
#!/bin/bash
# Check the to_pytree method implementation
rg -A 20 "def to_pytree" src/everwillow/statelib/state.py | head -40Repository: MoAly98/everwillow
Length of output: 640
🏁 Script executed:
#!/bin/bash
# Also check TreeDefMeta.to_pytree as mentioned in the review
fd "meta.py" src/ | xargs rg -l "class TreeDefMeta"Repository: MoAly98/everwillow
Length of output: 93
🏁 Script executed:
#!/bin/bash
# Look for the notnone property/method mentioned in the suggestion
rg "notnone" src/everwillow/statelib/state.pyRepository: MoAly98/everwillow
Length of output: 133
🏁 Script executed:
#!/bin/bash
# Get TreeDefMeta.to_pytree implementation
rg -A 30 "class TreeDefMeta" src/everwillow/statelib/meta.py | head -80Repository: MoAly98/everwillow
Length of output: 1073
🏁 Script executed:
#!/bin/bash
# Search more specifically for to_pytree in TreeDefMeta
rg -B 2 -A 25 "def to_pytree" src/everwillow/statelib/meta.pyRepository: MoAly98/everwillow
Length of output: 1049
🏁 Script executed:
#!/bin/bash
# Let's also check the notnone property implementation
rg -B 2 -A 10 "def notnone" src/everwillow/statelib/state.pyRepository: MoAly98/everwillow
Length of output: 506
Docstring example is incorrect.
The partition function returns States containing all original keys with excluded keys set to None. When right.to_pytree() is called, the resulting pytree includes all keys: {('a',): None, ('b',): 2}, not {('b',): 2} as shown in the example. The suggested fix to use right.notnone is correct—it filters out None values and produces the expected output.
📝 Suggested fix
Examples:
>>> state = State.from_pytree({"a": 1, "b": 2})
>>> left, right = partition(state, lambda key, _: key == ("a",))
- >>> right.to_pytree()
- {('b',): 2}
+ >>> right.notnone
+ {('b',): 2}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| >>> state = State.from_pytree({"a": 1, "b": 2}) | |
| >>> left, right = partition(state, lambda key, _: key == ("a",)) | |
| >>> dict(right.mapping) | |
| >>> right.to_pytree() | |
| {('b',): 2} | |
| """ | |
| >>> state = State.from_pytree({"a": 1, "b": 2}) | |
| >>> left, right = partition(state, lambda key, _: key == ("a",)) | |
| >>> right.notnone | |
| {('b',): 2} | |
| """ |
🤖 Prompt for AI Agents
In `@src/everwillow/statelib/state.py` around lines 349 - 353, The docstring
example for partition is wrong because partition returns States that keep
excluded keys as None; update the example to call right.notnone before to_pytree
(e.g., use right.notnone.to_pytree()) or adjust the expected output to include
{('a',): None, ('b',): 2}; reference the State.from_pytree and partition example
lines and show the corrected output that matches right.notnone.to_pytree().
|
@pfackeldey I simplified the |
| "RUF003", # Ambiguous Unicode characters in comments (Greek letters are standard in physics) | ||
| "F722", # jaxtyping: `Float32[Array, "b c"]` (https://docs.kidger.site/jaxtyping/faq/#flake8-or-ruff-are-throwing-an-error) | ||
| "EM101", # flake8-errmsg: Error message should not end with a period | ||
| "UP040", # PEP 695 type alias syntax (https://peps.python.org/pep-0695/) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pfackeldey what do you think about this btw -- it seems the standard was enforced in 3.13 but it is incredibly frustrating to use
Summary by CodeRabbit
New Features
Tests
Refactor
Documentation
Chores
✏️ Tip: You can customize this high-level summary in your review settings.