diff --git a/CHANGES.md b/CHANGES.md index 2c87ca5..0e176a8 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,6 +2,7 @@ * Uses prettier `tqdm` output that is now aware of Jupyter notebooks. * `bayes.update` now supports `lognorm` and `gamma` distributions in addition to `norm` and `beta`. +* Allow `zero_inflated` to accept 0 or 1 as probabilities. ## v0.29 - latest release diff --git a/CLAUDE.md b/CLAUDE.md index fe9fe1b..eb7d027 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -21,4 +21,6 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co ## Workflow - Always run `make format` before committing to ensure code passes formatting checks -- Update CHANGES.md when adding new features or fixing bugs \ No newline at end of file +- Always update `CHANGES.md` when making code changes (bug fixes, new features, etc.) +- Add changelog entries under the development version section at the top (marked with `- development version`) +- Follow the existing format: `* Description of change` diff --git a/squigglepy/distributions.py b/squigglepy/distributions.py index ec33c5d..3483767 100644 --- a/squigglepy/distributions.py +++ b/squigglepy/distributions.py @@ -1796,8 +1796,15 @@ def zero_inflated(p_zero, dist): - 0 - norm(mean=1.5, sd=0.3) """ - if p_zero > 1 or p_zero < 0 or not isinstance(p_zero, float): + if isinstance(p_zero, bool) or not isinstance(p_zero, (float, int)): + raise ValueError("`p_zero` must be a float or int") + if not 0 <= p_zero <= 1: raise ValueError("`p_zero` must be between 0 and 1") + # Handle edge cases: p_zero=1 means always 0, p_zero=0 means always sample from dist + if p_zero == 1: + return MixtureDistribution(dists=[0], weights=[1]) + elif p_zero == 0: + return MixtureDistribution(dists=[dist], weights=[1]) return MixtureDistribution(dists=[0, dist], weights=p_zero) diff --git a/squigglepy/utils.py b/squigglepy/utils.py index 71ca07b..74b9cec 100644 --- a/squigglepy/utils.py +++ b/squigglepy/utils.py @@ -32,7 +32,7 @@ def _process_weights_values(weights=None, relative_weights=None, values=None, dr weights = relative_weights relative = True - if isinstance(weights, float): + if isinstance(weights, (float, int)) and not isinstance(weights, bool): weights = [weights] elif isinstance(weights, np.ndarray): weights = list(weights)