Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 176 additions & 1 deletion aemcmc/conjugates.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@
from aesara.graph.rewriting.basic import in2out, node_rewriter
from aesara.graph.rewriting.db import LocalGroupDB
from aesara.graph.rewriting.unify import eval_if_etuple
from aesara.tensor.random.basic import BinomialRV, NegBinomialRV, PoissonRV
from aesara.tensor.random.basic import (
BernoulliRV,
BinomialRV,
NegBinomialRV,
PoissonRV,
UniformRV,
)
from etuples import etuple, etuplize
from kanren import eq, lall, run
from unification import var
Expand Down Expand Up @@ -268,13 +274,182 @@
return rv_var.owner.outputs


def beta_bernoulli_conjugateo(observed_val, observed_rv_expr, posterior_expr):
r"""Produce a goal that represents the application of Bayes theorem
for a beta prior with a negative binomial observation model.

.. math::

\frac{
Y \sim \operatorname{P(x=1)}= p, \quad
p \sim \operatorname{Beta}\left(\alpha, \beta\right)
}{
\left(p \mid Y=y\right) \sim \operatorname{Beta}\left(\alpha + \sum^{n}_{i=1} y_i, \beta + n - \sum^{n}_{i=1} y_i,\right)
}


Parameters
----------
observed_val
The observed value.
observed_rv_expr
An expression that represents the observed variable.
posterior_exp
An expression that represents the posterior distribution of the latent
variable.

"""
# beta-negative_binomial observation model
alpha_lv, beta_lv = var(), var()
p_rng_lv = var()
p_size_lv = var()
p_type_idx_lv = var()
p_et = etuple(
etuplize(at.random.beta), p_rng_lv, p_size_lv, p_type_idx_lv, alpha_lv, beta_lv
)
Y_et = etuple(etuplize(at.random.bernoulli), var(), var(), var(), p_et)

new_alpha_et = etuple(etuplize(at.add), alpha_lv, observed_val)
new_beta_et = etuple(etuplize(at.add), beta_lv, 1, -observed_val)

p_posterior_et = etuple(
etuplize(at.random.beta),
new_alpha_et,
new_beta_et,
rng=p_rng_lv,
size=p_size_lv,
dtype=p_type_idx_lv,
)

return lall(
eq(observed_rv_expr, Y_et),
eq(posterior_expr, p_posterior_et),
)


@node_rewriter([BernoulliRV])
def local_beta_bernoulli_posterior(fgraph, node):
sampler_mappings = getattr(fgraph, "sampler_mappings", None)

Check warning on line 332 in aemcmc/conjugates.py

View check run for this annotation

Codecov / codecov/patch

aemcmc/conjugates.py#L332

Added line #L332 was not covered by tests

rv_var = node.outputs[1]
key = ("local_beta_bernoulli_posterior", rv_var)

Check warning on line 335 in aemcmc/conjugates.py

View check run for this annotation

Codecov / codecov/patch

aemcmc/conjugates.py#L334-L335

Added lines #L334 - L335 were not covered by tests

if sampler_mappings is None or key in sampler_mappings.rvs_seen:

Check warning on line 337 in aemcmc/conjugates.py

View check run for this annotation

Codecov / codecov/patch

aemcmc/conjugates.py#L337

Added line #L337 was not covered by tests
return None # pragma: no cover

q = var()

Check warning on line 340 in aemcmc/conjugates.py

View check run for this annotation

Codecov / codecov/patch

aemcmc/conjugates.py#L340

Added line #L340 was not covered by tests

rv_et = etuplize(rv_var)

Check warning on line 342 in aemcmc/conjugates.py

View check run for this annotation

Codecov / codecov/patch

aemcmc/conjugates.py#L342

Added line #L342 was not covered by tests

res = run(None, q, beta_bernoulli_conjugateo(rv_var, rv_et, q))
res = next(res, None)

Check warning on line 345 in aemcmc/conjugates.py

View check run for this annotation

Codecov / codecov/patch

aemcmc/conjugates.py#L344-L345

Added lines #L344 - L345 were not covered by tests

if res is None:

Check warning on line 347 in aemcmc/conjugates.py

View check run for this annotation

Codecov / codecov/patch

aemcmc/conjugates.py#L347

Added line #L347 was not covered by tests
return None # pragma: no cover

beta_rv = rv_et[-1].evaled_obj
beta_posterior = eval_if_etuple(res)

Check warning on line 351 in aemcmc/conjugates.py

View check run for this annotation

Codecov / codecov/patch

aemcmc/conjugates.py#L350-L351

Added lines #L350 - L351 were not covered by tests

