Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

* Uses prettier `tqdm` output that is now aware of Jupyter notebooks.
* `bayes.update` now supports `lognorm` and `gamma` distributions in addition to `norm` and `beta`.
* Allow `zero_inflated` to accept 0 or 1 as probabilities.

## v0.29 - latest release

Expand Down
4 changes: 3 additions & 1 deletion CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,6 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co

## Workflow
- Always run `make format` before committing to ensure code passes formatting checks
- Update CHANGES.md when adding new features or fixing bugs
- Always update `CHANGES.md` when making code changes (bug fixes, new features, etc.)
- Add changelog entries under the development version section at the top (marked with `- development version`)
- Follow the existing format: `* Description of change`
9 changes: 8 additions & 1 deletion squigglepy/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1796,8 +1796,15 @@ def zero_inflated(p_zero, dist):
- 0
- <Distribution> norm(mean=1.5, sd=0.3)
"""
if p_zero > 1 or p_zero < 0 or not isinstance(p_zero, float):
if isinstance(p_zero, bool) or not isinstance(p_zero, (float, int)):
raise ValueError("`p_zero` must be a float or int")
if not 0 <= p_zero <= 1:
raise ValueError("`p_zero` must be between 0 and 1")
# Handle edge cases: p_zero=1 means always 0, p_zero=0 means always sample from dist
if p_zero == 1:
return MixtureDistribution(dists=[0], weights=[1])
elif p_zero == 0:
return MixtureDistribution(dists=[dist], weights=[1])
return MixtureDistribution(dists=[0, dist], weights=p_zero)


Expand Down
2 changes: 1 addition & 1 deletion squigglepy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _process_weights_values(weights=None, relative_weights=None, values=None, dr
weights = relative_weights
relative = True

if isinstance(weights, float):
if isinstance(weights, (float, int)) and not isinstance(weights, bool):
weights = [weights]
elif isinstance(weights, np.ndarray):
weights = list(weights)
Expand Down