From aa2e6e430e324c81f4e90405ca54496bfc45cfb8 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 22 Mar 2026 02:56:52 +0000 Subject: [PATCH] Implement factorised pointwise probabilities --- Project.toml | 2 + src/accumulators/pointwise_logdensities.jl | 66 ++++++++++++++++------ 2 files changed, 51 insertions(+), 17 deletions(-) diff --git a/Project.toml b/Project.toml index af9ae267f..747a8041c 100644 --- a/Project.toml +++ b/Project.toml @@ -21,6 +21,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +PosteriorStats = "7f36be82-ad55-44ba-a5c0-b8b5480d7aa5" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -69,6 +70,7 @@ MacroTools = "0.5.6" MarginalLogDensities = "0.4.3" Mooncake = "0.4.147, 0.5" OrderedCollections = "1" +PosteriorStats = "0.4.8" PrecompileTools = "1.2.1" Printf = "1.10" Random = "1.6" diff --git a/src/accumulators/pointwise_logdensities.jl b/src/accumulators/pointwise_logdensities.jl index 63e5aa4b9..cd175fe22 100644 --- a/src/accumulators/pointwise_logdensities.jl +++ b/src/accumulators/pointwise_logdensities.jl @@ -1,3 +1,5 @@ +import PosteriorStats + """ PointwiseLogProb{Prior,Likelihood} @@ -12,12 +14,21 @@ This struct is used in conjunction with `VNTAccumulator`, via where `Prior` and `Likelihood` are the boolean type parameters. This accumulator will then store the log-probabilities for all tilde-statements in the model. """ -struct PointwiseLogProb{Prior,Likelihood} end -function (plp::PointwiseLogProb{Prior,Likelihood})( +struct PointwiseLogProb{Prior,Likelihood,Factorised} end +function (plp::PointwiseLogProb{Prior,Likelihood,Factorised})( val, tval, logjac, vn, dist -) where {Prior,Likelihood} - if Prior - return logpdf(dist, val) +) where {Prior,Likelihood,Factorised} + return if Prior + if Factorised && hasmethod( + PosteriorStats.pointwise_conditional_loglikelihoods, + Tuple{typeof(val),typeof([dist])}, + ) + dropdims( + PosteriorStats.pointwise_conditional_loglikelihoods(val, [dist]); dims=1 + ) + else + logpdf(dist, val) + end else return DoNotAccumulate() end @@ -32,15 +43,26 @@ end # Have to overload accumulate_assume!! since VNTAccumulator by default does not track # observe statements. function accumulate_observe!!( - acc::VNTAccumulator{POINTWISE_ACCNAME,PointwiseLogProb{Prior,Likelihood}}, + acc::VNTAccumulator{POINTWISE_ACCNAME,PointwiseLogProb{Prior,Likelihood,Factorised}}, right, left, vn, template, -) where {Prior,Likelihood} +) where {Prior,Likelihood,Factorised} # vn could be `nothing`, in which case we can't store it in a VNT. return if Likelihood && vn isa VarName - logp = logpdf(right, left) + logp = + if Factorised && hasmethod( + PosteriorStats.pointwise_conditional_loglikelihoods, + Tuple{typeof(left),typeof([right])}, + ) + dropdims( + PosteriorStats.pointwise_conditional_loglikelihoods(left, [right]); + dims=1, + ) + else + logpdf(right, left) + end new_values = DynamicPPL.templated_setindex!!(acc.values, logp, vn, template) return VNTAccumulator{POINTWISE_ACCNAME}(acc.f, new_values) else @@ -54,7 +76,8 @@ end model::Model, varinfo::AbstractVarInfo, ::Val{Prior}=Val(true), - ::Val{Likelihood}=Val(true), + ::Val{Likelihood}=Val(true); + factorize=false ) where {Prior,Likelihood} Shared internal function that computes pointwise log-densities (either priors, likelihoods, @@ -64,23 +87,32 @@ function _pointwise_logdensities( model::Model, varinfo::AbstractVarInfo, ::Val{Prior}=Val(true), - ::Val{Likelihood}=Val(true), + ::Val{Likelihood}=Val(true); + factorize=false, ) where {Prior,Likelihood} - acc = VNTAccumulator{POINTWISE_ACCNAME}(PointwiseLogProb{Prior,Likelihood}()) + acc = VNTAccumulator{POINTWISE_ACCNAME}(PointwiseLogProb{Prior,Likelihood,factorize}()) oavi = OnlyAccsVarInfo(acc) init_strategy = InitFromParams(varinfo.values, nothing) oavi = last(init!!(model, oavi, init_strategy, UnlinkAll())) return get_pointwise_logprobs(oavi) end -function pointwise_logdensities(model::Model, varinfo::AbstractVarInfo) - return _pointwise_logdensities(model, varinfo, Val(true), Val(true)) +function pointwise_logdensities(model::Model, varinfo::AbstractVarInfo; factorize=false) + return _pointwise_logdensities( + model, varinfo, Val(true), Val(true); factorize=factorize + ) end -function pointwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo) - return _pointwise_logdensities(model, varinfo, Val(false), Val(true)) +function pointwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo; factorize=false) + return _pointwise_logdensities( + model, varinfo, Val(false), Val(true); factorize=factorize + ) end -function pointwise_prior_logdensities(model::Model, varinfo::AbstractVarInfo) - return _pointwise_logdensities(model, varinfo, Val(true), Val(false)) +function pointwise_prior_logdensities( + model::Model, varinfo::AbstractVarInfo; factorize=false +) + return _pointwise_logdensities( + model, varinfo, Val(true), Val(false); factorize=factorize + ) end