sampler_mappings.rvs_to_samplers.setdefault(beta_rv, []).append(

Check warning on line 353 in aemcmc/conjugates.py

View check run for this annotation

Codecov / codecov/patch

aemcmc/conjugates.py#L353

Added line #L353 was not covered by tests
("local_beta_bernoulli_posterior", beta_posterior, None)
)
sampler_mappings.rvs_seen.add(key)

Check warning on line 356 in aemcmc/conjugates.py

View check run for this annotation

Codecov / codecov/patch

aemcmc/conjugates.py#L356

Added line #L356 was not covered by tests

return rv_var.owner.outputs

Check warning on line 358 in aemcmc/conjugates.py

View check run for this annotation

Codecov / codecov/patch

aemcmc/conjugates.py#L358

Added line #L358 was not covered by tests


def uniform_pareto_conjugateo(observed_val, observed_rv_expr, posterior_expr):
r"""Produce a goal that represents the application of Bayes theorem
for a pareto prior with a uniform with 0 as the lower bound observation model.

.. math::
Y \sim \operatorname{Uniform}\left(0, \theta\right)
\theta \sim \operatorname{pareto}\(max(x), k)



Parameters
----------
observed_val
The observed value.
observed_rv_expr
An expression that represents the observed variable.
posterior_exp
An expression that represents the posterior distribution of the latent
variable.

"""
# beta-negative_binomial observation model
x_lv, k_lv = var(), var()
theta_rng_lv = var()
theta_size_lv = var()
theta_type_idx_lv = var()
theta_et = etuple(
etuplize(at.random.pareto),
theta_rng_lv,
theta_size_lv,
theta_type_idx_lv,
k_lv,
x_lv,
)
Y_et = etuple(etuplize(at.random.uniform), var(), var(), var(), var(), theta_et)

new_x_et = etuple(at.math.max, observed_val)
new_k_et = etuple(etuplize(at.add), k_lv, 1)

theta_posterior_et = etuple(
etuplize(at.random.pareto),
new_k_et,
new_x_et,
rng=theta_rng_lv,
size=theta_size_lv,
dtype=theta_type_idx_lv,
)
return lall(
eq(observed_rv_expr, Y_et),
eq(posterior_expr, theta_posterior_et),
)


@node_rewriter([UniformRV])
def local_uniform_pareto_posterior(fgraph, node):
sampler_mappings = getattr(fgraph, "sampler_mappings", None)

Check warning on line 416 in aemcmc/conjugates.py

View check run for this annotation

Codecov / codecov/patch

aemcmc/conjugates.py#L416

Added line #L416 was not covered by tests

rv_var = node.outputs[1]
key = ("local_beta_negative_binomial_posterior", rv_var)

Check warning on line 419 in aemcmc/conjugates.py

View check run for this annotation

Codecov / codecov/patch

aemcmc/conjugates.py#L418-L419

Added lines #L418 - L419 were not covered by tests

if sampler_mappings is None or key in sampler_mappings.rvs_seen:

Check warning on line 421 in aemcmc/conjugates.py

View check run for this annotation

Codecov / codecov/patch

aemcmc/conjugates.py#L421

Added line #L421 was not covered by tests
return None # pragma: no cover

q = var()

Check warning on line 424 in aemcmc/conjugates.py

View check run for this annotation

Codecov / codecov/patch

aemcmc/conjugates.py#L424

Added line #L424 was not covered by tests

rv_et = etuplize(rv_var)

Check warning on line 426 in aemcmc/conjugates.py

View check run for this annotation

Codecov / codecov/patch

aemcmc/conjugates.py#L426

Added line #L426 was not covered by tests

res = run(None, q, uniform_pareto_conjugateo(rv_var, rv_et, q))
res = next(res, None)

Check warning on line 429 in aemcmc/conjugates.py

View check run for this annotation

Codecov / codecov/patch

aemcmc/conjugates.py#L428-L429

Added lines #L428 - L429 were not covered by tests

if res is None:

Check warning on line 431 in aemcmc/conjugates.py

View check run for this annotation

Codecov / codecov/patch

aemcmc/conjugates.py#L431

Added line #L431 was not covered by tests
return None # pragma: no cover

pareto_rv = rv_et[-1].evaled_obj
pareto_posterior = eval_if_etuple(res)

Check warning on line 435 in aemcmc/conjugates.py

View check run for this annotation

Codecov / codecov/patch

aemcmc/conjugates.py#L434-L435

Added lines #L434 - L435 were not covered by tests

sampler_mappings.rvs_to_samplers.setdefault(pareto_rv, []).append(

Check warning on line 437 in aemcmc/conjugates.py

View check run for this annotation

Codecov / codecov/patch

aemcmc/conjugates.py#L437

Added line #L437 was not covered by tests
("local_uniform_pareto_posterior", pareto_posterior, None)
)
sampler_mappings.rvs_seen.add(key)

