Skip to content
360 changes: 360 additions & 0 deletions squigglepy/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from collections.abc import Iterable

from abc import ABC, abstractmethod
from functools import reduce
from numbers import Real


class BaseDistribution(ABC):
Expand Down Expand Up @@ -172,6 +174,43 @@ def __rpow__(self, dist):
def __hash__(self):
return hash(repr(self))

def simplify(self):
"""Simplify a distribution by evaluating all operations that can be
performed analytically, for example by reducing a sum of normal
distributions into a single normal distribution. Return a new
``OperableDistribution`` that represents the simplified distribution.

Possible simplifications:
any - any = any + (-any)
any / constant = any * (1 / constant)

-Normal = Normal

Normal + Normal = Normal
constant + Normal = Normal
Bernoulli + Bernoulli = Binomial
Bernoulli + Binomial = Binomial
Binomial + Binomial = Binomial
Chi-Square + Chi-Square = Chi-Square
Exponential + Exponential = Exponential
Exponential + Gamma = Gamma
Gamma + Gamma = Gamma
Poisson + Poisson = Poisson

constant * Normal = Normal
Lognormal * Lognormal = Lognormal
constant * Lognormal = Lognormal
constant * Exponential = Exponential
constant * Gamma = Gamma

Lognormal / Lognormal = Lognormal
constant / Lognormal = Lognormal

Lognormal ** constant = Lognormal

"""
return self


# Distribution are either discrete, continuous, or composite

Expand Down Expand Up @@ -242,6 +281,9 @@ def __str__(self):
raise ValueError
return out

def simplify(self):
return FlatTree.build(self).simplify()


def _get_fname(f, name):
if name is None:
Expand Down Expand Up @@ -1699,3 +1741,321 @@ def geometric(p):
<Distribution> geometric(0.1)
"""
return GeometricDistribution(p=p)


class FlatTree:
"""Helper class for simplifying analytic expressions. A ``FlatTree`` is
sort of like a ``ComplexDistribution`` except that it flattens
commutative/associative operations onto a single object instead of having
one object per binary operation.

This class operates in two phases.

Phase 1: Generate a ``FlatTree`` object from a :ref:``BaseDistribution`` by
calling ``FlatTree.build(dist)``. This generates a tree where any series of
a single commutative/associative operation done repeatedly is flattened
onto a single ``FlatTree`` node. It also converts operations into a
normalized form, for example converting ``a - b`` into ``a + (-b)``.

Phase 2: Generate a simplified ``Distribution`` by calling
:ref:``simplify``. This works by combing through each flat list of
distributions to find which ones can be analytically simplified (for
example, converting a sum of normal distributions into a single normal
distribution).

