From 576d6bc4c4439f87384e56acd136b4c83c012f3b Mon Sep 17 00:00:00 2001 From: Michael Dickens Date: Fri, 1 Dec 2023 17:23:24 -0800 Subject: [PATCH 01/10] analytic: basic simplification for addition and multiplication --- squigglepy/distributions.py | 102 ++++++++++++++++++++++++++++++++++++ tests/test_simplify.py | 55 +++++++++++++++++++ 2 files changed, 157 insertions(+) create mode 100644 tests/test_simplify.py diff --git a/squigglepy/distributions.py b/squigglepy/distributions.py index 413e749..158fd79 100644 --- a/squigglepy/distributions.py +++ b/squigglepy/distributions.py @@ -12,6 +12,7 @@ from collections.abc import Iterable from abc import ABC, abstractmethod +from functools import reduce class BaseDistribution(ABC): @@ -172,6 +173,12 @@ def __rpow__(self, dist): def __hash__(self): return hash(repr(self)) + def _build_flat_tree(self): + return FlatTree.leaf(self) + + def simplify(self): + return self + # Distribution are either discrete, continuous, or composite @@ -242,6 +249,16 @@ def __str__(self): raise ValueError return out + def _build_flat_tree(self): + left_tree = self.left._build_flat_tree() + right_tree = None + if self.right is not None: + right_tree = self.right._build_flat_tree() + return FlatTree.branch(self.fn, left_tree, right_tree) + + def simplify(self): + return self._build_flat_tree().simplify() + def _get_fname(f, name): if name is None: @@ -1699,3 +1716,88 @@ def geometric(p): geometric(0.1) """ return GeometricDistribution(p=p) + + +class FlatTree: + COMMUTABLE_OPERATIONS = set([operator.add, operator.mul]) + + def __init__(self, dist=None, fn=None, dists=None, children=None, is_unary=False): + self.dist = dist + self.fn = fn + self.dists = dists + self.children = children + self.is_unary = is_unary + if dist is not None: + self.is_leaf = True + elif fn is not None and dists is not None: + self.is_leaf = False + else: + raise ValueError("Missing arguments to FlatTree constructor") + + @classmethod + def leaf(cls, dist): + return FlatTree(dist=dist) + + @classmethod + def branch(cls, fn, left_tree, right_tree): + # make a list of possibly-joinable distributions, plus a list of + # children as trees who could not be simplified at this level + dists = [] + children = [] + is_unary = right_tree is None + if is_unary and right_tree is not None: + raise ValueError(f"Multiple arguments provided for unary operator {fn}") + if fn == operator.neg and left_tree.is_leaf: + dist = left_tree.dist + if isinstance(dist, NormalDistribution): + return cls.leaf(NormalDistribution(mean=-dist.mean, sd=dist.sd)) + if fn == operator.sub: + return cls.branch( + operator.add, + left_tree, + FlatTree.branch(operator.neg, right_tree, None) + ) + + if left_tree.is_leaf: + dists.append(left_tree.dist) + elif left_tree.fn == fn and fn in cls.COMMUTABLE_OPERATIONS: + dists.extend(left_tree.dists) + else: + children.append(left_tree) + if right_tree is not None: + if right_tree.is_leaf: + dists.append(right_tree.dist) + elif right_tree.fn == fn and fn in cls.COMMUTABLE_OPERATIONS: + dists.extend(right_tree.dists) + else: + children.append(right_tree) + + dists.sort(key=lambda d: type(d).__name__) + + return cls(fn=fn, dists=dists, children=children, is_unary=is_unary) + + def _join_dists(self, name, join_fn): + dist_indexes = [i for i in range(len(self.dists)) if type(self.dists[i]).__name__ == name] + if len(dist_indexes) == 0: + return None + first_index = dist_indexes[0] + for i in dist_indexes[1:]: + self.dists[first_index] = join_fn(self.dists[first_index], self.dists[i]) + self.dists[i] = None + self.dists = [d for d in self.dists if d is not None] + + def simplify(self): + if self.is_leaf: + return self.dist + + simplified_children = [child.simplify() for child in self.children] + if self.fn == operator.add: + self._join_dists("NormalDistribution", lambda x, y: NormalDistribution(mean=x.mean + y.mean, sd=np.sqrt(x.sd**2 + y.sd**2))) + self._join_dists("float", lambda x, y: x + y) + self._join_dists("int", lambda x, y: x + y) + elif self.fn == operator.mul: + self._join_dists("LognormalDistribution", lambda x, y: LognormalDistribution(norm_mean=x.norm_mean + y.norm_mean, norm_sd=np.sqrt(x.norm_sd**2 + y.norm_sd**2))) + self._join_dists("float", lambda x, y: x * y) + self._join_dists("int", lambda x, y: x * y) + + return reduce(lambda acc, x: ComplexDistribution(acc, x), simplified_children + self.dists) diff --git a/tests/test_simplify.py b/tests/test_simplify.py new file mode 100644 index 0000000..e5a5f68 --- /dev/null +++ b/tests/test_simplify.py @@ -0,0 +1,55 @@ +from hypothesis import assume, example, given +import hypothesis.strategies as st +import numpy as np +import pytest +from pytest import approx + +from ..squigglepy.distributions import * + +def test_simplify_normal_add(): + x = NormalDistribution(mean=1, sd=1) + y = NormalDistribution(mean=2, sd=2) + sum2 = x + y + simplified2 = sum2.simplify() + assert isinstance(simplified2, NormalDistribution) + assert simplified2.mean == 3 + assert simplified2.sd == approx(np.sqrt(5)) + +def test_simplify_normal_add3(): + x = NormalDistribution(mean=1, sd=1) + y = NormalDistribution(mean=2, sd=2) + z = NormalDistribution(mean=-3, sd=2) + sum3_left = (x + y) + z + sum3_right = x + (y + z) + simplified3_left = sum3_left.simplify() + simplified3_right = sum3_right.simplify() + assert isinstance(simplified3_left, NormalDistribution) + assert simplified3_left.mean == 0 + assert simplified3_left.sd == approx(np.sqrt(1 + 4 + 4)) + assert isinstance(simplified3_right, NormalDistribution) + assert simplified3_right.mean == 0 + assert simplified3_right.sd == approx(np.sqrt(1 + 4 + 4)) + +def test_simplify_lognorm_mul3(): + x = LognormalDistribution(norm_mean=1, norm_sd=1) + y = LognormalDistribution(norm_mean=2, norm_sd=2) + z = LognormalDistribution(norm_mean=-2, norm_sd=3) + product3_left = (x * y) * z + product3_right = x * (y * z) + simplified3_left = product3_left.simplify() + simplified3_right = product3_right.simplify() + assert isinstance(simplified3_left, LognormalDistribution) + assert simplified3_left.norm_mean == 1 + assert simplified3_left.norm_sd == approx(np.sqrt(1 + 4 + 9)) + assert isinstance(simplified3_right, LognormalDistribution) + assert simplified3_right.norm_mean == 1 + assert simplified3_right.norm_sd == approx(np.sqrt(1 + 4 + 9)) + +def test_simplify_sub(): + x = NormalDistribution(mean=1, sd=1) + y = NormalDistribution(mean=2, sd=2) + difference = x - y + simplified = difference.simplify() + assert isinstance(simplified, NormalDistribution) + assert simplified.mean == -1 + assert simplified.sd == approx(np.sqrt(5)) From a42abc4ce4816bb922f499846c3bb7fd52499113 Mon Sep 17 00:00:00 2001 From: Michael Dickens Date: Sat, 2 Dec 2023 23:09:51 -0800 Subject: [PATCH 02/10] analytic: implement scaling and refactor to be more extensible --- squigglepy/distributions.py | 104 ++++++++++++++++++++---------------- tests/test_simplify.py | 24 +++++++-- 2 files changed, 80 insertions(+), 48 deletions(-) diff --git a/squigglepy/distributions.py b/squigglepy/distributions.py index 158fd79..24a538d 100644 --- a/squigglepy/distributions.py +++ b/squigglepy/distributions.py @@ -173,9 +173,6 @@ def __rpow__(self, dist): def __hash__(self): return hash(repr(self)) - def _build_flat_tree(self): - return FlatTree.leaf(self) - def simplify(self): return self @@ -249,15 +246,8 @@ def __str__(self): raise ValueError return out - def _build_flat_tree(self): - left_tree = self.left._build_flat_tree() - right_tree = None - if self.right is not None: - right_tree = self.right._build_flat_tree() - return FlatTree.branch(self.fn, left_tree, right_tree) - def simplify(self): - return self._build_flat_tree().simplify() + return FlatTree.build(self).simplify() def _get_fname(f, name): @@ -1735,56 +1725,81 @@ def __init__(self, dist=None, fn=None, dists=None, children=None, is_unary=False raise ValueError("Missing arguments to FlatTree constructor") @classmethod - def leaf(cls, dist): - return FlatTree(dist=dist) + def build(cls, dist): + if isinstance(dist, int) or isinstance(dist, float): + return FlatTree(dist=float(dist)) + if not isinstance(dist, BaseDistribution): + import ipdb; ipdb.set_trace() + raise ValueError(f"dist must be a BaseDistribution or numeric type, not {type(dist)}") + if not isinstance(dist, ComplexDistribution): + return FlatTree(dist=dist) - @classmethod - def branch(cls, fn, left_tree, right_tree): # make a list of possibly-joinable distributions, plus a list of # children as trees who could not be simplified at this level dists = [] children = [] - is_unary = right_tree is None - if is_unary and right_tree is not None: - raise ValueError(f"Multiple arguments provided for unary operator {fn}") - if fn == operator.neg and left_tree.is_leaf: - dist = left_tree.dist - if isinstance(dist, NormalDistribution): - return cls.leaf(NormalDistribution(mean=-dist.mean, sd=dist.sd)) - if fn == operator.sub: - return cls.branch( - operator.add, - left_tree, - FlatTree.branch(operator.neg, right_tree, None) + is_unary = dist.right is None + if is_unary and dist.right is not None: + raise ValueError(f"Multiple arguments provided for unary operator {dist.fn}") + if dist.fn == operator.neg: + if isinstance(dist.left, NormalDistribution): + return cls.build(NormalDistribution(mean=-dist.left.mean, sd=dist.left.sd)) + if dist.fn == operator.sub: + return cls.build( + ComplexDistribution( + dist.left, + ComplexDistribution( + dist.right, + right=None, + fn=operator.neg, + fn_str="-" + ), + fn=operator.add, + fn_str="+" + ) ) + left_tree = cls.build(dist.left) + right_tree = cls.build(dist.right) + if left_tree.is_leaf: dists.append(left_tree.dist) - elif left_tree.fn == fn and fn in cls.COMMUTABLE_OPERATIONS: + elif left_tree.fn == dist.fn and dist.fn in cls.COMMUTABLE_OPERATIONS: dists.extend(left_tree.dists) else: children.append(left_tree) if right_tree is not None: if right_tree.is_leaf: dists.append(right_tree.dist) - elif right_tree.fn == fn and fn in cls.COMMUTABLE_OPERATIONS: + elif right_tree.fn == dist.fn and dist.fn in cls.COMMUTABLE_OPERATIONS: dists.extend(right_tree.dists) else: children.append(right_tree) dists.sort(key=lambda d: type(d).__name__) - return cls(fn=fn, dists=dists, children=children, is_unary=is_unary) + return cls(fn=dist.fn, dists=dists, children=children, is_unary=is_unary) - def _join_dists(self, name, join_fn): - dist_indexes = [i for i in range(len(self.dists)) if type(self.dists[i]).__name__ == name] - if len(dist_indexes) == 0: - return None - first_index = dist_indexes[0] - for i in dist_indexes[1:]: - self.dists[first_index] = join_fn(self.dists[first_index], self.dists[i]) - self.dists[i] = None - self.dists = [d for d in self.dists if d is not None] + def _join_dists(self, left_type, right_type, join_fn, commutative=True): + dists = [] + acc = None + acc_is_left = True + for x in self.dists: + if acc is None and isinstance(x, left_type): + acc = x + elif acc is None and commutative and isinstance(x, right_type): + acc = x + acc_is_left = False + elif acc is not None and isinstance(x, right_type) and acc_is_left: + acc = join_fn(acc, x) + elif acc is not None and commutative and isinstance(x, left_type) and not acc_is_left: + acc = join_fn(x, acc) + else: + dists.append(x) + + if acc is not None: + dists.insert(0, acc) + self.dists = dists def simplify(self): if self.is_leaf: @@ -1792,12 +1807,11 @@ def simplify(self): simplified_children = [child.simplify() for child in self.children] if self.fn == operator.add: - self._join_dists("NormalDistribution", lambda x, y: NormalDistribution(mean=x.mean + y.mean, sd=np.sqrt(x.sd**2 + y.sd**2))) - self._join_dists("float", lambda x, y: x + y) - self._join_dists("int", lambda x, y: x + y) + self._join_dists(NormalDistribution, NormalDistribution, lambda x, y: NormalDistribution(mean=x.mean + y.mean, sd=np.sqrt(x.sd**2 + y.sd**2))) + self._join_dists(NormalDistribution, float, lambda x, y: NormalDistribution(mean=x.mean + y, sd=x.sd)) elif self.fn == operator.mul: - self._join_dists("LognormalDistribution", lambda x, y: LognormalDistribution(norm_mean=x.norm_mean + y.norm_mean, norm_sd=np.sqrt(x.norm_sd**2 + y.norm_sd**2))) - self._join_dists("float", lambda x, y: x * y) - self._join_dists("int", lambda x, y: x * y) + self._join_dists(LognormalDistribution, LognormalDistribution, lambda x, y: LognormalDistribution(norm_mean=x.norm_mean + y.norm_mean, norm_sd=np.sqrt(x.norm_sd**2 + y.norm_sd**2))) + self._join_dists(LognormalDistribution, float, lambda x, y: LognormalDistribution(norm_mean=x.norm_mean * y, norm_sd=x.norm_sd)) + self._join_dists(NormalDistribution, float, lambda x, y: NormalDistribution(mean=x.mean * y, sd=x.sd * y)) return reduce(lambda acc, x: ComplexDistribution(acc, x), simplified_children + self.dists) diff --git a/tests/test_simplify.py b/tests/test_simplify.py index e5a5f68..85c4b77 100644 --- a/tests/test_simplify.py +++ b/tests/test_simplify.py @@ -6,7 +6,7 @@ from ..squigglepy.distributions import * -def test_simplify_normal_add(): +def test_simplify_add_normal(): x = NormalDistribution(mean=1, sd=1) y = NormalDistribution(mean=2, sd=2) sum2 = x + y @@ -15,7 +15,7 @@ def test_simplify_normal_add(): assert simplified2.mean == 3 assert simplified2.sd == approx(np.sqrt(5)) -def test_simplify_normal_add3(): +def test_simplify_add_3_normals(): x = NormalDistribution(mean=1, sd=1) y = NormalDistribution(mean=2, sd=2) z = NormalDistribution(mean=-3, sd=2) @@ -30,7 +30,25 @@ def test_simplify_normal_add3(): assert simplified3_right.mean == 0 assert simplified3_right.sd == approx(np.sqrt(1 + 4 + 4)) -def test_simplify_lognorm_mul3(): +def test_simplify_normal_plus_const(): + x = NormalDistribution(mean=0, sd=1) + y = 2 + sum2 = x + y + simplified = sum2.simplify() + assert isinstance(simplified, NormalDistribution) + assert simplified.mean == 2 + assert simplified.sd == approx(1) + +def simplify_scale_normal(): + x = NormalDistribution(mean=2, sd=4) + y = 1.5 + product2 = x * y + simplified = product2.simplify() + assert isinstance(simplified, NormalDistribution) + assert simplified.mean == approx(3) + assert simplified.sd == approx(6) + +def test_simplify_mul_3_normals(): x = LognormalDistribution(norm_mean=1, norm_sd=1) y = LognormalDistribution(norm_mean=2, norm_sd=2) z = LognormalDistribution(norm_mean=-2, norm_sd=3) From a941f4cd3764d6e208e7346a7e7dd8f7c203110c Mon Sep 17 00:00:00 2001 From: Michael Dickens Date: Sun, 3 Dec 2023 00:04:47 -0800 Subject: [PATCH 03/10] analytic: implement division and constant subtraction --- squigglepy/distributions.py | 96 +++++++++++++++++++++++++++++-------- tests/test_simplify.py | 55 ++++++++++++++++++--- 2 files changed, 125 insertions(+), 26 deletions(-) diff --git a/squigglepy/distributions.py b/squigglepy/distributions.py index 24a538d..f365584 100644 --- a/squigglepy/distributions.py +++ b/squigglepy/distributions.py @@ -13,6 +13,7 @@ from abc import ABC, abstractmethod from functools import reduce +from numbers import Real class BaseDistribution(ABC): @@ -1709,6 +1710,27 @@ def geometric(p): class FlatTree: + """Helper class for simplifying analytic expressions. A ``FlatTree`` is + sort of like a ``ComplexDistribution`` except that it flattens + commutative/associative operations onto a single object instead of having + one object per binary operation. + + This class operates in two phases. + + Phase 1: Generate a ``FlatTree`` object from a :ref:``BaseDistribution`` by + calling ``FlatTree.build(dist)``. This generates a tree where any series of + a single commutative/associative operation done repeatedly is flattened + onto a single ``FlatTree`` node. It also converts operations into a + normalized form, for example converting ``a - b`` into ``a + (-b)``. + + Phase 2: Generate a simplified ``Distribution`` by calling + :ref:``simplify``. This works by combing through each flat list of + distributions to find which ones can be analytically simplified (for + example, converting a sum of normal distributions into a single normal + distribution). + + """ + COMMUTABLE_OPERATIONS = set([operator.add, operator.mul]) def __init__(self, dist=None, fn=None, dists=None, children=None, is_unary=False): @@ -1726,24 +1748,27 @@ def __init__(self, dist=None, fn=None, dists=None, children=None, is_unary=False @classmethod def build(cls, dist): - if isinstance(dist, int) or isinstance(dist, float): - return FlatTree(dist=float(dist)) + if dist is None: + return None + if isinstance(dist, Real): + return FlatTree(dist=dist) if not isinstance(dist, BaseDistribution): - import ipdb; ipdb.set_trace() raise ValueError(f"dist must be a BaseDistribution or numeric type, not {type(dist)}") if not isinstance(dist, ComplexDistribution): return FlatTree(dist=dist) - # make a list of possibly-joinable distributions, plus a list of - # children as trees who could not be simplified at this level - dists = [] - children = [] is_unary = dist.right is None if is_unary and dist.right is not None: raise ValueError(f"Multiple arguments provided for unary operator {dist.fn}") + + # Simplify unary operations if dist.fn == operator.neg: + if isinstance(dist.left, Real): + return cls.build(-dist.left) if isinstance(dist.left, NormalDistribution): return cls.build(NormalDistribution(mean=-dist.left.mean, sd=dist.left.sd)) + + # Normalize binary operations if dist.fn == operator.sub: return cls.build( ComplexDistribution( @@ -1759,9 +1784,21 @@ def build(cls, dist): ) ) + # TODO: maybe use operator.invert (~) as the symbol for reciprocal. + # actually that's a bad idea because then simplify() might output + # content with ~ in it + left_tree = cls.build(dist.left) right_tree = cls.build(dist.right) + # Make a list of possibly-joinable distributions, plus a list of + # children as trees who could not be simplified at this level + dists = [] + children = [] + + # If the child nodes use the same commutable operation as ``dist``, add + # their flattened ``dists`` lists to ``dists``. Otherwise, put them in + # the irreducible list of ``children``. if left_tree.is_leaf: dists.append(left_tree.dist) elif left_tree.fn == dist.fn and dist.fn in cls.COMMUTABLE_OPERATIONS: @@ -1776,42 +1813,61 @@ def build(cls, dist): else: children.append(right_tree) - dists.sort(key=lambda d: type(d).__name__) + if dist.fn in cls.COMMUTABLE_OPERATIONS: + dists.sort(key=lambda d: type(d).__name__) return cls(fn=dist.fn, dists=dists, children=children, is_unary=is_unary) def _join_dists(self, left_type, right_type, join_fn, commutative=True): - dists = [] + simplified_dists = [] acc = None + acc_index = None acc_is_left = True - for x in self.dists: + for i, x in enumerate(self.dists): if acc is None and isinstance(x, left_type): acc = x - elif acc is None and commutative and isinstance(x, right_type): - acc = x - acc_is_left = False + acc_index = i elif acc is not None and isinstance(x, right_type) and acc_is_left: acc = join_fn(acc, x) - elif acc is not None and commutative and isinstance(x, left_type) and not acc_is_left: + elif commutative and acc is None and isinstance(x, right_type): + acc = x + acc_index = i + acc_is_left = False + elif commutative and acc is not None and isinstance(x, left_type) and not acc_is_left: acc = join_fn(x, acc) else: - dists.append(x) + simplified_dists.append(x) if acc is not None: - dists.insert(0, acc) - self.dists = dists + simplified_dists.insert(acc_index, acc) + self.dists = simplified_dists + + def _lognormal_times_const(self, norm_mean, norm_sd, k): + if k == 0: + return 0 + elif k > 0: + return LognormalDistribution(norm_mean=norm_mean + np.log(k), norm_sd=norm_sd) + else: + return -LognormalDistribution(norm_mean=norm_mean + np.log(-k), norm_sd=norm_sd) def simplify(self): if self.is_leaf: return self.dist simplified_children = [child.simplify() for child in self.children] + if self.fn == operator.add: self._join_dists(NormalDistribution, NormalDistribution, lambda x, y: NormalDistribution(mean=x.mean + y.mean, sd=np.sqrt(x.sd**2 + y.sd**2))) - self._join_dists(NormalDistribution, float, lambda x, y: NormalDistribution(mean=x.mean + y, sd=x.sd)) + self._join_dists(NormalDistribution, Real, lambda x, y: NormalDistribution(mean=x.mean + y, sd=x.sd)) + elif self.fn == operator.mul: self._join_dists(LognormalDistribution, LognormalDistribution, lambda x, y: LognormalDistribution(norm_mean=x.norm_mean + y.norm_mean, norm_sd=np.sqrt(x.norm_sd**2 + y.norm_sd**2))) - self._join_dists(LognormalDistribution, float, lambda x, y: LognormalDistribution(norm_mean=x.norm_mean * y, norm_sd=x.norm_sd)) - self._join_dists(NormalDistribution, float, lambda x, y: NormalDistribution(mean=x.mean * y, sd=x.sd * y)) + self._join_dists(LognormalDistribution, Real, lambda x, y: self._lognormal_times_const(x.norm_mean, x.norm_sd, y)) + self._join_dists(NormalDistribution, Real, lambda x, y: NormalDistribution(mean=x.mean * y, sd=x.sd * y)) + + elif self.fn == operator.truediv: + self._join_dists(LognormalDistribution, LognormalDistribution, lambda x, y: LognormalDistribution(norm_mean=x.norm_mean - y.norm_mean, norm_sd=np.sqrt(x.norm_sd ** 2 + y.norm_sd ** 2)), commutative=False) + self._join_dists(LognormalDistribution, Real, lambda x, y: self._lognormal_times_const(x.norm_mean, x.norm_sd, 1 / y), commutative=False) + self._join_dists(Real, LognormalDistribution, lambda x, y: self._lognormal_times_const(-y.norm_mean, y.norm_sd, x), commutative=False) return reduce(lambda acc, x: ComplexDistribution(acc, x), simplified_children + self.dists) diff --git a/tests/test_simplify.py b/tests/test_simplify.py index 85c4b77..bd911f6 100644 --- a/tests/test_simplify.py +++ b/tests/test_simplify.py @@ -33,11 +33,11 @@ def test_simplify_add_3_normals(): def test_simplify_normal_plus_const(): x = NormalDistribution(mean=0, sd=1) y = 2 - sum2 = x + y - simplified = sum2.simplify() + z = NormalDistribution(mean=1, sd=1) + simplified = (x + y + z).simplify() assert isinstance(simplified, NormalDistribution) - assert simplified.mean == 2 - assert simplified.sd == approx(1) + assert simplified.mean == 3 + assert simplified.sd == approx(np.sqrt(2)) def simplify_scale_normal(): x = NormalDistribution(mean=2, sd=4) @@ -48,7 +48,7 @@ def simplify_scale_normal(): assert simplified.mean == approx(3) assert simplified.sd == approx(6) -def test_simplify_mul_3_normals(): +def test_simplify_mul_3_lognorms(): x = LognormalDistribution(norm_mean=1, norm_sd=1) y = LognormalDistribution(norm_mean=2, norm_sd=2) z = LognormalDistribution(norm_mean=-2, norm_sd=3) @@ -63,7 +63,7 @@ def test_simplify_mul_3_normals(): assert simplified3_right.norm_mean == 1 assert simplified3_right.norm_sd == approx(np.sqrt(1 + 4 + 9)) -def test_simplify_sub(): +def test_simplify_sub_normals(): x = NormalDistribution(mean=1, sd=1) y = NormalDistribution(mean=2, sd=2) difference = x - y @@ -71,3 +71,46 @@ def test_simplify_sub(): assert isinstance(simplified, NormalDistribution) assert simplified.mean == -1 assert simplified.sd == approx(np.sqrt(5)) + +def test_simplify_normal_minus_const(): + x = NormalDistribution(mean=0, sd=1) + y = 2 + simplified = (x - y).simplify() + assert isinstance(simplified, NormalDistribution) + assert simplified.mean == -2 + assert simplified.sd == approx(1) + +@given( + norm_mean=st.floats(min_value=-10, max_value=10), + norm_sd=st.floats(min_value=0.1, max_value=3), + y=st.floats(min_value=0.1, max_value=10), +) +def test_simplify_div_lognormal_by_constant(norm_mean, norm_sd, y): + x = LognormalDistribution(norm_mean=norm_mean, norm_sd=norm_sd) + simplified = (x / y).simplify() + assert isinstance(simplified, LognormalDistribution) + assert simplified.norm_mean == approx(norm_mean - np.log(y)) + assert simplified.norm_sd == approx(norm_sd) + +@given( + x=st.floats(min_value=0.1, max_value=10), + norm_mean=st.floats(min_value=-10, max_value=10), + norm_sd=st.floats(min_value=0.1, max_value=3) +) +@example(x=1, norm_mean=0, norm_sd=1) +def test_simplify_div_constant_by_lognormal(x, norm_mean, norm_sd): + y = LognormalDistribution(norm_mean=norm_mean, norm_sd=norm_sd) + simplified = (x / y).simplify() + assert isinstance(simplified, LognormalDistribution) + assert simplified.norm_mean == approx(np.log(x) - norm_mean) + assert simplified.norm_sd == approx(norm_sd) + + +def test_simplify_lognormals(): + x = LognormalDistribution(norm_mean=1, norm_sd=1) + y = LognormalDistribution(norm_mean=2, norm_sd=2) + quotient = x / y + simplified = quotient.simplify() + assert isinstance(simplified, LognormalDistribution) + assert simplified.norm_mean == -1 + assert simplified.norm_sd == approx(np.sqrt(5)) From 2bc573de2270f9de5567b0784cead8e1ff127653 Mon Sep 17 00:00:00 2001 From: Michael Dickens Date: Sun, 3 Dec 2023 00:26:55 -0800 Subject: [PATCH 04/10] analytic: implement pow --- squigglepy/distributions.py | 121 ++++++++++++++++++++++++++++++------ tests/test_simplify.py | 46 +++++++++++++- 2 files changed, 147 insertions(+), 20 deletions(-) diff --git a/squigglepy/distributions.py b/squigglepy/distributions.py index f365584..5123b5a 100644 --- a/squigglepy/distributions.py +++ b/squigglepy/distributions.py @@ -1773,14 +1773,9 @@ def build(cls, dist): return cls.build( ComplexDistribution( dist.left, - ComplexDistribution( - dist.right, - right=None, - fn=operator.neg, - fn_str="-" - ), + ComplexDistribution(dist.right, right=None, fn=operator.neg, fn_str="-"), fn=operator.add, - fn_str="+" + fn_str="+", ) ) @@ -1818,7 +1813,7 @@ def build(cls, dist): return cls(fn=dist.fn, dists=dists, children=children, is_unary=is_unary) - def _join_dists(self, left_type, right_type, join_fn, commutative=True): + def _join_dists(self, left_type, right_type, join_fn, commutative=True, condition=None): simplified_dists = [] acc = None acc_index = None @@ -1827,13 +1822,24 @@ def _join_dists(self, left_type, right_type, join_fn, commutative=True): if acc is None and isinstance(x, left_type): acc = x acc_index = i - elif acc is not None and isinstance(x, right_type) and acc_is_left: + elif ( + acc is not None + and isinstance(x, right_type) + and acc_is_left + and (condition is None or condition(acc, x)) + ): acc = join_fn(acc, x) elif commutative and acc is None and isinstance(x, right_type): acc = x acc_index = i acc_is_left = False - elif commutative and acc is not None and isinstance(x, left_type) and not acc_is_left: + elif ( + commutative + and acc is not None + and isinstance(x, left_type) + and not acc_is_left + and (condition is None or condition(x, acc)) + ): acc = join_fn(x, acc) else: simplified_dists.append(x) @@ -1857,17 +1863,96 @@ def simplify(self): simplified_children = [child.simplify() for child in self.children] if self.fn == operator.add: - self._join_dists(NormalDistribution, NormalDistribution, lambda x, y: NormalDistribution(mean=x.mean + y.mean, sd=np.sqrt(x.sd**2 + y.sd**2))) - self._join_dists(NormalDistribution, Real, lambda x, y: NormalDistribution(mean=x.mean + y, sd=x.sd)) + self._join_dists( + NormalDistribution, + NormalDistribution, + lambda x, y: NormalDistribution( + mean=x.mean + y.mean, sd=np.sqrt(x.sd**2 + y.sd**2) + ), + ) + self._join_dists( + NormalDistribution, Real, lambda x, y: NormalDistribution(mean=x.mean + y, sd=x.sd) + ) elif self.fn == operator.mul: - self._join_dists(LognormalDistribution, LognormalDistribution, lambda x, y: LognormalDistribution(norm_mean=x.norm_mean + y.norm_mean, norm_sd=np.sqrt(x.norm_sd**2 + y.norm_sd**2))) - self._join_dists(LognormalDistribution, Real, lambda x, y: self._lognormal_times_const(x.norm_mean, x.norm_sd, y)) - self._join_dists(NormalDistribution, Real, lambda x, y: NormalDistribution(mean=x.mean * y, sd=x.sd * y)) + self._join_dists( + LognormalDistribution, + LognormalDistribution, + lambda x, y: LognormalDistribution( + norm_mean=x.norm_mean + y.norm_mean, + norm_sd=np.sqrt(x.norm_sd**2 + y.norm_sd**2), + ), + ) + self._join_dists( + LognormalDistribution, + Real, + lambda x, y: self._lognormal_times_const(x.norm_mean, x.norm_sd, y), + ) + self._join_dists( + NormalDistribution, + Real, + lambda x, y: NormalDistribution(mean=x.mean * y, sd=x.sd * y), + ) elif self.fn == operator.truediv: - self._join_dists(LognormalDistribution, LognormalDistribution, lambda x, y: LognormalDistribution(norm_mean=x.norm_mean - y.norm_mean, norm_sd=np.sqrt(x.norm_sd ** 2 + y.norm_sd ** 2)), commutative=False) - self._join_dists(LognormalDistribution, Real, lambda x, y: self._lognormal_times_const(x.norm_mean, x.norm_sd, 1 / y), commutative=False) - self._join_dists(Real, LognormalDistribution, lambda x, y: self._lognormal_times_const(-y.norm_mean, y.norm_sd, x), commutative=False) + self._join_dists( + LognormalDistribution, + LognormalDistribution, + lambda x, y: LognormalDistribution( + norm_mean=x.norm_mean - y.norm_mean, + norm_sd=np.sqrt(x.norm_sd**2 + y.norm_sd**2), + ), + commutative=False, + ) + self._join_dists( + LognormalDistribution, + Real, + lambda x, y: self._lognormal_times_const(x.norm_mean, x.norm_sd, 1 / y), + commutative=False, + ) + self._join_dists( + Real, + LognormalDistribution, + lambda x, y: self._lognormal_times_const(-y.norm_mean, y.norm_sd, x), + commutative=False, + ) + self._join_dists( + NormalDistribution, + Real, + lambda x, y: NormalDistribution(mean=x.mean / y, sd=x.sd / y), + commutative=False, + ) + self._join_dists( + Real, + NormalDistribution, + lambda x, y: NormalDistribution(mean=x / y.mean, sd=x / y.sd), + commutative=False, + ) + + elif self.fn == operator.pow: + self._join_dists( + LognormalDistribution, + Real, + lambda x, y: LognormalDistribution( + norm_mean=x.norm_mean * y, norm_sd=x.norm_sd * y + ), + commutative=False, + condition=lambda x, y: y > 0, + ) + + # There are a few other operations that can be simplified but they're a + # bit more obscure: + # + # Bernoulli + Bernoulli = binomial + # binomial + binomial = binomial + # Poisson + Poisson = Poisson + # Cauchy + Cauchy = Cauchy + # Gamma + Gamma = Gamma + # Chi^2 + Chi^2 = Chi^2 + # central normal / central normal = Cauchy + # + # There are also a lot of known analytic expressions for combinations + # of distributions but where the analytic expression isn't a named + # distribution (and some of them are really complicated). return reduce(lambda acc, x: ComplexDistribution(acc, x), simplified_children + self.dists) diff --git a/tests/test_simplify.py b/tests/test_simplify.py index bd911f6..bf65411 100644 --- a/tests/test_simplify.py +++ b/tests/test_simplify.py @@ -6,6 +6,7 @@ from ..squigglepy.distributions import * + def test_simplify_add_normal(): x = NormalDistribution(mean=1, sd=1) y = NormalDistribution(mean=2, sd=2) @@ -15,6 +16,7 @@ def test_simplify_add_normal(): assert simplified2.mean == 3 assert simplified2.sd == approx(np.sqrt(5)) + def test_simplify_add_3_normals(): x = NormalDistribution(mean=1, sd=1) y = NormalDistribution(mean=2, sd=2) @@ -30,6 +32,7 @@ def test_simplify_add_3_normals(): assert simplified3_right.mean == 0 assert simplified3_right.sd == approx(np.sqrt(1 + 4 + 4)) + def test_simplify_normal_plus_const(): x = NormalDistribution(mean=0, sd=1) y = 2 @@ -39,6 +42,7 @@ def test_simplify_normal_plus_const(): assert simplified.mean == 3 assert simplified.sd == approx(np.sqrt(2)) + def simplify_scale_normal(): x = NormalDistribution(mean=2, sd=4) y = 1.5 @@ -48,6 +52,7 @@ def simplify_scale_normal(): assert simplified.mean == approx(3) assert simplified.sd == approx(6) + def test_simplify_mul_3_lognorms(): x = LognormalDistribution(norm_mean=1, norm_sd=1) y = LognormalDistribution(norm_mean=2, norm_sd=2) @@ -63,6 +68,7 @@ def test_simplify_mul_3_lognorms(): assert simplified3_right.norm_mean == 1 assert simplified3_right.norm_sd == approx(np.sqrt(1 + 4 + 9)) + def test_simplify_sub_normals(): x = NormalDistribution(mean=1, sd=1) y = NormalDistribution(mean=2, sd=2) @@ -72,6 +78,7 @@ def test_simplify_sub_normals(): assert simplified.mean == -1 assert simplified.sd == approx(np.sqrt(5)) + def test_simplify_normal_minus_const(): x = NormalDistribution(mean=0, sd=1) y = 2 @@ -80,6 +87,7 @@ def test_simplify_normal_minus_const(): assert simplified.mean == -2 assert simplified.sd == approx(1) + @given( norm_mean=st.floats(min_value=-10, max_value=10), norm_sd=st.floats(min_value=0.1, max_value=3), @@ -92,10 +100,11 @@ def test_simplify_div_lognormal_by_constant(norm_mean, norm_sd, y): assert simplified.norm_mean == approx(norm_mean - np.log(y)) assert simplified.norm_sd == approx(norm_sd) + @given( x=st.floats(min_value=0.1, max_value=10), norm_mean=st.floats(min_value=-10, max_value=10), - norm_sd=st.floats(min_value=0.1, max_value=3) + norm_sd=st.floats(min_value=0.1, max_value=3), ) @example(x=1, norm_mean=0, norm_sd=1) def test_simplify_div_constant_by_lognormal(x, norm_mean, norm_sd): @@ -106,7 +115,7 @@ def test_simplify_div_constant_by_lognormal(x, norm_mean, norm_sd): assert simplified.norm_sd == approx(norm_sd) -def test_simplify_lognormals(): +def test_simplify_div_lognormals(): x = LognormalDistribution(norm_mean=1, norm_sd=1) y = LognormalDistribution(norm_mean=2, norm_sd=2) quotient = x / y @@ -114,3 +123,36 @@ def test_simplify_lognormals(): assert isinstance(simplified, LognormalDistribution) assert simplified.norm_mean == -1 assert simplified.norm_sd == approx(np.sqrt(5)) + + +def test_simplify_div_normal_by_const(): + x = NormalDistribution(mean=3, sd=1) + y = 2 + simplified = (x / y).simplify() + assert isinstance(simplified, NormalDistribution) + assert simplified.mean == approx(1.5) + assert simplified.sd == approx(0.5) + + +def test_simplify_div_const_by_normal(): + x = 2 + y = NormalDistribution(mean=3, sd=2) + simplified = (x / y).simplify() + assert isinstance(simplified, NormalDistribution) + assert simplified.mean == approx(2 / 3) + assert simplified.sd == approx(1) + + +def test_simplify_lognorm_pow(): + x = LognormalDistribution(norm_mean=3, norm_sd=2) + y = 2 + product = x ** y + simplified = product.simplify() + assert isinstance(simplified, LognormalDistribution) + assert simplified.norm_mean == approx(6) + assert simplified.norm_sd == approx(4) + + y = -1 + product = x ** y + simplified = product.simplify() + assert isinstance(simplified, ComplexDistribution) From 59a4bb9130ca1787c6139855a39bf08dd8ab38b7 Mon Sep 17 00:00:00 2001 From: Michael Dickens Date: Sun, 3 Dec 2023 00:31:21 -0800 Subject: [PATCH 05/10] analytic: don't simplify clipped dists --- squigglepy/distributions.py | 5 ++++- tests/test_simplify.py | 19 +++++++++++++++++-- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/squigglepy/distributions.py b/squigglepy/distributions.py index 5123b5a..3a0589a 100644 --- a/squigglepy/distributions.py +++ b/squigglepy/distributions.py @@ -1819,7 +1819,10 @@ def _join_dists(self, left_type, right_type, join_fn, commutative=True, conditio acc_index = None acc_is_left = True for i, x in enumerate(self.dists): - if acc is None and isinstance(x, left_type): + if isinstance(x, BaseDistribution) and (x.lclip is not None or x.rclip is not None): + # We can't simplify a clipped distribution + simplified_dists.append(x) + elif acc is None and isinstance(x, left_type): acc = x acc_index = i elif ( diff --git a/tests/test_simplify.py b/tests/test_simplify.py index bf65411..1810fb4 100644 --- a/tests/test_simplify.py +++ b/tests/test_simplify.py @@ -146,13 +146,28 @@ def test_simplify_div_const_by_normal(): def test_simplify_lognorm_pow(): x = LognormalDistribution(norm_mean=3, norm_sd=2) y = 2 - product = x ** y + product = x**y simplified = product.simplify() assert isinstance(simplified, LognormalDistribution) assert simplified.norm_mean == approx(6) assert simplified.norm_sd == approx(4) y = -1 - product = x ** y + product = x**y simplified = product.simplify() assert isinstance(simplified, ComplexDistribution) + + +def test_simplify_clipped_normal(): + x = NormalDistribution(mean=1, sd=1) + y = NormalDistribution(mean=0, sd=1, lclip=1) + z = NormalDistribution(mean=3, sd=2) + simplified = (x + y + z).simplify() + assert isinstance(simplified, ComplexDistribution) + assert isinstance(simplified.left, NormalDistribution) + assert isinstance(simplified.right, NormalDistribution) + assert simplified.left.mean == 4 + assert simplified.left.sd == approx(np.sqrt(5)) + assert simplified.right.mean == 0 + assert simplified.right.sd == 1 + assert simplified.right.lclip == 1 From 275971fc65d3caf7caefae2bd38bcb50f2e1bda6 Mon Sep 17 00:00:00 2001 From: Michael Dickens Date: Sun, 3 Dec 2023 13:17:10 -0800 Subject: [PATCH 06/10] analytic: get nested trees working and support more distributions --- squigglepy/distributions.py | 210 +++++++++++++++++++++++++----------- tests/test_simplify.py | 151 +++++++++++++++++++------- 2 files changed, 263 insertions(+), 98 deletions(-) diff --git a/squigglepy/distributions.py b/squigglepy/distributions.py index 3a0589a..defc5ef 100644 --- a/squigglepy/distributions.py +++ b/squigglepy/distributions.py @@ -175,6 +175,40 @@ def __hash__(self): return hash(repr(self)) def simplify(self): + """Simplify a distribution by evaluating all operations that can be + performed analytically, for example by reducing a sum of normal + distributions into a single normal distribution. Return a new + ``OperableDistribution`` that represents the simplified distribution. + + Possible simplifications: + any - any = any + (-any) + any / constant = any * (1 / constant) + + -Normal = Normal + + Normal + Normal = Normal + constant + Normal = Normal + Bernoulli + Bernoulli = Binomial + Bernoulli + Binomial = Binomial + Binomial + Binomial = Binomial + Chi-Square + Chi-Square = Chi-Square + Exponential + Exponential = Exponential + Exponential + Gamma = Gamma + Gamma + Gamma = Gamma + Poisson + Poisson = Poisson + + constant * Normal = Normal + Lognormal * Lognormal = Lognormal + constant * Lognormal = Lognormal + constant * Exponential = Exponential + constant * Gamma = Gamma + + Lognormal / Lognormal = Lognormal + constant / Lognormal = Lognormal + + Lognormal ** constant = Lognormal + + """ return self @@ -1733,29 +1767,38 @@ class FlatTree: COMMUTABLE_OPERATIONS = set([operator.add, operator.mul]) - def __init__(self, dist=None, fn=None, dists=None, children=None, is_unary=False): + def __init__(self, dist=None, fn=None, fn_str=None, children=None, is_unary=False): self.dist = dist self.fn = fn - self.dists = dists + self.fn_str = fn_str self.children = children self.is_unary = is_unary if dist is not None: self.is_leaf = True - elif fn is not None and dists is not None: + elif fn is not None and children is not None: self.is_leaf = False else: raise ValueError("Missing arguments to FlatTree constructor") + def __str__(self): + if self.is_leaf: + return f"FlatTree({self.dist})" + else: + return "FlatTree({})[{}]".format(self.fn_str, ", ".join(map(str, self.children))) + + def __repr__(self): + return str(self) + @classmethod def build(cls, dist): if dist is None: return None if isinstance(dist, Real): - return FlatTree(dist=dist) + return cls(dist=dist) if not isinstance(dist, BaseDistribution): raise ValueError(f"dist must be a BaseDistribution or numeric type, not {type(dist)}") if not isinstance(dist, ComplexDistribution): - return FlatTree(dist=dist) + return cls(dist=dist) is_unary = dist.right is None if is_unary and dist.right is not None: @@ -1768,7 +1811,7 @@ def build(cls, dist): if isinstance(dist.left, NormalDistribution): return cls.build(NormalDistribution(mean=-dist.left.mean, sd=dist.left.sd)) - # Normalize binary operations + # Convert x - y into x + (-y) if dist.fn == operator.sub: return cls.build( ComplexDistribution( @@ -1779,46 +1822,52 @@ def build(cls, dist): ) ) - # TODO: maybe use operator.invert (~) as the symbol for reciprocal. - # actually that's a bad idea because then simplify() might output - # content with ~ in it + # If the denominator is a constant, replace division by constant + # with multiplication by the reciprocal of the constant + if dist.fn == operator.truediv and isinstance(dist.right, Real): + if dist.right == 0: + raise ZeroDivisionError("Division by zero in ComplexDistribution: {dist}") + return cls.build( + ComplexDistribution( + dist.left, + 1 / dist.right, + fn=operator.mul, + fn_str="*", + ) + ) left_tree = cls.build(dist.left) right_tree = cls.build(dist.right) # Make a list of possibly-joinable distributions, plus a list of # children as trees who could not be simplified at this level - dists = [] children = [] # If the child nodes use the same commutable operation as ``dist``, add - # their flattened ``dists`` lists to ``dists``. Otherwise, put them in - # the irreducible list of ``children``. + # their flattened ``children`` lists to ``children``. Otherwise, put + # the whole node in ``children``. if left_tree.is_leaf: - dists.append(left_tree.dist) + children.append(left_tree.dist) elif left_tree.fn == dist.fn and dist.fn in cls.COMMUTABLE_OPERATIONS: - dists.extend(left_tree.dists) + children.extend(left_tree.children) else: children.append(left_tree) if right_tree is not None: if right_tree.is_leaf: - dists.append(right_tree.dist) + children.append(right_tree.dist) elif right_tree.fn == dist.fn and dist.fn in cls.COMMUTABLE_OPERATIONS: - dists.extend(right_tree.dists) + children.extend(right_tree.children) else: children.append(right_tree) - if dist.fn in cls.COMMUTABLE_OPERATIONS: - dists.sort(key=lambda d: type(d).__name__) - - return cls(fn=dist.fn, dists=dists, children=children, is_unary=is_unary) + return cls(fn=dist.fn, fn_str=dist.fn_str, children=children, is_unary=is_unary) def _join_dists(self, left_type, right_type, join_fn, commutative=True, condition=None): simplified_dists = [] acc = None acc_index = None acc_is_left = True - for i, x in enumerate(self.dists): + for i, x in enumerate(self.children): if isinstance(x, BaseDistribution) and (x.lclip is not None or x.rclip is not None): # We can't simplify a clipped distribution simplified_dists.append(x) @@ -1849,7 +1898,7 @@ def _join_dists(self, left_type, right_type, join_fn, commutative=True, conditio if acc is not None: simplified_dists.insert(acc_index, acc) - self.dists = simplified_dists + self.children = simplified_dists def _lognormal_times_const(self, norm_mean, norm_sd, k): if k == 0: @@ -1860,10 +1909,14 @@ def _lognormal_times_const(self, norm_mean, norm_sd, k): return -LognormalDistribution(norm_mean=norm_mean + np.log(-k), norm_sd=norm_sd) def simplify(self): + """Convert a FlatTree back into a Distribution, simplifying as much as + possible.""" if self.is_leaf: return self.dist - simplified_children = [child.simplify() for child in self.children] + for i in range(len(self.children)): + if isinstance(self.children[i], FlatTree): + self.children[i] = self.children[i].simplify() if self.fn == operator.add: self._join_dists( @@ -1874,10 +1927,63 @@ def simplify(self): ), ) self._join_dists( - NormalDistribution, Real, lambda x, y: NormalDistribution(mean=x.mean + y, sd=x.sd) + NormalDistribution, + Real, + lambda x, y: NormalDistribution(mean=x.mean + y, sd=x.sd) + ) + self._join_dists( + BernoulliDistribution, + BernoulliDistribution, + lambda x, y: BinomialDistribution(n=2, p=x.p), + condition=lambda x, y: x.p == y.p, + ) + self._join_dists( + BinomialDistribution, + BernoulliDistribution, + lambda x, y: BinomialDistribution(n=x.n + 1, p=x.p), + condition=lambda x, y: x.p == y.p, + ) + self._join_dists( + BinomialDistribution, + BinomialDistribution, + lambda x, y: BinomialDistribution(n=x.n + y.n, p=x.p), + condition=lambda x, y: x.p == y.p, + ) + self._join_dists( + ChiSquareDistribution, + ChiSquareDistribution, + lambda x, y: ChiSquareDistribution(df=x.df + y.df), + ) + self._join_dists( + ExponentialDistribution, + ExponentialDistribution, + lambda x, y: GammaDistribution(shape=2, scale=x.scale), + condition=lambda x, y: x.scale == y.scale, + ) + self._join_dists( + ExponentialDistribution, + GammaDistribution, + lambda x, y: GammaDistribution(shape=y.shape + 1, scale=x.scale), + condition=lambda x, y: x.scale == y.scale, + ) + self._join_dists( + GammaDistribution, + GammaDistribution, + lambda x, y: GammaDistribution(shape=x.shape + y.shape, scale=x.scale), + condition=lambda x, y: x.scale == y.scale, + ) + self._join_dists( + PoissonDistribution, + PoissonDistribution, + lambda x, y: PoissonDistribution(lam=x.lam + y.lam), ) elif self.fn == operator.mul: + self._join_dists( + NormalDistribution, + Real, + lambda x, y: NormalDistribution(mean=x.mean * y, sd=x.sd * y), + ) self._join_dists( LognormalDistribution, LognormalDistribution, @@ -1892,9 +1998,14 @@ def simplify(self): lambda x, y: self._lognormal_times_const(x.norm_mean, x.norm_sd, y), ) self._join_dists( - NormalDistribution, + ExponentialDistribution, Real, - lambda x, y: NormalDistribution(mean=x.mean * y, sd=x.sd * y), + lambda x, y: ExponentialDistribution(scale=x.scale * y), + ) + self._join_dists( + GammaDistribution, + Real, + lambda x, y: GammaDistribution(shape=x.shape, scale=x.scale * y), ) elif self.fn == operator.truediv: @@ -1907,30 +2018,24 @@ def simplify(self): ), commutative=False, ) - self._join_dists( - LognormalDistribution, - Real, - lambda x, y: self._lognormal_times_const(x.norm_mean, x.norm_sd, 1 / y), - commutative=False, - ) + # self._join_dists( + # LognormalDistribution, + # Real, + # lambda x, y: self._lognormal_times_const(x.norm_mean, x.norm_sd, 1 / y), + # commutative=False, + # ) self._join_dists( Real, LognormalDistribution, lambda x, y: self._lognormal_times_const(-y.norm_mean, y.norm_sd, x), commutative=False, ) - self._join_dists( - NormalDistribution, - Real, - lambda x, y: NormalDistribution(mean=x.mean / y, sd=x.sd / y), - commutative=False, - ) - self._join_dists( - Real, - NormalDistribution, - lambda x, y: NormalDistribution(mean=x / y.mean, sd=x / y.sd), - commutative=False, - ) + # self._join_dists( + # NormalDistribution, + # Real, + # lambda x, y: NormalDistribution(mean=x.mean / y, sd=x.sd / y), + # commutative=False, + # ) elif self.fn == operator.pow: self._join_dists( @@ -1943,19 +2048,4 @@ def simplify(self): condition=lambda x, y: y > 0, ) - # There are a few other operations that can be simplified but they're a - # bit more obscure: - # - # Bernoulli + Bernoulli = binomial - # binomial + binomial = binomial - # Poisson + Poisson = Poisson - # Cauchy + Cauchy = Cauchy - # Gamma + Gamma = Gamma - # Chi^2 + Chi^2 = Chi^2 - # central normal / central normal = Cauchy - # - # There are also a lot of known analytic expressions for combinations - # of distributions but where the analytic expression isn't a named - # distribution (and some of them are really complicated). - - return reduce(lambda acc, x: ComplexDistribution(acc, x), simplified_children + self.dists) + return reduce(lambda acc, x: ComplexDistribution(acc, x), self.children) diff --git a/tests/test_simplify.py b/tests/test_simplify.py index 1810fb4..11b37a9 100644 --- a/tests/test_simplify.py +++ b/tests/test_simplify.py @@ -7,9 +7,9 @@ from ..squigglepy.distributions import * -def test_simplify_add_normal(): - x = NormalDistribution(mean=1, sd=1) - y = NormalDistribution(mean=2, sd=2) +def test_simplify_add_norm(): + x = norm(mean=1, sd=1) + y = norm(mean=2, sd=2) sum2 = x + y simplified2 = sum2.simplify() assert isinstance(simplified2, NormalDistribution) @@ -18,9 +18,9 @@ def test_simplify_add_normal(): def test_simplify_add_3_normals(): - x = NormalDistribution(mean=1, sd=1) - y = NormalDistribution(mean=2, sd=2) - z = NormalDistribution(mean=-3, sd=2) + x = norm(mean=1, sd=1) + y = norm(mean=2, sd=2) + z = norm(mean=-3, sd=2) sum3_left = (x + y) + z sum3_right = x + (y + z) simplified3_left = sum3_left.simplify() @@ -34,17 +34,17 @@ def test_simplify_add_3_normals(): def test_simplify_normal_plus_const(): - x = NormalDistribution(mean=0, sd=1) + x = norm(mean=0, sd=1) y = 2 - z = NormalDistribution(mean=1, sd=1) + z = norm(mean=1, sd=1) simplified = (x + y + z).simplify() assert isinstance(simplified, NormalDistribution) assert simplified.mean == 3 assert simplified.sd == approx(np.sqrt(2)) -def simplify_scale_normal(): - x = NormalDistribution(mean=2, sd=4) +def simplify_scale_norm(): + x = norm(mean=2, sd=4) y = 1.5 product2 = x * y simplified = product2.simplify() @@ -54,9 +54,9 @@ def simplify_scale_normal(): def test_simplify_mul_3_lognorms(): - x = LognormalDistribution(norm_mean=1, norm_sd=1) - y = LognormalDistribution(norm_mean=2, norm_sd=2) - z = LognormalDistribution(norm_mean=-2, norm_sd=3) + x = lognorm(norm_mean=1, norm_sd=1) + y = lognorm(norm_mean=2, norm_sd=2) + z = lognorm(norm_mean=-2, norm_sd=3) product3_left = (x * y) * z product3_right = x * (y * z) simplified3_left = product3_left.simplify() @@ -70,8 +70,8 @@ def test_simplify_mul_3_lognorms(): def test_simplify_sub_normals(): - x = NormalDistribution(mean=1, sd=1) - y = NormalDistribution(mean=2, sd=2) + x = norm(mean=1, sd=1) + y = norm(mean=2, sd=2) difference = x - y simplified = difference.simplify() assert isinstance(simplified, NormalDistribution) @@ -80,7 +80,7 @@ def test_simplify_sub_normals(): def test_simplify_normal_minus_const(): - x = NormalDistribution(mean=0, sd=1) + x = norm(mean=0, sd=1) y = 2 simplified = (x - y).simplify() assert isinstance(simplified, NormalDistribution) @@ -93,8 +93,8 @@ def test_simplify_normal_minus_const(): norm_sd=st.floats(min_value=0.1, max_value=3), y=st.floats(min_value=0.1, max_value=10), ) -def test_simplify_div_lognormal_by_constant(norm_mean, norm_sd, y): - x = LognormalDistribution(norm_mean=norm_mean, norm_sd=norm_sd) +def test_simplify_div_lognorm_by_constant(norm_mean, norm_sd, y): + x = lognorm(norm_mean=norm_mean, norm_sd=norm_sd) simplified = (x / y).simplify() assert isinstance(simplified, LognormalDistribution) assert simplified.norm_mean == approx(norm_mean - np.log(y)) @@ -107,17 +107,17 @@ def test_simplify_div_lognormal_by_constant(norm_mean, norm_sd, y): norm_sd=st.floats(min_value=0.1, max_value=3), ) @example(x=1, norm_mean=0, norm_sd=1) -def test_simplify_div_constant_by_lognormal(x, norm_mean, norm_sd): - y = LognormalDistribution(norm_mean=norm_mean, norm_sd=norm_sd) +def test_simplify_div_constant_by_lognorm(x, norm_mean, norm_sd): + y = lognorm(norm_mean=norm_mean, norm_sd=norm_sd) simplified = (x / y).simplify() assert isinstance(simplified, LognormalDistribution) assert simplified.norm_mean == approx(np.log(x) - norm_mean) assert simplified.norm_sd == approx(norm_sd) -def test_simplify_div_lognormals(): - x = LognormalDistribution(norm_mean=1, norm_sd=1) - y = LognormalDistribution(norm_mean=2, norm_sd=2) +def test_simplify_div_lognorms(): + x = lognorm(norm_mean=1, norm_sd=1) + y = lognorm(norm_mean=2, norm_sd=2) quotient = x / y simplified = quotient.simplify() assert isinstance(simplified, LognormalDistribution) @@ -126,7 +126,7 @@ def test_simplify_div_lognormals(): def test_simplify_div_normal_by_const(): - x = NormalDistribution(mean=3, sd=1) + x = norm(mean=3, sd=1) y = 2 simplified = (x / y).simplify() assert isinstance(simplified, NormalDistribution) @@ -134,17 +134,8 @@ def test_simplify_div_normal_by_const(): assert simplified.sd == approx(0.5) -def test_simplify_div_const_by_normal(): - x = 2 - y = NormalDistribution(mean=3, sd=2) - simplified = (x / y).simplify() - assert isinstance(simplified, NormalDistribution) - assert simplified.mean == approx(2 / 3) - assert simplified.sd == approx(1) - - def test_simplify_lognorm_pow(): - x = LognormalDistribution(norm_mean=3, norm_sd=2) + x = lognorm(norm_mean=3, norm_sd=2) y = 2 product = x**y simplified = product.simplify() @@ -158,10 +149,10 @@ def test_simplify_lognorm_pow(): assert isinstance(simplified, ComplexDistribution) -def test_simplify_clipped_normal(): - x = NormalDistribution(mean=1, sd=1) - y = NormalDistribution(mean=0, sd=1, lclip=1) - z = NormalDistribution(mean=3, sd=2) +def test_simplify_clipped_norm(): + x = norm(mean=1, sd=1) + y = norm(mean=0, sd=1, lclip=1) + z = norm(mean=3, sd=2) simplified = (x + y + z).simplify() assert isinstance(simplified, ComplexDistribution) assert isinstance(simplified.left, NormalDistribution) @@ -171,3 +162,87 @@ def test_simplify_clipped_normal(): assert simplified.right.mean == 0 assert simplified.right.sd == 1 assert simplified.right.lclip == 1 + + +def test_simplify_bernoulli_sum(): + simplified = (bernoulli(p=0.5) + bernoulli(p=0.5)).simplify() + assert isinstance(simplified, BinomialDistribution) + assert simplified.n == 2 + assert simplified.p == 0.5 + + # Cannot simplify if the probabilities are different + simplified = (bernoulli(p=0.5) + bernoulli(p=0.6)).simplify() + assert isinstance(simplified, ComplexDistribution) + +def test_simplify_bernoulli_plus_binomial(): + simplified = (bernoulli(p=0.5) + binomial(n=10, p=0.2) + binomial(n=2, p=0.5) + binomial(n=3, p=0.5) + bernoulli(p=0.5)).simplify() + assert isinstance(simplified, ComplexDistribution) + assert isinstance(simplified.left, BinomialDistribution) + assert isinstance(simplified.right, BinomialDistribution) + assert simplified.left.n == 7 + assert simplified.left.p == 0.5 + assert simplified.right.n == 10 + assert simplified.right.p == 0.2 + + +def test_simplify_gamma_sum(): + simplified = (gamma(shape=2, scale=1) + gamma(shape=4, scale=1)).simplify() + assert isinstance(simplified, GammaDistribution) + assert simplified.shape == 6 + assert simplified.scale == 1 + + # Cannot simplify if the scales are different + simplified = (gamma(shape=2, scale=1) + gamma(shape=2, scale=2)).simplify() + assert isinstance(simplified, ComplexDistribution) + + +def test_simplify_scale_gamma(): + simplified = (2 * gamma(shape=2, scale=3)).simplify() + assert isinstance(simplified, GammaDistribution) + assert simplified.shape == 2 + assert simplified.scale == 6 + + +def test_simplify_exponential_sum(): + simplified = (exponential(scale=2) + exponential(scale=2)).simplify() + assert isinstance(simplified, GammaDistribution) + assert simplified.shape == 2 + assert simplified.scale == 2 + + # Cannot simplify if the rates are different + simplified = (exponential(scale=2) + exponential(scale=3)).simplify() + assert isinstance(simplified, ComplexDistribution) + + +def test_simplify_exponential_gamma_sum(): + simplified = (5 * (exponential(scale=3) + gamma(shape=2, scale=3) + exponential(scale=3))).simplify() + assert isinstance(simplified, GammaDistribution) + assert simplified.shape == 4 + assert simplified.scale == 15 + + +def test_simplify_big_sum(): + simplified = ( + 2 * norm(mean=1, sd=1) + + lognorm(norm_mean=1, norm_sd=1) * lognorm(norm_mean=3, norm_sd=2) + + gamma(shape=2, scale=3) + + exponential(scale=3) / 5 + + exponential(scale=3) + - norm(mean=2, sd=1) + ).simplify() + + # simplifies to norm(0, sqrt(5)) + lognorm(4, sqrt(5)) + gamma(3, 3) + exponential(6) + assert isinstance(simplified, ComplexDistribution) + assert isinstance(simplified.right, ExponentialDistribution) + assert simplified.right.scale == approx(3 / 5) + assert isinstance(simplified.left, ComplexDistribution) + assert isinstance(simplified.left.right, GammaDistribution) + assert simplified.left.right.shape == 3 + assert simplified.left.right.scale == 3 + assert isinstance(simplified.left.left, ComplexDistribution) + assert isinstance(simplified.left.left.right, LognormalDistribution) + assert simplified.left.left.right.norm_mean == 4 + assert simplified.left.left.right.norm_sd == approx(np.sqrt(5)) + assert isinstance(simplified.left.left.left, NormalDistribution) + assert simplified.left.left.left.mean == 0 + assert simplified.left.left.left.sd == approx(np.sqrt(5)) From 2837a7765ce88122fe1d80659d282c8d08801534 Mon Sep 17 00:00:00 2001 From: Michael Dickens Date: Sun, 3 Dec 2023 14:07:41 -0800 Subject: [PATCH 07/10] analytic: bugfixes for unary operators --- squigglepy/distributions.py | 28 +++++++++++++++++----------- tests/test_simplify.py | 14 ++++++++++++++ 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/squigglepy/distributions.py b/squigglepy/distributions.py index defc5ef..51b4983 100644 --- a/squigglepy/distributions.py +++ b/squigglepy/distributions.py @@ -1767,12 +1767,13 @@ class FlatTree: COMMUTABLE_OPERATIONS = set([operator.add, operator.mul]) - def __init__(self, dist=None, fn=None, fn_str=None, children=None, is_unary=False): + def __init__(self, dist=None, fn=None, fn_str=None, children=None, is_unary=False, infix=None): self.dist = dist self.fn = fn self.fn_str = fn_str self.children = children self.is_unary = is_unary + self.infix = infix if dist is not None: self.is_leaf = True elif fn is not None and children is not None: @@ -1804,13 +1805,6 @@ def build(cls, dist): if is_unary and dist.right is not None: raise ValueError(f"Multiple arguments provided for unary operator {dist.fn}") - # Simplify unary operations - if dist.fn == operator.neg: - if isinstance(dist.left, Real): - return cls.build(-dist.left) - if isinstance(dist.left, NormalDistribution): - return cls.build(NormalDistribution(mean=-dist.left.mean, sd=dist.left.sd)) - # Convert x - y into x + (-y) if dist.fn == operator.sub: return cls.build( @@ -1860,7 +1854,7 @@ def build(cls, dist): else: children.append(right_tree) - return cls(fn=dist.fn, fn_str=dist.fn_str, children=children, is_unary=is_unary) + return cls(fn=dist.fn, fn_str=dist.fn_str, children=children, is_unary=is_unary, infix=dist.infix) def _join_dists(self, left_type, right_type, join_fn, commutative=True, condition=None): simplified_dists = [] @@ -1900,7 +1894,8 @@ def _join_dists(self, left_type, right_type, join_fn, commutative=True, conditio simplified_dists.insert(acc_index, acc) self.children = simplified_dists - def _lognormal_times_const(self, norm_mean, norm_sd, k): + @classmethod + def _lognormal_times_const(cls, norm_mean, norm_sd, k): if k == 0: return 0 elif k > 0: @@ -1918,6 +1913,17 @@ def simplify(self): if isinstance(self.children[i], FlatTree): self.children[i] = self.children[i].simplify() + # Simplify unary operations + if len(self.children) == 1: + child = self.children[0] + if self.fn == operator.neg: + if isinstance(child, Real): + return -child + if isinstance(child, NormalDistribution): + return NormalDistribution(mean=-child.mean, sd=child.sd) + + return ComplexDistribution(child, right=None, fn=self.fn, fn_str=self.fn_str, infix=self.infix) + if self.fn == operator.add: self._join_dists( NormalDistribution, @@ -2048,4 +2054,4 @@ def simplify(self): condition=lambda x, y: y > 0, ) - return reduce(lambda acc, x: ComplexDistribution(acc, x), self.children) + return reduce(lambda acc, x: ComplexDistribution(acc, x, fn=self.fn, fn_str=self.fn_str), self.children) diff --git a/tests/test_simplify.py b/tests/test_simplify.py index 11b37a9..6a3623a 100644 --- a/tests/test_simplify.py +++ b/tests/test_simplify.py @@ -1,5 +1,6 @@ from hypothesis import assume, example, given import hypothesis.strategies as st +from numbers import Real import numpy as np import pytest from pytest import approx @@ -246,3 +247,16 @@ def test_simplify_big_sum(): assert isinstance(simplified.left.left.left, NormalDistribution) assert simplified.left.left.left.mean == 0 assert simplified.left.left.left.sd == approx(np.sqrt(5)) + + +def test_preserve_non_commutative_op(): + simplified = (2 ** 3 ** norm(mean=0, sd=1)).simplify() + assert isinstance(simplified, ComplexDistribution) + assert simplified.fn_str == "**" + assert isinstance(simplified.left, Real) + assert simplified.left == 2 + assert isinstance(simplified.right, ComplexDistribution) + assert simplified.right.fn_str == "**" + assert isinstance(simplified.right.left, Real) + assert simplified.right.left == 3 + assert isinstance(simplified.right.right, NormalDistribution) From f72ca3cdeb91dbf240c4a353c6a2e94a1fd07789 Mon Sep 17 00:00:00 2001 From: Michael Dickens Date: Sun, 3 Dec 2023 14:07:50 -0800 Subject: [PATCH 08/10] analytic: simplify before sampling --- squigglepy/samplers.py | 5 +++++ tests/test_samplers.py | 10 ++++++---- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/squigglepy/samplers.py b/squigglepy/samplers.py index e88df0b..031c6cc 100644 --- a/squigglepy/samplers.py +++ b/squigglepy/samplers.py @@ -39,6 +39,7 @@ LognormalDistribution, MixtureDistribution, NormalDistribution, + OperableDistribution, ParetoDistribution, PoissonDistribution, TDistribution, @@ -877,6 +878,10 @@ def sample( if verbose is None: verbose = n >= 1000000 + # Simplify distribution analytically before sampling + if isinstance(dist, OperableDistribution): + dist = dist.simplify() + # Handle loading from cache samples = None has_in_mem_cache = str(dist) in _squigglepy_internal_sample_caches diff --git a/tests/test_samplers.py b/tests/test_samplers.py index 1861735..a768601 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -848,11 +848,12 @@ def sample_fn(): assert sample(sample_fn) == 5 +@patch.object(samplers, "gamma_sample", Mock(return_value=1)) @patch.object(samplers, "normal_sample", Mock(return_value=1)) @patch.object(samplers, "lognormal_sample", Mock(return_value=4)) def test_sample_callable_resolves_fully2(): def really_inner_sample_fn(): - return 1 + return gamma(1, 4) def inner_sample_fn(): return norm(1, 4) + lognorm(1, 10) + really_inner_sample_fn() @@ -870,15 +871,16 @@ def test_sample_invalid_input(): @patch.object(samplers, "normal_sample", Mock(return_value=100)) +@patch.object(samplers, "lognormal_sample", Mock(return_value=100)) def test_sample_math(): - assert ~(norm(0, 1) + norm(1, 2)) == 200 + assert ~(norm(0, 1) + lognorm(1, 2)) == 200 @patch.object(samplers, "normal_sample", Mock(return_value=10)) @patch.object(samplers, "lognormal_sample", Mock(return_value=100)) def test_sample_complex_math(): - obj = (2 ** norm(0, 1)) - (8 * 6) + 2 + (lognorm(10, 100) / 11) + 8 - expected = (2**10) - (8 * 6) + 2 + (100 / 11) + 8 + obj = (2 ** norm(0, 1)) - (8 * 6) + 2 + (lognorm(10, 100) + 11) / 8 + expected = (2**10) - (8 * 6) + 2 + (100 + 11) / 8 assert ~obj == expected From 4b34c513bcbb8296af59741ab07a877a50b0fe45 Mon Sep 17 00:00:00 2001 From: Michael Dickens Date: Sun, 3 Dec 2023 23:38:44 -0800 Subject: [PATCH 09/10] analytic: lognorm reciprocal regardless of type of numerator --- squigglepy/distributions.py | 21 +++++++++------------ tests/test_simplify.py | 16 +++++++++++++++- 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/squigglepy/distributions.py b/squigglepy/distributions.py index 51b4983..1c4afcd 100644 --- a/squigglepy/distributions.py +++ b/squigglepy/distributions.py @@ -1829,6 +1829,15 @@ def build(cls, dist): fn_str="*", ) ) + if dist.fn == operator.truediv and isinstance(dist.right, LognormalDistribution): + return cls.build( + ComplexDistribution( + dist.left, + LognormalDistribution(norm_mean=-dist.right.norm_mean, norm_sd=dist.right.norm_sd), + fn=operator.mul, + fn_str="*", + ) + ) left_tree = cls.build(dist.left) right_tree = cls.build(dist.right) @@ -2024,24 +2033,12 @@ def simplify(self): ), commutative=False, ) - # self._join_dists( - # LognormalDistribution, - # Real, - # lambda x, y: self._lognormal_times_const(x.norm_mean, x.norm_sd, 1 / y), - # commutative=False, - # ) self._join_dists( Real, LognormalDistribution, lambda x, y: self._lognormal_times_const(-y.norm_mean, y.norm_sd, x), commutative=False, ) - # self._join_dists( - # NormalDistribution, - # Real, - # lambda x, y: NormalDistribution(mean=x.mean / y, sd=x.sd / y), - # commutative=False, - # ) elif self.fn == operator.pow: self._join_dists( diff --git a/tests/test_simplify.py b/tests/test_simplify.py index 6a3623a..9253790 100644 --- a/tests/test_simplify.py +++ b/tests/test_simplify.py @@ -126,7 +126,7 @@ def test_simplify_div_lognorms(): assert simplified.norm_sd == approx(np.sqrt(5)) -def test_simplify_div_normal_by_const(): +def test_simplify_div_norm_by_const(): x = norm(mean=3, sd=1) y = 2 simplified = (x / y).simplify() @@ -135,6 +135,20 @@ def test_simplify_div_normal_by_const(): assert simplified.sd == approx(0.5) +def test_simplify_div_norm_by_lognorm(): + x = norm(mean=3, sd=1) + y = lognorm(norm_mean=2, norm_sd=2) + simplified = (x / y).simplify() + assert isinstance(simplified, ComplexDistribution) + assert simplified.fn_str == "*" + assert isinstance(simplified.left, NormalDistribution) + assert simplified.left.mean == 3 + assert simplified.left.sd == 1 + assert isinstance(simplified.right, LognormalDistribution) + assert simplified.right.norm_mean == -2 + assert simplified.right.norm_sd == approx(2) + + def test_simplify_lognorm_pow(): x = lognorm(norm_mean=3, norm_sd=2) y = 2 From 0781b7aa77d3014b3dd0b158bb2a5dea2649d991 Mon Sep 17 00:00:00 2001 From: Michael Dickens Date: Wed, 3 Jan 2024 07:00:25 -0800 Subject: [PATCH 10/10] analytic: fix linter --- squigglepy/distributions.py | 21 ++++++++++++++------- tests/test_simplify.py | 33 +++++++++++++++++++++++++++------ 2 files changed, 41 insertions(+), 13 deletions(-) diff --git a/squigglepy/distributions.py b/squigglepy/distributions.py index 1c4afcd..db4e1e5 100644 --- a/squigglepy/distributions.py +++ b/squigglepy/distributions.py @@ -1833,7 +1833,9 @@ def build(cls, dist): return cls.build( ComplexDistribution( dist.left, - LognormalDistribution(norm_mean=-dist.right.norm_mean, norm_sd=dist.right.norm_sd), + LognormalDistribution( + norm_mean=-dist.right.norm_mean, norm_sd=dist.right.norm_sd + ), fn=operator.mul, fn_str="*", ) @@ -1863,7 +1865,9 @@ def build(cls, dist): else: children.append(right_tree) - return cls(fn=dist.fn, fn_str=dist.fn_str, children=children, is_unary=is_unary, infix=dist.infix) + return cls( + fn=dist.fn, fn_str=dist.fn_str, children=children, is_unary=is_unary, infix=dist.infix + ) def _join_dists(self, left_type, right_type, join_fn, commutative=True, condition=None): simplified_dists = [] @@ -1931,7 +1935,9 @@ def simplify(self): if isinstance(child, NormalDistribution): return NormalDistribution(mean=-child.mean, sd=child.sd) - return ComplexDistribution(child, right=None, fn=self.fn, fn_str=self.fn_str, infix=self.infix) + return ComplexDistribution( + child, right=None, fn=self.fn, fn_str=self.fn_str, infix=self.infix + ) if self.fn == operator.add: self._join_dists( @@ -1942,9 +1948,7 @@ def simplify(self): ), ) self._join_dists( - NormalDistribution, - Real, - lambda x, y: NormalDistribution(mean=x.mean + y, sd=x.sd) + NormalDistribution, Real, lambda x, y: NormalDistribution(mean=x.mean + y, sd=x.sd) ) self._join_dists( BernoulliDistribution, @@ -2051,4 +2055,7 @@ def simplify(self): condition=lambda x, y: y > 0, ) - return reduce(lambda acc, x: ComplexDistribution(acc, x, fn=self.fn, fn_str=self.fn_str), self.children) + return reduce( + lambda acc, x: ComplexDistribution(acc, x, fn=self.fn, fn_str=self.fn_str), + self.children, + ) diff --git a/tests/test_simplify.py b/tests/test_simplify.py index 9253790..2aa62ad 100644 --- a/tests/test_simplify.py +++ b/tests/test_simplify.py @@ -1,11 +1,23 @@ -from hypothesis import assume, example, given +from hypothesis import example, given import hypothesis.strategies as st from numbers import Real import numpy as np -import pytest from pytest import approx -from ..squigglepy.distributions import * +from ..squigglepy.distributions import ( + bernoulli, + binomial, + exponential, + gamma, + lognorm, + norm, + BinomialDistribution, + ComplexDistribution, + ExponentialDistribution, + GammaDistribution, + LognormalDistribution, + NormalDistribution, +) def test_simplify_add_norm(): @@ -189,8 +201,15 @@ def test_simplify_bernoulli_sum(): simplified = (bernoulli(p=0.5) + bernoulli(p=0.6)).simplify() assert isinstance(simplified, ComplexDistribution) + def test_simplify_bernoulli_plus_binomial(): - simplified = (bernoulli(p=0.5) + binomial(n=10, p=0.2) + binomial(n=2, p=0.5) + binomial(n=3, p=0.5) + bernoulli(p=0.5)).simplify() + simplified = ( + bernoulli(p=0.5) + + binomial(n=10, p=0.2) + + binomial(n=2, p=0.5) + + binomial(n=3, p=0.5) + + bernoulli(p=0.5) + ).simplify() assert isinstance(simplified, ComplexDistribution) assert isinstance(simplified.left, BinomialDistribution) assert isinstance(simplified.right, BinomialDistribution) @@ -230,7 +249,9 @@ def test_simplify_exponential_sum(): def test_simplify_exponential_gamma_sum(): - simplified = (5 * (exponential(scale=3) + gamma(shape=2, scale=3) + exponential(scale=3))).simplify() + simplified = ( + 5 * (exponential(scale=3) + gamma(shape=2, scale=3) + exponential(scale=3)) + ).simplify() assert isinstance(simplified, GammaDistribution) assert simplified.shape == 4 assert simplified.scale == 15 @@ -264,7 +285,7 @@ def test_simplify_big_sum(): def test_preserve_non_commutative_op(): - simplified = (2 ** 3 ** norm(mean=0, sd=1)).simplify() + simplified = (2**3 ** norm(mean=0, sd=1)).simplify() assert isinstance(simplified, ComplexDistribution) assert simplified.fn_str == "**" assert isinstance(simplified.left, Real)