Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
PosteriorStats = "7f36be82-ad55-44ba-a5c0-b8b5480d7aa5"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down Expand Up @@ -69,6 +70,7 @@ MacroTools = "0.5.6"
MarginalLogDensities = "0.4.3"
Mooncake = "0.4.147, 0.5"
OrderedCollections = "1"
PosteriorStats = "0.4.8"
PrecompileTools = "1.2.1"
Printf = "1.10"
Random = "1.6"
Expand Down
66 changes: 49 additions & 17 deletions src/accumulators/pointwise_logdensities.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import PosteriorStats

"""
PointwiseLogProb{Prior,Likelihood}

Expand All @@ -12,12 +14,21 @@ This struct is used in conjunction with `VNTAccumulator`, via
where `Prior` and `Likelihood` are the boolean type parameters. This accumulator will then
store the log-probabilities for all tilde-statements in the model.
"""
struct PointwiseLogProb{Prior,Likelihood} end
function (plp::PointwiseLogProb{Prior,Likelihood})(
struct PointwiseLogProb{Prior,Likelihood,Factorised} end
function (plp::PointwiseLogProb{Prior,Likelihood,Factorised})(
val, tval, logjac, vn, dist
) where {Prior,Likelihood}
if Prior
return logpdf(dist, val)
) where {Prior,Likelihood,Factorised}
return if Prior
if Factorised && hasmethod(
PosteriorStats.pointwise_conditional_loglikelihoods,
Tuple{typeof(val),typeof([dist])},
)
dropdims(
PosteriorStats.pointwise_conditional_loglikelihoods(val, [dist]); dims=1
)
else
logpdf(dist, val)
end
else
return DoNotAccumulate()
end
Expand All @@ -32,15 +43,26 @@ end
# Have to overload accumulate_assume!! since VNTAccumulator by default does not track
# observe statements.
function accumulate_observe!!(
acc::VNTAccumulator{POINTWISE_ACCNAME,PointwiseLogProb{Prior,Likelihood}},
acc::VNTAccumulator{POINTWISE_ACCNAME,PointwiseLogProb{Prior,Likelihood,Factorised}},
right,
left,
vn,
template,
) where {Prior,Likelihood}
) where {Prior,Likelihood,Factorised}
# vn could be `nothing`, in which case we can't store it in a VNT.
return if Likelihood && vn isa VarName
logp = logpdf(right, left)
logp =
if Factorised && hasmethod(
PosteriorStats.pointwise_conditional_loglikelihoods,
Tuple{typeof(left),typeof([right])},
)
dropdims(
PosteriorStats.pointwise_conditional_loglikelihoods(left, [right]);
dims=1,
)
else
logpdf(right, left)
end
new_values = DynamicPPL.templated_setindex!!(acc.values, logp, vn, template)
return VNTAccumulator{POINTWISE_ACCNAME}(acc.f, new_values)
else
Expand All @@ -54,7 +76,8 @@ end
model::Model,
varinfo::AbstractVarInfo,
::Val{Prior}=Val(true),
::Val{Likelihood}=Val(true),
::Val{Likelihood}=Val(true);
factorize=false
) where {Prior,Likelihood}

Shared internal function that computes pointwise log-densities (either priors, likelihoods,
Expand All @@ -64,23 +87,32 @@ function _pointwise_logdensities(
model::Model,
varinfo::AbstractVarInfo,
::Val{Prior}=Val(true),
::Val{Likelihood}=Val(true),
::Val{Likelihood}=Val(true);
factorize=false,
) where {Prior,Likelihood}
acc = VNTAccumulator{POINTWISE_ACCNAME}(PointwiseLogProb{Prior,Likelihood}())
acc = VNTAccumulator{POINTWISE_ACCNAME}(PointwiseLogProb{Prior,Likelihood,factorize}())
oavi = OnlyAccsVarInfo(acc)
init_strategy = InitFromParams(varinfo.values, nothing)
oavi = last(init!!(model, oavi, init_strategy, UnlinkAll()))
return get_pointwise_logprobs(oavi)
end

function pointwise_logdensities(model::Model, varinfo::AbstractVarInfo)
return _pointwise_logdensities(model, varinfo, Val(true), Val(true))
function pointwise_logdensities(model::Model, varinfo::AbstractVarInfo; factorize=false)
return _pointwise_logdensities(
model, varinfo, Val(true), Val(true); factorize=factorize
)
end

function pointwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo)
return _pointwise_logdensities(model, varinfo, Val(false), Val(true))
function pointwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo; factorize=false)
return _pointwise_logdensities(
model, varinfo, Val(false), Val(true); factorize=factorize
)
end

function pointwise_prior_logdensities(model::Model, varinfo::AbstractVarInfo)
return _pointwise_logdensities(model, varinfo, Val(true), Val(false))
function pointwise_prior_logdensities(
model::Model, varinfo::AbstractVarInfo; factorize=false
)
return _pointwise_logdensities(
model, varinfo, Val(true), Val(false); factorize=factorize
)
end
Loading