diff --git a/squigglepy/distributions.py b/squigglepy/distributions.py index 413e749..db4e1e5 100644 --- a/squigglepy/distributions.py +++ b/squigglepy/distributions.py @@ -12,6 +12,8 @@ from collections.abc import Iterable from abc import ABC, abstractmethod +from functools import reduce +from numbers import Real class BaseDistribution(ABC): @@ -172,6 +174,43 @@ def __rpow__(self, dist): def __hash__(self): return hash(repr(self)) + def simplify(self): + """Simplify a distribution by evaluating all operations that can be + performed analytically, for example by reducing a sum of normal + distributions into a single normal distribution. Return a new + ``OperableDistribution`` that represents the simplified distribution. + + Possible simplifications: + any - any = any + (-any) + any / constant = any * (1 / constant) + + -Normal = Normal + + Normal + Normal = Normal + constant + Normal = Normal + Bernoulli + Bernoulli = Binomial + Bernoulli + Binomial = Binomial + Binomial + Binomial = Binomial + Chi-Square + Chi-Square = Chi-Square + Exponential + Exponential = Exponential + Exponential + Gamma = Gamma + Gamma + Gamma = Gamma + Poisson + Poisson = Poisson + + constant * Normal = Normal + Lognormal * Lognormal = Lognormal + constant * Lognormal = Lognormal + constant * Exponential = Exponential + constant * Gamma = Gamma + + Lognormal / Lognormal = Lognormal + constant / Lognormal = Lognormal + + Lognormal ** constant = Lognormal + + """ + return self + # Distribution are either discrete, continuous, or composite @@ -242,6 +281,9 @@ def __str__(self): raise ValueError return out + def simplify(self): + return FlatTree.build(self).simplify() + def _get_fname(f, name): if name is None: @@ -1699,3 +1741,321 @@ def geometric(p): geometric(0.1) """ return GeometricDistribution(p=p) + + +class FlatTree: + """Helper class for simplifying analytic expressions. A ``FlatTree`` is + sort of like a ``ComplexDistribution`` except that it flattens + commutative/associative operations onto a single object instead of having + one object per binary operation. + + This class operates in two phases. + + Phase 1: Generate a ``FlatTree`` object from a :ref:``BaseDistribution`` by + calling ``FlatTree.build(dist)``. This generates a tree where any series of + a single commutative/associative operation done repeatedly is flattened + onto a single ``FlatTree`` node. It also converts operations into a + normalized form, for example converting ``a - b`` into ``a + (-b)``. + + Phase 2: Generate a simplified ``Distribution`` by calling + :ref:``simplify``. This works by combing through each flat list of + distributions to find which ones can be analytically simplified (for + example, converting a sum of normal distributions into a single normal + distribution). + + """ + + COMMUTABLE_OPERATIONS = set([operator.add, operator.mul]) + + def __init__(self, dist=None, fn=None, fn_str=None, children=None, is_unary=False, infix=None): + self.dist = dist + self.fn = fn + self.fn_str = fn_str + self.children = children + self.is_unary = is_unary + self.infix = infix + if dist is not None: + self.is_leaf = True + elif fn is not None and children is not None: + self.is_leaf = False + else: + raise ValueError("Missing arguments to FlatTree constructor") + + def __str__(self): + if self.is_leaf: + return f"FlatTree({self.dist})" + else: + return "FlatTree({})[{}]".format(self.fn_str, ", ".join(map(str, self.children))) + + def __repr__(self): + return str(self) + + @classmethod + def build(cls, dist): + if dist is None: + return None + if isinstance(dist, Real): + return cls(dist=dist) + if not isinstance(dist, BaseDistribution): + raise ValueError(f"dist must be a BaseDistribution or numeric type, not {type(dist)}") + if not isinstance(dist, ComplexDistribution): + return cls(dist=dist) + + is_unary = dist.right is None + if is_unary and dist.right is not None: + raise ValueError(f"Multiple arguments provided for unary operator {dist.fn}") + + # Convert x - y into x + (-y) + if dist.fn == operator.sub: + return cls.build( + ComplexDistribution( + dist.left, + ComplexDistribution(dist.right, right=None, fn=operator.neg, fn_str="-"), + fn=operator.add, + fn_str="+", + ) + ) + + # If the denominator is a constant, replace division by constant + # with multiplication by the reciprocal of the constant + if dist.fn == operator.truediv and isinstance(dist.right, Real): + if dist.right == 0: + raise ZeroDivisionError("Division by zero in ComplexDistribution: {dist}") + return cls.build( + ComplexDistribution( + dist.left, + 1 / dist.right, + fn=operator.mul, + fn_str="*", + ) + ) + if dist.fn == operator.truediv and isinstance(dist.right, LognormalDistribution): + return cls.build( + ComplexDistribution( + dist.left, + LognormalDistribution( + norm_mean=-dist.right.norm_mean, norm_sd=dist.right.norm_sd + ), + fn=operator.mul, + fn_str="*", + ) + ) + + left_tree = cls.build(dist.left) + right_tree = cls.build(dist.right) + + # Make a list of possibly-joinable distributions, plus a list of + # children as trees who could not be simplified at this level + children = [] + + # If the child nodes use the same commutable operation as ``dist``, add + # their flattened ``children`` lists to ``children``. Otherwise, put + # the whole node in ``children``. + if left_tree.is_leaf: + children.append(left_tree.dist) + elif left_tree.fn == dist.fn and dist.fn in cls.COMMUTABLE_OPERATIONS: + children.extend(left_tree.children) + else: + children.append(left_tree) + if right_tree is not None: + if right_tree.is_leaf: + children.append(right_tree.dist) + elif right_tree.fn == dist.fn and dist.fn in cls.COMMUTABLE_OPERATIONS: + children.extend(right_tree.children) + else: + children.append(right_tree) + + return cls( + fn=dist.fn, fn_str=dist.fn_str, children=children, is_unary=is_unary, infix=dist.infix + ) + + def _join_dists(self, left_type, right_type, join_fn, commutative=True, condition=None): + simplified_dists = [] + acc = None + acc_index = None + acc_is_left = True + for i, x in enumerate(self.children): + if isinstance(x, BaseDistribution) and (x.lclip is not None or x.rclip is not None): + # We can't simplify a clipped distribution + simplified_dists.append(x) + elif acc is None and isinstance(x, left_type): + acc = x + acc_index = i + elif ( + acc is not None + and isinstance(x, right_type) + and acc_is_left + and (condition is None or condition(acc, x)) + ): + acc = join_fn(acc, x) + elif commutative and acc is None and isinstance(x, right_type): + acc = x + acc_index = i + acc_is_left = False + elif ( + commutative + and acc is not None + and isinstance(x, left_type) + and not acc_is_left + and (condition is None or condition(x, acc)) + ): + acc = join_fn(x, acc) + else: + simplified_dists.append(x) + + if acc is not None: + simplified_dists.insert(acc_index, acc) + self.children = simplified_dists + + @classmethod + def _lognormal_times_const(cls, norm_mean, norm_sd, k): + if k == 0: + return 0 + elif k > 0: + return LognormalDistribution(norm_mean=norm_mean + np.log(k), norm_sd=norm_sd) + else: + return -LognormalDistribution(norm_mean=norm_mean + np.log(-k), norm_sd=norm_sd) + + def simplify(self): + """Convert a FlatTree back into a Distribution, simplifying as much as + possible.""" + if self.is_leaf: + return self.dist + + for i in range(len(self.children)): + if isinstance(self.children[i], FlatTree): + self.children[i] = self.children[i].simplify() + + # Simplify unary operations + if len(self.children) == 1: + child = self.children[0] + if self.fn == operator.neg: + if isinstance(child, Real): + return -child + if isinstance(child, NormalDistribution): + return NormalDistribution(mean=-child.mean, sd=child.sd) + + return ComplexDistribution( + child, right=None, fn=self.fn, fn_str=self.fn_str, infix=self.infix + ) + + if self.fn == operator.add: + self._join_dists( + NormalDistribution, + NormalDistribution, + lambda x, y: NormalDistribution( + mean=x.mean + y.mean, sd=np.sqrt(x.sd**2 + y.sd**2) + ), + ) + self._join_dists( + NormalDistribution, Real, lambda x, y: NormalDistribution(mean=x.mean + y, sd=x.sd) + ) + self._join_dists( + BernoulliDistribution, + BernoulliDistribution, + lambda x, y: BinomialDistribution(n=2, p=x.p), + condition=lambda x, y: x.p == y.p, + ) + self._join_dists( + BinomialDistribution, + BernoulliDistribution, + lambda x, y: BinomialDistribution(n=x.n + 1, p=x.p), + condition=lambda x, y: x.p == y.p, + ) + self._join_dists( + BinomialDistribution, + BinomialDistribution, + lambda x, y: BinomialDistribution(n=x.n + y.n, p=x.p), + condition=lambda x, y: x.p == y.p, + ) + self._join_dists( + ChiSquareDistribution, + ChiSquareDistribution, + lambda x, y: ChiSquareDistribution(df=x.df + y.df), + ) + self._join_dists( + ExponentialDistribution, + ExponentialDistribution, + lambda x, y: GammaDistribution(shape=2, scale=x.scale), + condition=lambda x, y: x.scale == y.scale, + ) + self._join_dists( + ExponentialDistribution, + GammaDistribution, + lambda x, y: GammaDistribution(shape=y.shape + 1, scale=x.scale), + condition=lambda x, y: x.scale == y.scale, + ) + self._join_dists( + GammaDistribution, + GammaDistribution, + lambda x, y: GammaDistribution(shape=x.shape + y.shape, scale=x.scale), + condition=lambda x, y: x.scale == y.scale, + ) + self._join_dists( + PoissonDistribution, + PoissonDistribution, + lambda x, y: PoissonDistribution(lam=x.lam + y.lam), + ) + + elif self.fn == operator.mul: + self._join_dists( + NormalDistribution, + Real, + lambda x, y: NormalDistribution(mean=x.mean * y, sd=x.sd * y), + ) + self._join_dists( + LognormalDistribution, + LognormalDistribution, + lambda x, y: LognormalDistribution( + norm_mean=x.norm_mean + y.norm_mean, + norm_sd=np.sqrt(x.norm_sd**2 + y.norm_sd**2), + ), + ) + self._join_dists( + LognormalDistribution, + Real, + lambda x, y: self._lognormal_times_const(x.norm_mean, x.norm_sd, y), + ) + self._join_dists( + ExponentialDistribution, + Real, + lambda x, y: ExponentialDistribution(scale=x.scale * y), + ) + self._join_dists( + GammaDistribution, + Real, + lambda x, y: GammaDistribution(shape=x.shape, scale=x.scale * y), + ) + + elif self.fn == operator.truediv: + self._join_dists( + LognormalDistribution, + LognormalDistribution, + lambda x, y: LognormalDistribution( + norm_mean=x.norm_mean - y.norm_mean, + norm_sd=np.sqrt(x.norm_sd**2 + y.norm_sd**2), + ), + commutative=False, + ) + self._join_dists( + Real, + LognormalDistribution, + lambda x, y: self._lognormal_times_const(-y.norm_mean, y.norm_sd, x), + commutative=False, + ) + + elif self.fn == operator.pow: + self._join_dists( + LognormalDistribution, + Real, + lambda x, y: LognormalDistribution( + norm_mean=x.norm_mean * y, norm_sd=x.norm_sd * y + ), + commutative=False, + condition=lambda x, y: y > 0, + ) + + return reduce( + lambda acc, x: ComplexDistribution(acc, x, fn=self.fn, fn_str=self.fn_str), + self.children, + ) diff --git a/squigglepy/samplers.py b/squigglepy/samplers.py index e88df0b..031c6cc 100644 --- a/squigglepy/samplers.py +++ b/squigglepy/samplers.py @@ -39,6 +39,7 @@ LognormalDistribution, MixtureDistribution, NormalDistribution, + OperableDistribution, ParetoDistribution, PoissonDistribution, TDistribution, @@ -877,6 +878,10 @@ def sample( if verbose is None: verbose = n >= 1000000 + # Simplify distribution analytically before sampling + if isinstance(dist, OperableDistribution): + dist = dist.simplify() + # Handle loading from cache samples = None has_in_mem_cache = str(dist) in _squigglepy_internal_sample_caches diff --git a/tests/test_samplers.py b/tests/test_samplers.py index 1861735..a768601 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -848,11 +848,12 @@ def sample_fn(): assert sample(sample_fn) == 5 +@patch.object(samplers, "gamma_sample", Mock(return_value=1)) @patch.object(samplers, "normal_sample", Mock(return_value=1)) @patch.object(samplers, "lognormal_sample", Mock(return_value=4)) def test_sample_callable_resolves_fully2(): def really_inner_sample_fn(): - return 1 + return gamma(1, 4) def inner_sample_fn(): return norm(1, 4) + lognorm(1, 10) + really_inner_sample_fn() @@ -870,15 +871,16 @@ def test_sample_invalid_input(): @patch.object(samplers, "normal_sample", Mock(return_value=100)) +@patch.object(samplers, "lognormal_sample", Mock(return_value=100)) def test_sample_math(): - assert ~(norm(0, 1) + norm(1, 2)) == 200 + assert ~(norm(0, 1) + lognorm(1, 2)) == 200 @patch.object(samplers, "normal_sample", Mock(return_value=10)) @patch.object(samplers, "lognormal_sample", Mock(return_value=100)) def test_sample_complex_math(): - obj = (2 ** norm(0, 1)) - (8 * 6) + 2 + (lognorm(10, 100) / 11) + 8 - expected = (2**10) - (8 * 6) + 2 + (100 / 11) + 8 + obj = (2 ** norm(0, 1)) - (8 * 6) + 2 + (lognorm(10, 100) + 11) / 8 + expected = (2**10) - (8 * 6) + 2 + (100 + 11) / 8 assert ~obj == expected diff --git a/tests/test_simplify.py b/tests/test_simplify.py new file mode 100644 index 0000000..2aa62ad --- /dev/null +++ b/tests/test_simplify.py @@ -0,0 +1,297 @@ +from hypothesis import example, given +import hypothesis.strategies as st +from numbers import Real +import numpy as np +from pytest import approx + +from ..squigglepy.distributions import ( + bernoulli, + binomial, + exponential, + gamma, + lognorm, + norm, + BinomialDistribution, + ComplexDistribution, + ExponentialDistribution, + GammaDistribution, + LognormalDistribution, + NormalDistribution, +) + + +def test_simplify_add_norm(): + x = norm(mean=1, sd=1) + y = norm(mean=2, sd=2) + sum2 = x + y + simplified2 = sum2.simplify() + assert isinstance(simplified2, NormalDistribution) + assert simplified2.mean == 3 + assert simplified2.sd == approx(np.sqrt(5)) + + +def test_simplify_add_3_normals(): + x = norm(mean=1, sd=1) + y = norm(mean=2, sd=2) + z = norm(mean=-3, sd=2) + sum3_left = (x + y) + z + sum3_right = x + (y + z) + simplified3_left = sum3_left.simplify() + simplified3_right = sum3_right.simplify() + assert isinstance(simplified3_left, NormalDistribution) + assert simplified3_left.mean == 0 + assert simplified3_left.sd == approx(np.sqrt(1 + 4 + 4)) + assert isinstance(simplified3_right, NormalDistribution) + assert simplified3_right.mean == 0 + assert simplified3_right.sd == approx(np.sqrt(1 + 4 + 4)) + + +def test_simplify_normal_plus_const(): + x = norm(mean=0, sd=1) + y = 2 + z = norm(mean=1, sd=1) + simplified = (x + y + z).simplify() + assert isinstance(simplified, NormalDistribution) + assert simplified.mean == 3 + assert simplified.sd == approx(np.sqrt(2)) + + +def simplify_scale_norm(): + x = norm(mean=2, sd=4) + y = 1.5 + product2 = x * y + simplified = product2.simplify() + assert isinstance(simplified, NormalDistribution) + assert simplified.mean == approx(3) + assert simplified.sd == approx(6) + + +def test_simplify_mul_3_lognorms(): + x = lognorm(norm_mean=1, norm_sd=1) + y = lognorm(norm_mean=2, norm_sd=2) + z = lognorm(norm_mean=-2, norm_sd=3) + product3_left = (x * y) * z + product3_right = x * (y * z) + simplified3_left = product3_left.simplify() + simplified3_right = product3_right.simplify() + assert isinstance(simplified3_left, LognormalDistribution) + assert simplified3_left.norm_mean == 1 + assert simplified3_left.norm_sd == approx(np.sqrt(1 + 4 + 9)) + assert isinstance(simplified3_right, LognormalDistribution) + assert simplified3_right.norm_mean == 1 + assert simplified3_right.norm_sd == approx(np.sqrt(1 + 4 + 9)) + + +def test_simplify_sub_normals(): + x = norm(mean=1, sd=1) + y = norm(mean=2, sd=2) + difference = x - y + simplified = difference.simplify() + assert isinstance(simplified, NormalDistribution) + assert simplified.mean == -1 + assert simplified.sd == approx(np.sqrt(5)) + + +def test_simplify_normal_minus_const(): + x = norm(mean=0, sd=1) + y = 2 + simplified = (x - y).simplify() + assert isinstance(simplified, NormalDistribution) + assert simplified.mean == -2 + assert simplified.sd == approx(1) + + +@given( + norm_mean=st.floats(min_value=-10, max_value=10), + norm_sd=st.floats(min_value=0.1, max_value=3), + y=st.floats(min_value=0.1, max_value=10), +) +def test_simplify_div_lognorm_by_constant(norm_mean, norm_sd, y): + x = lognorm(norm_mean=norm_mean, norm_sd=norm_sd) + simplified = (x / y).simplify() + assert isinstance(simplified, LognormalDistribution) + assert simplified.norm_mean == approx(norm_mean - np.log(y)) + assert simplified.norm_sd == approx(norm_sd) + + +@given( + x=st.floats(min_value=0.1, max_value=10), + norm_mean=st.floats(min_value=-10, max_value=10), + norm_sd=st.floats(min_value=0.1, max_value=3), +) +@example(x=1, norm_mean=0, norm_sd=1) +def test_simplify_div_constant_by_lognorm(x, norm_mean, norm_sd): + y = lognorm(norm_mean=norm_mean, norm_sd=norm_sd) + simplified = (x / y).simplify() + assert isinstance(simplified, LognormalDistribution) + assert simplified.norm_mean == approx(np.log(x) - norm_mean) + assert simplified.norm_sd == approx(norm_sd) + + +def test_simplify_div_lognorms(): + x = lognorm(norm_mean=1, norm_sd=1) + y = lognorm(norm_mean=2, norm_sd=2) + quotient = x / y + simplified = quotient.simplify() + assert isinstance(simplified, LognormalDistribution) + assert simplified.norm_mean == -1 + assert simplified.norm_sd == approx(np.sqrt(5)) + + +def test_simplify_div_norm_by_const(): + x = norm(mean=3, sd=1) + y = 2 + simplified = (x / y).simplify() + assert isinstance(simplified, NormalDistribution) + assert simplified.mean == approx(1.5) + assert simplified.sd == approx(0.5) + + +def test_simplify_div_norm_by_lognorm(): + x = norm(mean=3, sd=1) + y = lognorm(norm_mean=2, norm_sd=2) + simplified = (x / y).simplify() + assert isinstance(simplified, ComplexDistribution) + assert simplified.fn_str == "*" + assert isinstance(simplified.left, NormalDistribution) + assert simplified.left.mean == 3 + assert simplified.left.sd == 1 + assert isinstance(simplified.right, LognormalDistribution) + assert simplified.right.norm_mean == -2 + assert simplified.right.norm_sd == approx(2) + + +def test_simplify_lognorm_pow(): + x = lognorm(norm_mean=3, norm_sd=2) + y = 2 + product = x**y + simplified = product.simplify() + assert isinstance(simplified, LognormalDistribution) + assert simplified.norm_mean == approx(6) + assert simplified.norm_sd == approx(4) + + y = -1 + product = x**y + simplified = product.simplify() + assert isinstance(simplified, ComplexDistribution) + + +def test_simplify_clipped_norm(): + x = norm(mean=1, sd=1) + y = norm(mean=0, sd=1, lclip=1) + z = norm(mean=3, sd=2) + simplified = (x + y + z).simplify() + assert isinstance(simplified, ComplexDistribution) + assert isinstance(simplified.left, NormalDistribution) + assert isinstance(simplified.right, NormalDistribution) + assert simplified.left.mean == 4 + assert simplified.left.sd == approx(np.sqrt(5)) + assert simplified.right.mean == 0 + assert simplified.right.sd == 1 + assert simplified.right.lclip == 1 + + +def test_simplify_bernoulli_sum(): + simplified = (bernoulli(p=0.5) + bernoulli(p=0.5)).simplify() + assert isinstance(simplified, BinomialDistribution) + assert simplified.n == 2 + assert simplified.p == 0.5 + + # Cannot simplify if the probabilities are different + simplified = (bernoulli(p=0.5) + bernoulli(p=0.6)).simplify() + assert isinstance(simplified, ComplexDistribution) + + +def test_simplify_bernoulli_plus_binomial(): + simplified = ( + bernoulli(p=0.5) + + binomial(n=10, p=0.2) + + binomial(n=2, p=0.5) + + binomial(n=3, p=0.5) + + bernoulli(p=0.5) + ).simplify() + assert isinstance(simplified, ComplexDistribution) + assert isinstance(simplified.left, BinomialDistribution) + assert isinstance(simplified.right, BinomialDistribution) + assert simplified.left.n == 7 + assert simplified.left.p == 0.5 + assert simplified.right.n == 10 + assert simplified.right.p == 0.2 + + +def test_simplify_gamma_sum(): + simplified = (gamma(shape=2, scale=1) + gamma(shape=4, scale=1)).simplify() + assert isinstance(simplified, GammaDistribution) + assert simplified.shape == 6 + assert simplified.scale == 1 + + # Cannot simplify if the scales are different + simplified = (gamma(shape=2, scale=1) + gamma(shape=2, scale=2)).simplify() + assert isinstance(simplified, ComplexDistribution) + + +def test_simplify_scale_gamma(): + simplified = (2 * gamma(shape=2, scale=3)).simplify() + assert isinstance(simplified, GammaDistribution) + assert simplified.shape == 2 + assert simplified.scale == 6 + + +def test_simplify_exponential_sum(): + simplified = (exponential(scale=2) + exponential(scale=2)).simplify() + assert isinstance(simplified, GammaDistribution) + assert simplified.shape == 2 + assert simplified.scale == 2 + + # Cannot simplify if the rates are different + simplified = (exponential(scale=2) + exponential(scale=3)).simplify() + assert isinstance(simplified, ComplexDistribution) + + +def test_simplify_exponential_gamma_sum(): + simplified = ( + 5 * (exponential(scale=3) + gamma(shape=2, scale=3) + exponential(scale=3)) + ).simplify() + assert isinstance(simplified, GammaDistribution) + assert simplified.shape == 4 + assert simplified.scale == 15 + + +def test_simplify_big_sum(): + simplified = ( + 2 * norm(mean=1, sd=1) + + lognorm(norm_mean=1, norm_sd=1) * lognorm(norm_mean=3, norm_sd=2) + + gamma(shape=2, scale=3) + + exponential(scale=3) / 5 + + exponential(scale=3) + - norm(mean=2, sd=1) + ).simplify() + + # simplifies to norm(0, sqrt(5)) + lognorm(4, sqrt(5)) + gamma(3, 3) + exponential(6) + assert isinstance(simplified, ComplexDistribution) + assert isinstance(simplified.right, ExponentialDistribution) + assert simplified.right.scale == approx(3 / 5) + assert isinstance(simplified.left, ComplexDistribution) + assert isinstance(simplified.left.right, GammaDistribution) + assert simplified.left.right.shape == 3 + assert simplified.left.right.scale == 3 + assert isinstance(simplified.left.left, ComplexDistribution) + assert isinstance(simplified.left.left.right, LognormalDistribution) + assert simplified.left.left.right.norm_mean == 4 + assert simplified.left.left.right.norm_sd == approx(np.sqrt(5)) + assert isinstance(simplified.left.left.left, NormalDistribution) + assert simplified.left.left.left.mean == 0 + assert simplified.left.left.left.sd == approx(np.sqrt(5)) + + +def test_preserve_non_commutative_op(): + simplified = (2**3 ** norm(mean=0, sd=1)).simplify() + assert isinstance(simplified, ComplexDistribution) + assert simplified.fn_str == "**" + assert isinstance(simplified.left, Real) + assert simplified.left == 2 + assert isinstance(simplified.right, ComplexDistribution) + assert simplified.right.fn_str == "**" + assert isinstance(simplified.right.left, Real) + assert simplified.right.left == 3 + assert isinstance(simplified.right.right, NormalDistribution)