"""

COMMUTABLE_OPERATIONS = set([operator.add, operator.mul])

def __init__(self, dist=None, fn=None, fn_str=None, children=None, is_unary=False, infix=None):
self.dist = dist
self.fn = fn
self.fn_str = fn_str
self.children = children
self.is_unary = is_unary
self.infix = infix
if dist is not None:
self.is_leaf = True
elif fn is not None and children is not None:
self.is_leaf = False
else:
raise ValueError("Missing arguments to FlatTree constructor")

def __str__(self):
if self.is_leaf:
return f"FlatTree({self.dist})"
else:
return "FlatTree({})[{}]".format(self.fn_str, ", ".join(map(str, self.children)))

def __repr__(self):
return str(self)

@classmethod
def build(cls, dist):
if dist is None:
return None
if isinstance(dist, Real):
return cls(dist=dist)
if not isinstance(dist, BaseDistribution):
raise ValueError(f"dist must be a BaseDistribution or numeric type, not {type(dist)}")
if not isinstance(dist, ComplexDistribution):
return cls(dist=dist)

is_unary = dist.right is None
if is_unary and dist.right is not None:
raise ValueError(f"Multiple arguments provided for unary operator {dist.fn}")

# Convert x - y into x + (-y)
if dist.fn == operator.sub:
return cls.build(
ComplexDistribution(
dist.left,
ComplexDistribution(dist.right, right=None, fn=operator.neg, fn_str="-"),
fn=operator.add,
fn_str="+",
)
)

# If the denominator is a constant, replace division by constant
# with multiplication by the reciprocal of the constant
if dist.fn == operator.truediv and isinstance(dist.right, Real):
if dist.right == 0:
raise ZeroDivisionError("Division by zero in ComplexDistribution: {dist}")
return cls.build(
ComplexDistribution(
dist.left,
1 / dist.right,
fn=operator.mul,
fn_str="*",
)
)
if dist.fn == operator.truediv and isinstance(dist.right, LognormalDistribution):
return cls.build(
ComplexDistribution(
dist.left,
LognormalDistribution(
norm_mean=-dist.right.norm_mean, norm_sd=dist.right.norm_sd
),
fn=operator.mul,
fn_str="*",
)
)

left_tree = cls.build(dist.left)
right_tree = cls.build(dist.right)

# Make a list of possibly-joinable distributions, plus a list of
# children as trees who could not be simplified at this level
children = []

# If the child nodes use the same commutable operation as ``dist``, add
# their flattened ``children`` lists to ``children``. Otherwise, put
# the whole node in ``children``.
if left_tree.is_leaf:
children.append(left_tree.dist)
elif left_tree.fn == dist.fn and dist.fn in cls.COMMUTABLE_OPERATIONS:
children.extend(left_tree.children)
else:
children.append(left_tree)
if right_tree is not None:
if right_tree.is_leaf:
children.append(right_tree.dist)
elif right_tree.fn == dist.fn and dist.fn in cls.COMMUTABLE_OPERATIONS:
children.extend(right_tree.children)
else:
children.append(right_tree)

return cls(
fn=dist.fn, fn_str=dist.fn_str, children=children, is_unary=is_unary, infix=dist.infix
)

def _join_dists(self, left_type, right_type, join_fn, commutative=True, condition=None):
simplified_dists = []
acc = None
acc_index = None
acc_is_left = True
for i, x in enumerate(self.children):
if isinstance(x, BaseDistribution) and (x.lclip is not None or x.rclip is not None):
# We can't simplify a clipped distribution
simplified_dists.append(x)
elif acc is None and isinstance(x, left_type):
acc = x
acc_index = i
elif (
acc is not None
and isinstance(x, right_type)
and acc_is_left
and (condition is None or condition(acc, x))
):
acc = join_fn(acc, x)
elif commutative and acc is None and isinstance(x, right_type):
acc = x
acc_index = i
acc_is_left = False
elif (
commutative
and acc is not None
and isinstance(x, left_type)
and not acc_is_left
and (condition is None or condition(x, acc))
):
acc = join_fn(x, acc)
else:
simplified_dists.append(x)

if acc is not None:
simplified_dists.insert(acc_index, acc)
self.children = simplified_dists

@classmethod
def _lognormal_times_const(cls, norm_mean, norm_sd, k):
if k == 0:
return 0
elif k > 0:
return LognormalDistribution(norm_mean=norm_mean + np.log(k), norm_sd=norm_sd)
else:
return -LognormalDistribution(norm_mean=norm_mean + np.log(-k), norm_sd=norm_sd)

def simplify(self):
"""Convert a FlatTree back into a Distribution, simplifying as much as
possible."""
if self.is_leaf:
return self.dist

for i in range(len(self.children)):
if isinstance(self.children[i], FlatTree):
self.children[i] = self.children[i].simplify()

# Simplify unary operations
if len(self.children) == 1:
child = self.children[0]
if self.fn == operator.neg:
if isinstance(child, Real):
return -child
if isinstance(child, NormalDistribution):
return NormalDistribution(mean=-child.mean, sd=child.sd)

return ComplexDistribution(
child, right=None, fn=self.fn, fn_str=self.fn_str, infix=self.infix
)

if self.fn == operator.add:
self._join_dists(
NormalDistribution,
NormalDistribution,
lambda x, y: NormalDistribution(
mean=x.mean + y.mean, sd=np.sqrt(x.sd**2 + y.sd**2)
),
)
self._join_dists(
NormalDistribution, Real, lambda x, y: NormalDistribution(mean=x.mean + y, sd=x.sd)
)
self._join_dists(
BernoulliDistribution,
BernoulliDistribution,
lambda x, y: BinomialDistribution(n=2, p=x.p),
condition=lambda x, y: x.p == y.p,
)
self._join_dists(
BinomialDistribution,
BernoulliDistribution,
lambda x, y: BinomialDistribution(n=x.n + 1, p=x.p),
condition=lambda x, y: x.p == y.p,
)
self._join_dists(
BinomialDistribution,
BinomialDistribution,
lambda x, y: BinomialDistribution(n=x.n + y.n, p=x.p),
condition=lambda x, y: x.p == y.p,
)
self._join_dists(
ChiSquareDistribution,
ChiSquareDistribution,
lambda x, y: ChiSquareDistribution(df=x.df + y.df),
)
self._join_dists(
ExponentialDistribution,
ExponentialDistribution,
lambda x, y: GammaDistribution(shape=2, scale=x.scale),
condition=lambda x, y: x.scale == y.scale,
)
self._join_dists(
ExponentialDistribution,
GammaDistribution,
lambda x, y: GammaDistribution(shape=y.shape + 1, scale=x.scale),
condition=lambda x, y: x.scale == y.scale,
)
self._join_dists(
GammaDistribution,
GammaDistribution,
lambda x, y: GammaDistribution(shape=x.shape + y.shape, scale=x.scale),
condition=lambda x, y: x.scale == y.scale,
)
self._join_dists(
PoissonDistribution,
PoissonDistribution,
lambda x, y: PoissonDistribution(lam=x.lam + y.lam),
)

elif self.fn == operator.mul:
self._join_dists(
NormalDistribution,
Real,
lambda x, y: NormalDistribution(mean=x.mean * y, sd=x.sd * y),
)
self._join_dists(
LognormalDistribution,
LognormalDistribution,
lambda x, y: LognormalDistribution(
norm_mean=x.norm_mean + y.norm_mean,
norm_sd=np.sqrt(x.norm_sd**2 + y.norm_sd**2),
),
)
self._join_dists(
LognormalDistribution,
Real,
lambda x, y: self._lognormal_times_const(x.norm_mean, x.norm_sd, y),
)
self._join_dists(
ExponentialDistribution,
Real,
lambda x, y: ExponentialDistribution(scale=x.scale * y),
)
self._join_dists(
GammaDistribution,
Real,
lambda x, y: GammaDistribution(shape=x.shape, scale=x.scale * y),
)

elif self.fn == operator.truediv:
self._join_dists(
LognormalDistribution,
LognormalDistribution,
lambda x, y: LognormalDistribution(
norm_mean=x.norm_mean - y.norm_mean,
norm_sd=np.sqrt(x.norm_sd**2 + y.norm_sd**2),
),
commutative=False,
)
self._join_dists(
Real,
LognormalDistribution,
lambda x, y: self._lognormal_times_const(-y.norm_mean, y.norm_sd, x),
commutative=False,
)

elif self.fn == operator.pow:
self._join_dists(
LognormalDistribution,
Real,
lambda x, y: LognormalDistribution(
norm_mean=x.norm_mean * y, norm_sd=x.norm_sd * y
),
commutative=False,
condition=lambda x, y: y > 0,
)

return reduce(
lambda acc, x: ComplexDistribution(acc, x, fn=self.fn, fn_str=self.fn_str),
self.children,
)
Loading