Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# 0.41

Removed `LogDensityFunctionWrapper` and `VarInfo(::MarginalLogDensity, ...)`
from the MarginalLogDensities extension. Users should now use
`DynamicPPL.InitFromVector(mld, ...)` to obtain an initialisation strategy
and pass it to `init!!` to get a consistent `VarInfo`.

Comment on lines +3 to +7
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Please could you rebase against the breaking branch, since this is an API change?
  2. The changelog doesn't need to be formatted to a certain character width; we use semantic linebreaks in the changelog (i.e. each sentence is on its own line). You can see the rest of the changelog for examples.

# 0.40.15

DynamicPPL now allows you to set the type that log-probabilities are initialised with, using the `DynamicPPL.set_logprob_type!` function.
Expand Down
3 changes: 0 additions & 3 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ using AbstractMCMC: AbstractMCMC
using MarginalLogDensities: MarginalLogDensities
using Random

# Need this to document a method which uses a type inside the extension...
DPPLMLDExt = Base.get_extension(DynamicPPL, :DynamicPPLMarginalLogDensitiesExt)

# Doctest setup
DocMeta.setdocmeta!(
DynamicPPL, :DocTestSetup, :(using DynamicPPL, MCMCChains); recursive=true
Expand Down
4 changes: 2 additions & 2 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,10 @@ marginalize
```

A `MarginalLogDensity` object acts as a function which maps non-marginalised parameter values to a marginal log-probability.
To retrieve a VarInfo object from it, you can use:
To obtain an initialisation strategy reflecting the state of the marginalisation, you can use:

```@docs
VarInfo(::MarginalLogDensities.MarginalLogDensity{<:DPPLMLDExt.LogDensityFunctionWrapper}, ::Union{AbstractVector,Nothing})
InitFromVector(::MarginalLogDensities.MarginalLogDensity{<:DynamicPPL.LogDensityFunction}, ::Union{AbstractVector,Nothing})
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am concerned that InitFromVector is not a super clear name and that we may need to rename it later, whereas VarInfor is more stable.

Copy link
Copy Markdown
Member

@penelopeysm penelopeysm Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It already exists, though: https://turinglang.org/DynamicPPL.jl/stable/ldf/models/ and it's not really that complicated to use.

I'm conscious of backwards compatibility, but I really strongly believe that VarInfo is not the right thing to keep around. As you are aware, there have been many perf gains in DynamicPPL recently. But we haven't managed to do that because I'm some hacker Linus Torvalds genius who optimised low level code: they're just because we have better, more modular, interfaces that don't make users pay for things that they don't need.

VarInfo is the exact antithesis of that: it's a struct that bundles lots of stuff together and forces users to pay for things they don't need. Besides the performance issues, it also leads to incorrect results with unflatten!! because log-probs and transforms are not updated in tandem with the parameters, which is an issue I've pointed out many times. That's exactly the reason why I think we should remove this method.

If we want Turing/DynamicPPL to be something that people trust, then correctness is very, very important. I think it's the single most important thing, even more so than performance.

```

## Models within models
Expand Down
150 changes: 34 additions & 116 deletions ext/DynamicPPLMarginalLogDensitiesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,11 @@ module DynamicPPLMarginalLogDensitiesExt
using DynamicPPL: DynamicPPL, LogDensityProblems, VarName, RangeAndLinked
using MarginalLogDensities: MarginalLogDensities

# A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by
# MarginalLogDensities. It's helpful to have a struct so that we can dispatch on its type
# below.
struct LogDensityFunctionWrapper{
L<:DynamicPPL.LogDensityFunction,V<:DynamicPPL.AbstractVarInfo
}
logdensity::L
# This field is used only to reconstruct the VarInfo later on; it's not needed for the
# actual log-density evaluation.
varinfo::V
end
function (lw::LogDensityFunctionWrapper)(x, _)
return LogDensityProblems.logdensity(lw.logdensity, x)
# Make LogDensityFunction directly callable with the two-argument interface expected by
# MarginalLogDensities. The second argument is the gradient and is unused here because
# MarginalLogDensities handles differentiation separately.
function (ldf::DynamicPPL.LogDensityFunction)(x, _)
return LogDensityProblems.logdensity(ldf, x)
end
Comment on lines +6 to 11
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to not define this function. This makes all LogDensityFunctions callable structs, and it can be very confusing to debug behaviour because if you load this extension it will be callable, and if you don't it won't.


"""
Expand Down Expand Up @@ -53,7 +45,6 @@ log-density.
constructor.

## Example

```jldoctest
julia> using DynamicPPL, Distributions, MarginalLogDensities

Expand All @@ -80,12 +71,11 @@ julia> logpdf(Normal(2.0), 1.0)
marginal log-density can be performed in unconstrained space. However, care must be
taken if the model contains variables where the link transformation depends on a
marginalized variable. For example:

```julia
@model function f()
x ~ Normal()
y ~ truncated(Normal(); lower=x)
end
@model function f()
x ~ Normal()
y ~ truncated(Normal(); lower=x)
end
```

Here, the support of `y`, and hence the link transformation used, depends on the value
Expand All @@ -101,7 +91,7 @@ function DynamicPPL.marginalize(
method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(),
kwargs...,
)
# Construct the marginal log-density model.
# Construct the log-density function directly from the model and varinfo.
ldf = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo)
# Determine the indices for the variables to marginalise out.
varindices = mapreduce(vcat, marginalized_varnames) do vn
Expand All @@ -110,121 +100,49 @@ function DynamicPPL.marginalize(
(ldf._varname_ranges[vn]::RangeAndLinked).range
end
mld = MarginalLogDensities.MarginalLogDensity(
LogDensityFunctionWrapper(ldf, varinfo),
varinfo[:],
varindices,
(),
method;
kwargs...,
ldf, varinfo[:], varindices, (), method; kwargs...
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recognise that the MarginalLogDensity interface requires you to pass a function. Instead of making LogDensityFunction callable, you could instead pass Base.Fix1(LogDensityProblems.logdensity, ldf) as the argument here. That is also a callable, but it accomplishes the same thing without making LogDensityFunction itself callable.

)
return mld
end

"""
VarInfo(
mld::MarginalLogDensities.MarginalLogDensity{<:LogDensityFunctionWrapper},
InitFromVector(
mld::MarginalLogDensities.MarginalLogDensity{<:DynamicPPL.LogDensityFunction},
unmarginalized_params::Union{AbstractVector,Nothing}=nothing
)

Retrieve the `VarInfo` object used in the marginalisation process.

If a Laplace approximation was used for the marginalisation, the values of the marginalized
parameters are also set to their mode (note that this only happens if the `mld` object has
been used to compute the marginal log-density at least once, so that the mode has been
computed).
Return an [`InitFromVector`](@ref DynamicPPL.InitFromVector) initialisation strategy whose
parameter vector reflects the state of `mld`.

If a vector of `unmarginalized_params` is specified, the values for the corresponding
parameters will also be updated in the returned VarInfo. This vector may be obtained e.g. by
performing an optimization of the marginal log-density.
If a Laplace approximation was used for marginalisation, the marginalized parameters are set
to their modal values (note that this requires `mld` to have been evaluated at least once,
so that the mode has been found).

All other aspects of the VarInfo, such as link status, are preserved from the original
VarInfo used in the marginalisation.

!!! note

The other fields of the VarInfo, e.g. accumulated log-probabilities, will not be
updated. If you wish to obtain updated log-probabilities, you should re-evaluate the
model with the values inside the returned VarInfo, for example using:

```julia
init_strategy = DynamicPPL.InitFromParams(varinfo.values, nothing)
oavi = DynamicPPL.OnlyAccsVarInfo((
DynamicPPL.LogPriorAccumulator(),
DynamicPPL.LogLikelihoodAccumulator(),
DynamicPPL.RawValueAccumulator(false),
# ... whatever else you need
))
_, oavi = DynamicPPL.init!!(rng, model, oavi, init_strategy, DynamicPPL.UnlinkAll())
```

You can then extract all the updated data from `oavi`.

## Example

```jldoctest
julia> using DynamicPPL, Distributions, MarginalLogDensities

julia> @model function demo()
x ~ Normal()
y ~ Beta(2, 2)
end
demo (generic function with 2 methods)
If `unmarginalized_params` is provided, those values are used for the non-marginalized
parameters. This vector may be obtained e.g. by optimizing the marginal log-density.

julia> # Note that by default `marginalize` uses a linked VarInfo.
mld = marginalize(demo(), [@varname(x)]);

julia> using MarginalLogDensities: Optimization, OptimizationOptimJL

julia> # Find the mode of the marginal log-density of `y`, with an initial point of `y0`.
y0 = 2.0; opt_problem = Optimization.OptimizationProblem(mld, [y0])
OptimizationProblem. In-place: true
u0: 1-element Vector{Float64}:
2.0

julia> # This tells us the optimal (linked) value of `y` is around 0.
opt_solution = Optimization.solve(opt_problem, OptimizationOptimJL.NelderMead())
retcode: Success
u: 1-element Vector{Float64}:
4.88281250001733e-5

julia> # Get the VarInfo corresponding to the mode of `y`.
vi = VarInfo(mld, opt_solution.u);

julia> # `x` is set to its mode (which for `Normal()` is zero).
vi[@varname(x)]
0.0

julia> # `y` is set to the optimal value we found above.
DynamicPPL.getindex_internal(vi, @varname(y))
1-element Vector{Float64}:
4.88281250001733e-5

julia> # To obtain values in the original constrained space, we can either
# use `getindex`:
vi[@varname(y)]
0.5000122070312476

julia> # Or invlink the entire VarInfo object using the model:
vi_unlinked = DynamicPPL.invlink(vi, demo()); vi_unlinked[:]
2-element Vector{Float64}:
0.0
0.5000122070312476
To obtain a fully consistent `VarInfo` — with updated log-probabilities and correct link
status — use the returned strategy to re-evaluate the model:
```julia
init_strategy = DynamicPPL.InitFromVector(mld, opt_solution.u)
ldf = mld.logdensity
_, vi = DynamicPPL.init!!(ldf.model, DynamicPPL.VarInfo(), init_strategy, ldf.transform_strategy)
Comment on lines +124 to +129
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
To obtain a fully consistent `VarInfo` — with updated log-probabilities and correct link
status — use the returned strategy to re-evaluate the model:
```julia
init_strategy = DynamicPPL.InitFromVector(mld, opt_solution.u)
ldf = mld.logdensity
_, vi = DynamicPPL.init!!(ldf.model, DynamicPPL.VarInfo(), init_strategy, ldf.transform_strategy)
The returned `InitFromVector` strategy can then be used to re-evaluate the model (see also
[the DynamicPPL docs](@ref ldf-model). For example, if `opt_solution` is a vector of
unmarginalised parameters obtained from optimisation of the `mld` object, then you can
write:
```julia
init_strategy = DynamicPPL.InitFromVector(mld, opt_solution.u)
ldf = mld.logdensity
accs = DynamicPPL.OnlyAccsVarInfo()
_, accs = DynamicPPL.init!!(ldf.model, accs, init_strategy, ldf.transform_strategy)

```
"""
function DynamicPPL.VarInfo(
mld::MarginalLogDensities.MarginalLogDensity{<:LogDensityFunctionWrapper},
function DynamicPPL.InitFromVector(
mld::MarginalLogDensities.MarginalLogDensity{<:DynamicPPL.LogDensityFunction},
unmarginalized_params::Union{AbstractVector,Nothing}=nothing,
)
# Extract the original VarInfo. Its contents will in general be junk.
original_vi = mld.logdensity.varinfo
# Extract the stored parameters, which includes the modes for any marginalized
# parameters
# Retrieve the full cached parameter vector (includes modal values for marginalized
# parameters if a Laplace approximation has been run).
full_params = MarginalLogDensities.cached_params(mld)
# We can then (if needed) set the values for any non-marginalized parameters
# Overwrite the non-marginalized entries if the caller supplied them.
if unmarginalized_params !== nothing
full_params[MarginalLogDensities.ijoint(mld)] = unmarginalized_params
end
return DynamicPPL.unflatten!!(original_vi, full_params)
# Use the convenience constructor that reads varname_ranges and transform_strategy
# directly from the LogDensityFunction stored inside mld.
return DynamicPPL.InitFromVector(full_params, mld.logdensity)
end

end
24 changes: 19 additions & 5 deletions test/ext/DynamicPPLMarginalLogDensitiesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,20 +79,34 @@ using ADTypes: AutoForwardDiff

@testset "unlinked VarInfo" begin
mx = marginalize(model, [@varname(x)]; varinfo=vi_unlinked)
mx([0.5]) # evaluate at some point to force calculation of Laplace approx
vi = VarInfo(mx)
mx([0.5]) # evaluate to force the Laplace approximation to run and cache modal values
strategy = DynamicPPL.InitFromVector(mx) # build init strategy from cached modal values
ldf = mx.logdensity
_, vi = DynamicPPL.init!!(
ldf.model, DynamicPPL.VarInfo(), strategy, ldf.transform_strategy
)
@test vi[@varname(x)] ≈ mode(Normal())
Comment on lines +85 to 88
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_, vi = DynamicPPL.init!!(
ldf.model, DynamicPPL.VarInfo(), strategy, ldf.transform_strategy
)
@test vi[@varname(x)] mode(Normal())
accs = OnlyAccsVarInfo(RawValueAccumulator(false))
_, accs = DynamicPPL.init!!(
ldf.model, accs, strategy, ldf.transform_strategy
)
@test get_raw_values(accs)[@varname(x)] mode(Normal())

And likewise for the other tests too (see e.g. https://turinglang.org/DynamicPPL.jl/stable/migration/#Getting-parameter-values).

I think it might also be better to just extract this into a separate function (the function could return, e.g., get_raw_values(accs) and then in the tests you can check the values in it). There's quite a lot of code duplication in these tests.

vi = VarInfo(mx, [0.5]) # this 0.5 is unlinked
strategy = DynamicPPL.InitFromVector(mx, [0.5]) # same, but override the unmarginalized parameter with 0.5
_, vi = DynamicPPL.init!!(
ldf.model, DynamicPPL.VarInfo(), strategy, ldf.transform_strategy
)
@test vi[@varname(x)] ≈ mode(Normal())
@test vi[@varname(y)] ≈ 0.5
end

@testset "linked VarInfo" begin
mx = marginalize(model, [@varname(x)]; varinfo=vi_linked)
mx([0.5]) # evaluate at some point to force calculation of Laplace approx
vi = VarInfo(mx)
strategy = DynamicPPL.InitFromVector(mx) # build init strategy from cached modal values
ldf = mx.logdensity
_, vi = DynamicPPL.init!!(
ldf.model, DynamicPPL.VarInfo(), strategy, ldf.transform_strategy
)
@test vi[@varname(x)] ≈ mode(Normal())
vi = VarInfo(mx, [0.5]) # this 0.5 is linked
strategy = DynamicPPL.InitFromVector(mx, [0.5]) # this 0.5 is a linked value for the unmarginalized parameter y
_, vi = DynamicPPL.init!!(
ldf.model, DynamicPPL.VarInfo(), strategy, ldf.transform_strategy
)
binv = Bijectors.inverse(Bijectors.bijector(Beta(2, 2)))
@test vi[@varname(x)] ≈ mode(Normal())
# when using getindex it always returns unlinked values
Expand Down
Loading