Check warning on line 440 in aemcmc/conjugates.py

View check run for this annotation

Codecov / codecov/patch

aemcmc/conjugates.py#L440

Added line #L440 was not covered by tests

return rv_var.owner.outputs

Check warning on line 442 in aemcmc/conjugates.py

View check run for this annotation

Codecov / codecov/patch

aemcmc/conjugates.py#L442

Added line #L442 was not covered by tests


conjugates_db = LocalGroupDB(apply_all_rewrites=True)
conjugates_db.name = "conjugates_db"
conjugates_db.register("beta_binomial", local_beta_binomial_posterior, "basic")
conjugates_db.register("gamma_poisson", local_gamma_poisson_posterior, "basic")
conjugates_db.register(
"negative_binomial", local_beta_negative_binomial_posterior, "basic"
)
conjugates_db.register("uniform", local_uniform_pareto_posterior, "basic")


sampler_finder_db.register(
Expand Down
122 changes: 122 additions & 0 deletions tests/test_conjugates.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
import pytest
from aesara.graph.rewriting.unify import eval_if_etuple
from aesara.tensor.random import RandomStream
from etuples import etuple, etuplize
from kanren import run
from unification import var

from aemcmc.conjugates import (
beta_bernoulli_conjugateo,
beta_binomial_conjugateo,
beta_negative_binomial_conjugateo,
gamma_poisson_conjugateo,
uniform_pareto_conjugateo,
)


Expand Down Expand Up @@ -157,3 +160,122 @@ def test_beta_negative_binomial_conjugate_expand():
expanded = eval_if_etuple(expanded_expr)

assert isinstance(expanded.owner.op, type(at.random.beta))


def test_beta_bernoulli_conjugate_contract():
"""Produce the closed-form posterior for the binomial observation model with
a beta prior.

"""
srng = RandomStream(0)

alpha_tt = at.scalar("alpha")
beta_tt = at.scalar("beta")
p_rv = srng.beta(alpha_tt, beta_tt, name="p")

Y_rv = srng.bernoulli(p_rv)
y_vv = Y_rv.clone()
y_vv.tag.name = "y"

q_lv = var()
(posterior_expr,) = run(1, q_lv, beta_bernoulli_conjugateo(y_vv, Y_rv, q_lv))
posterior = eval_if_etuple(posterior_expr)

assert isinstance(posterior.owner.op, type(at.random.beta))

# Build the sampling function and check the results on limiting cases.
sample_fn = aesara.function((alpha_tt, beta_tt, y_vv), posterior)
assert sample_fn(1.0, 1.0, 1) == pytest.approx(1.0, abs=0.3) # only successes
assert sample_fn(1.0, 1.0, 0) == pytest.approx(0.0, abs=0.3) # no success


@pytest.mark.xfail(
reason="Op.__call__ does not dispatch to Op.make_node for some RandomVariable and etuple evaluation returns an error"
)
def test_beta_bernoulli_conjugate_expand():
"""Expand a contracted beta-binomial observation model."""

srng = RandomStream(0)

alpha_tt = at.scalar("alpha")
beta_tt = at.scalar("beta")
y_vv = at.iscalar("y")
n_tt = at.iscalar("n")
Y_rv = srng.beta(alpha_tt + y_vv, beta_tt + n_tt - y_vv)

e_lv = var()
(expanded_expr,) = run(1, e_lv, beta_bernoulli_conjugateo(e_lv, y_vv, Y_rv))
expanded = eval_if_etuple(expanded_expr)

assert isinstance(expanded.owner.op, type(at.random.beta))


def test_uniform_pareto_conjugate_contract():
"""Produce the closed-form posterior for the uniform observation model with
a pareto prior.

"""
srng = RandomStream(0)

xm_tt = at.scalar("xm")
k_tt = at.scalar("k")
theta_rv = srng.pareto(k_tt, xm_tt, name="theta")

# zero = at.iscalar("zero")
Y_rv = srng.uniform(0, theta_rv)
y_vv = Y_rv.clone()
y_vv.tag.name = "y"

q_lv = var()
(posterior_expr,) = run(1, q_lv, uniform_pareto_conjugateo(y_vv, Y_rv, q_lv))
posterior = eval_if_etuple(posterior_expr)

assert isinstance(posterior.owner.op, type(at.random.pareto))

# Build the sampling function and check the results on limiting cases.
sample_fn = aesara.function((xm_tt, k_tt, y_vv), posterior)
assert sample_fn(1.0, 1000, 1) == pytest.approx(1.0, abs=0.01) # k = 1000
assert sample_fn(1.0, 1, 0) == pytest.approx(0.0, abs=0.01) # all zeros


def test_uniform_pareto_binomial_conjugate_expand():
"""Expand a contracted beta-binomial observation model."""

srng = RandomStream(0)

k_tt = at.scalar("k")
y_vv = at.iscalar("y")
n_tt = at.scalar("n")

Y_rv = srng.pareto(at.max(y_vv), k_tt + n_tt)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, the test graph looks like this:

>>> aesara.dprint(Y_rv)
pareto_rv{0, (0, 0), floatX, False}.1 [id A]
 |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FC4162BFF20>) [id B]
 |TensorConstant{[]} [id C]
 |TensorConstant{11} [id D]
 |MaxAndArgmax{axis=()}.0 [id E] 'max'
 | |y [id F]
 |Elemwise{add,no_inplace} [id G]
   |k [id H]
   |n [id I]

