From c4ffb390f603f3ef40d3f121662790180bdf1cb9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Thu, 3 Nov 2022 15:17:19 +0100 Subject: [PATCH] Implement log-probability for StudentTRV --- aeppl/logprob.py | 14 ++++++++++++++ setup.py | 2 +- tests/test_logprob.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 1 deletion(-) diff --git a/aeppl/logprob.py b/aeppl/logprob.py index 8b52949c..28e00dc5 100644 --- a/aeppl/logprob.py +++ b/aeppl/logprob.py @@ -534,6 +534,20 @@ def hypergeometric_logprob(op, values, *inputs, **kwargs): return res +@_logprob.register(arb.StudentTRV) +def studentt_logprob(op, values, *inputs, **kwargs): + (value,) = values + df, loc, scale = inputs[3:] + res = ( + at.gammaln((df + 1.0) / 2.0) + - at.gammaln(df / 2.0) + - 0.5 * at.log(np.pi * df * scale**2) + - (df + 1.0) / 2.0 * at.log1p(((value - loc) / scale) ** 2 / df) + ) + res = CheckParameterValue("scale >= 0")(res, at.all(at.ge(scale, 0.0))) + return res + + @_logprob.register(arb.CategoricalRV) def categorical_logprob(op, values, *inputs, **kwargs): (value,) = values diff --git a/setup.py b/setup.py index abec7a0e..fbfa4ae1 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ install_requires=[ "numpy>=1.18.1", "scipy>=1.4.0", - "aesara >= 2.8.5", + "aesara >= 2.8.8", ], tests_require=["pytest"], long_description=open("README.rst").read() if exists("README.rst") else "", diff --git a/tests/test_logprob.py b/tests/test_logprob.py index 642204f2..1c45fff8 100644 --- a/tests/test_logprob.py +++ b/tests/test_logprob.py @@ -883,6 +883,34 @@ def scipy_logprob(obs, good, bad, n): scipy_logprob_tester(x, obs, dist_params, test_fn=scipy_logprob) +@pytest.mark.parametrize( + "dist_params, obs, size, error", + [ + ((1, 0, 2), np.array([-10, 0, 10], dtype=np.float32), (), False), + ((1, 0, 2), np.array([-10, 0, 10], dtype=np.float32), (3, 2), False), + ( + (np.array([10, 5, 3], dtype=np.int64), 1, 2), + np.array([-1, 1, 84], dtype=np.int64), + (), + False, + ), + ], +) +def test_studentt_logprob(dist_params, obs, size, error): + dist_params_at, obs_at, size_at = create_aesara_params(dist_params, obs, size) + dist_params = dict(zip(dist_params_at, dist_params)) + + x = at.random.t(*dist_params_at, size=size_at) + + cm = contextlib.suppress() if not error else pytest.raises(AssertionError) + + def scipy_logprob(obs, df, loc, scale): + return stats.t.logpdf(obs, df, loc=loc, scale=scale) + + with cm: + scipy_logprob_tester(x, obs, dist_params, test_fn=scipy_logprob) + + @pytest.mark.parametrize( "dist_params, obs, size, exc_type, chk_bcast", [