and that MaxAndArgmax Op isn't same as the at.math.max used in the etuple graph. at.math.max is a function that constructs a MaxAndArgmax Op and uses it to further construct a graph for the max of its argument. In other words, we need an etuple form/"pattern" that matches the types of graphs output by the helper function at.math.max.

Often the easiest way to find etuple forms for the graphs constructed by helper functions is to etuplize said graphs and spot their generalities.
For example:

>>> from etuples import etuplize
>>> etuplize(at.math.max(at.vector("x")))
e(e(aesara.tensor.math.MaxAndArgmax, (0,)), x)
>>> etuplize(at.math.max(at.matrix("x")))
e(e(aesara.tensor.math.MaxAndArgmax, (0, 1)), x)

As we can see, the axis property in the MaxAndArgmax Op will change according to the dimensions of the input (i.e. it computes the max across all dimensions), so we don't want to use a very specific value for the matching form. Instead, we can use another logic variable in place of those values.

Here's a general testing setup for that part of the problem:

import aesara
import aesara.tensor as at

from etuples import etuplize, etuple


srng = at.random.RandomStream(0)

k_tt = at.scalar("k")
y_vv = at.iscalar("y")
n_tt = at.scalar("n")

Y_rv = srng.pareto(at.max(y_vv), k_tt + n_tt)

# This is what we need to match/unify:
etuplize(Y_rv)
# e(
#     e(aesara.tensor.random.basic.ParetoRV, 'pareto', 0, (0, 0), 'floatX', False),
#     RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FA1F9B3D9E0>),
#     TensorConstant{[]},
#     TensorConstant{11},
#     e(e(aesara.tensor.math.MaxAndArgmax, ()), y),
#     e(
#         e(
#             aesara.tensor.elemwise.Elemwise,
#             <aesara.scalar.basic.Add at 0x7fa1fd3823d0>,
#             <frozendict {}>),
#         k,
#         n))

from unification import var
from kanren import run, eq
from aesara.tensor.math import MaxAndArgmax


observed_val = var()
axis_lv = var()
new_x_et = etuple(etuple(MaxAndArgmax, axis_lv), observed_val)

k_lv, n_lv = var(), var()
new_k_et = etuple(etuplize(at.add), k_lv, n_lv)

theta_rng_lv = var()
theta_size_lv = var()
theta_type_idx_lv = var()
theta_posterior_et = etuple(
    etuplize(at.random.pareto),
    theta_rng_lv,
    theta_size_lv,
    theta_type_idx_lv,
    new_x_et,
    new_k_et,
)


run(0, (new_x_et, new_k_et), eq(Y_rv, theta_posterior_et))
# ((e(e(aesara.tensor.math.MaxAndArgmax, ()), y),
#   e(
#       e(
#           aesara.tensor.elemwise.Elemwise,
#           <aesara.scalar.basic.Add at 0x7fa1fd3823d0>,
#           <frozendict {}>),
#       k,
#       n)),)

etuplize(Y_rv)

# e_lv = var()
# (expanded_expr,) = run(1, e_lv, uniform_pareto_conjugateo(e_lv, y_vv, Y_rv))
# expanded = eval_if_etuple(expanded_expr)

# assert isinstance(expanded.owner.op, type(at.random.pareto))
from aesara.tensor.math import MaxAndArgmax
from kanren import eq, run
from unification import var

observed_val = var()
axis_lv = var()
new_x_et = etuple(etuple(MaxAndArgmax, axis_lv), observed_val)

k_lv, n_lv = var(), var()
new_k_et = etuple(etuplize(at.add), k_lv, n_lv)

theta_rng_lv = var()
theta_size_lv = var()
theta_type_idx_lv = var()
theta_posterior_et = etuple(
etuplize(at.random.pareto),
theta_rng_lv,
theta_size_lv,
theta_type_idx_lv,
new_x_et,
new_k_et,
)

run(0, (new_x_et, new_k_et), eq(Y_rv, theta_posterior_et))