From fcca4c7d5d0cb71cac454cf423cbdec27607bce4 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 23 Mar 2026 23:59:11 +0530 Subject: [PATCH 1/3] Remove LogDensityFunctionWrapper and replace VarInfo(mld, ...) with InitFromVector --- HISTORY.md | 7 + docs/make.jl | 3 - docs/src/accs/threadsafe.md | 2 +- docs/src/api.md | 4 +- ext/DynamicPPLMarginalLogDensitiesExt.jl | 150 ++++-------------- test/ext/DynamicPPLMarginalLogDensitiesExt.jl | 24 ++- 6 files changed, 63 insertions(+), 127 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 17c328988..9c218947f 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,10 @@ +# 0.41 + +Removed `LogDensityFunctionWrapper` and `VarInfo(::MarginalLogDensity, ...)` +from the MarginalLogDensities extension. Users should now use +`DynamicPPL.InitFromVector(mld, ...)` to obtain an initialisation strategy +and pass it to `init!!` to get a consistent `VarInfo`. + # 0.40.14 Fixed `check_model()` erroneously failing for models such as `x[1:2] .~ univariate_dist`. diff --git a/docs/make.jl b/docs/make.jl index 389cd7921..98258af75 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -16,9 +16,6 @@ using AbstractMCMC: AbstractMCMC using MarginalLogDensities: MarginalLogDensities using Random -# Need this to document a method which uses a type inside the extension... -DPPLMLDExt = Base.get_extension(DynamicPPL, :DynamicPPLMarginalLogDensitiesExt) - # Doctest setup DocMeta.setdocmeta!( DynamicPPL, :DocTestSetup, :(using DynamicPPL, MCMCChains); recursive=true diff --git a/docs/src/accs/threadsafe.md b/docs/src/accs/threadsafe.md index a01f5dbc4..788cd1985 100644 --- a/docs/src/accs/threadsafe.md +++ b/docs/src/accs/threadsafe.md @@ -17,7 +17,7 @@ model = setthreadsafe(g(y), true) This is accomplished by creating one copy of each accumulator per thread (using `DynamicPPL.split`), and then after the model evaluation is complete, merging the result of each thread's accumulator with `DynamicPPL.combine`. -**This means that if you are implementing your own accumulator, you will need to implement the `split` and `combine` methods for it in order for it work correctly in thread-safe mode.** +**This means that if you are implementing your own accumulator, you will need to implement the `split` and `combine` methods for it in order for it to work correctly in thread-safe mode.** Each accumulator sees only the tilde-statements that were executed on its own thread. However, the intent is that after merging the results from all threads, the final accumulator should be equivalent to what would have been obtained by a single-threaded evaluation (modulo ordering). diff --git a/docs/src/api.md b/docs/src/api.md index 77066b89c..5d1b54b28 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -169,10 +169,10 @@ marginalize ``` A `MarginalLogDensity` object acts as a function which maps non-marginalised parameter values to a marginal log-probability. -To retrieve a VarInfo object from it, you can use: +To obtain an initialisation strategy reflecting the state of the marginalisation, you can use: ```@docs -VarInfo(::MarginalLogDensities.MarginalLogDensity{<:DPPLMLDExt.LogDensityFunctionWrapper}, ::Union{AbstractVector,Nothing}) +InitFromVector(::MarginalLogDensities.MarginalLogDensity{<:DynamicPPL.LogDensityFunction}, ::Union{AbstractVector,Nothing}) ``` ## Models within models diff --git a/ext/DynamicPPLMarginalLogDensitiesExt.jl b/ext/DynamicPPLMarginalLogDensitiesExt.jl index 348152d30..b2f233c6a 100644 --- a/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ b/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -3,19 +3,11 @@ module DynamicPPLMarginalLogDensitiesExt using DynamicPPL: DynamicPPL, LogDensityProblems, VarName, RangeAndLinked using MarginalLogDensities: MarginalLogDensities -# A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by -# MarginalLogDensities. It's helpful to have a struct so that we can dispatch on its type -# below. -struct LogDensityFunctionWrapper{ - L<:DynamicPPL.LogDensityFunction,V<:DynamicPPL.AbstractVarInfo -} - logdensity::L - # This field is used only to reconstruct the VarInfo later on; it's not needed for the - # actual log-density evaluation. - varinfo::V -end -function (lw::LogDensityFunctionWrapper)(x, _) - return LogDensityProblems.logdensity(lw.logdensity, x) +# Make LogDensityFunction directly callable with the two-argument interface expected by +# MarginalLogDensities. The second argument is the gradient and is unused here because +# MarginalLogDensities handles differentiation separately. +function (ldf::DynamicPPL.LogDensityFunction)(x, _) + return LogDensityProblems.logdensity(ldf, x) end """ @@ -53,7 +45,6 @@ log-density. constructor. ## Example - ```jldoctest julia> using DynamicPPL, Distributions, MarginalLogDensities @@ -80,12 +71,11 @@ julia> logpdf(Normal(2.0), 1.0) marginal log-density can be performed in unconstrained space. However, care must be taken if the model contains variables where the link transformation depends on a marginalized variable. For example: - ```julia - @model function f() - x ~ Normal() - y ~ truncated(Normal(); lower=x) - end + @model function f() + x ~ Normal() + y ~ truncated(Normal(); lower=x) + end ``` Here, the support of `y`, and hence the link transformation used, depends on the value @@ -101,7 +91,7 @@ function DynamicPPL.marginalize( method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(), kwargs..., ) - # Construct the marginal log-density model. + # Construct the log-density function directly from the model and varinfo. ldf = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo) # Determine the indices for the variables to marginalise out. varindices = mapreduce(vcat, marginalized_varnames) do vn @@ -110,121 +100,49 @@ function DynamicPPL.marginalize( (ldf._varname_ranges[vn]::RangeAndLinked).range end mld = MarginalLogDensities.MarginalLogDensity( - LogDensityFunctionWrapper(ldf, varinfo), - varinfo[:], - varindices, - (), - method; - kwargs..., + ldf, varinfo[:], varindices, (), method; kwargs... ) return mld end """ - VarInfo( - mld::MarginalLogDensities.MarginalLogDensity{<:LogDensityFunctionWrapper}, + InitFromVector( + mld::MarginalLogDensities.MarginalLogDensity{<:DynamicPPL.LogDensityFunction}, unmarginalized_params::Union{AbstractVector,Nothing}=nothing ) -Retrieve the `VarInfo` object used in the marginalisation process. - -If a Laplace approximation was used for the marginalisation, the values of the marginalized -parameters are also set to their mode (note that this only happens if the `mld` object has -been used to compute the marginal log-density at least once, so that the mode has been -computed). +Return an [`InitFromVector`](@ref DynamicPPL.InitFromVector) initialisation strategy whose +parameter vector reflects the state of `mld`. -If a vector of `unmarginalized_params` is specified, the values for the corresponding -parameters will also be updated in the returned VarInfo. This vector may be obtained e.g. by -performing an optimization of the marginal log-density. +If a Laplace approximation was used for marginalisation, the marginalized parameters are set +to their modal values (note that this requires `mld` to have been evaluated at least once, +so that the mode has been found). -All other aspects of the VarInfo, such as link status, are preserved from the original -VarInfo used in the marginalisation. - -!!! note - - The other fields of the VarInfo, e.g. accumulated log-probabilities, will not be - updated. If you wish to obtain updated log-probabilities, you should re-evaluate the - model with the values inside the returned VarInfo, for example using: - - ```julia - init_strategy = DynamicPPL.InitFromParams(varinfo.values, nothing) - oavi = DynamicPPL.OnlyAccsVarInfo(( - DynamicPPL.LogPriorAccumulator(), - DynamicPPL.LogLikelihoodAccumulator(), - DynamicPPL.RawValueAccumulator(false), - # ... whatever else you need - )) - _, oavi = DynamicPPL.init!!(rng, model, oavi, init_strategy, DynamicPPL.UnlinkAll()) - ``` - - You can then extract all the updated data from `oavi`. - -## Example - -```jldoctest -julia> using DynamicPPL, Distributions, MarginalLogDensities - -julia> @model function demo() - x ~ Normal() - y ~ Beta(2, 2) - end -demo (generic function with 2 methods) +If `unmarginalized_params` is provided, those values are used for the non-marginalized +parameters. This vector may be obtained e.g. by optimizing the marginal log-density. -julia> # Note that by default `marginalize` uses a linked VarInfo. - mld = marginalize(demo(), [@varname(x)]); - -julia> using MarginalLogDensities: Optimization, OptimizationOptimJL - -julia> # Find the mode of the marginal log-density of `y`, with an initial point of `y0`. - y0 = 2.0; opt_problem = Optimization.OptimizationProblem(mld, [y0]) -OptimizationProblem. In-place: true -u0: 1-element Vector{Float64}: - 2.0 - -julia> # This tells us the optimal (linked) value of `y` is around 0. - opt_solution = Optimization.solve(opt_problem, OptimizationOptimJL.NelderMead()) -retcode: Success -u: 1-element Vector{Float64}: - 4.88281250001733e-5 - -julia> # Get the VarInfo corresponding to the mode of `y`. - vi = VarInfo(mld, opt_solution.u); - -julia> # `x` is set to its mode (which for `Normal()` is zero). - vi[@varname(x)] -0.0 - -julia> # `y` is set to the optimal value we found above. - DynamicPPL.getindex_internal(vi, @varname(y)) -1-element Vector{Float64}: - 4.88281250001733e-5 - -julia> # To obtain values in the original constrained space, we can either - # use `getindex`: - vi[@varname(y)] -0.5000122070312476 - -julia> # Or invlink the entire VarInfo object using the model: - vi_unlinked = DynamicPPL.invlink(vi, demo()); vi_unlinked[:] -2-element Vector{Float64}: - 0.0 - 0.5000122070312476 +To obtain a fully consistent `VarInfo` — with updated log-probabilities and correct link +status — use the returned strategy to re-evaluate the model: +```julia +init_strategy = DynamicPPL.InitFromVector(mld, opt_solution.u) +ldf = mld.logdensity +_, vi = DynamicPPL.init!!(ldf.model, DynamicPPL.VarInfo(), init_strategy, ldf.transform_strategy) ``` """ -function DynamicPPL.VarInfo( - mld::MarginalLogDensities.MarginalLogDensity{<:LogDensityFunctionWrapper}, +function DynamicPPL.InitFromVector( + mld::MarginalLogDensities.MarginalLogDensity{<:DynamicPPL.LogDensityFunction}, unmarginalized_params::Union{AbstractVector,Nothing}=nothing, ) - # Extract the original VarInfo. Its contents will in general be junk. - original_vi = mld.logdensity.varinfo - # Extract the stored parameters, which includes the modes for any marginalized - # parameters + # Retrieve the full cached parameter vector (includes modal values for marginalized + # parameters if a Laplace approximation has been run). full_params = MarginalLogDensities.cached_params(mld) - # We can then (if needed) set the values for any non-marginalized parameters + # Overwrite the non-marginalized entries if the caller supplied them. if unmarginalized_params !== nothing full_params[MarginalLogDensities.ijoint(mld)] = unmarginalized_params end - return DynamicPPL.unflatten!!(original_vi, full_params) + # Use the convenience constructor that reads varname_ranges and transform_strategy + # directly from the LogDensityFunction stored inside mld. + return DynamicPPL.InitFromVector(full_params, mld.logdensity) end end diff --git a/test/ext/DynamicPPLMarginalLogDensitiesExt.jl b/test/ext/DynamicPPLMarginalLogDensitiesExt.jl index 32c4bb479..3aaf81f5a 100644 --- a/test/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ b/test/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -79,10 +79,17 @@ using ADTypes: AutoForwardDiff @testset "unlinked VarInfo" begin mx = marginalize(model, [@varname(x)]; varinfo=vi_unlinked) - mx([0.5]) # evaluate at some point to force calculation of Laplace approx - vi = VarInfo(mx) + mx([0.5]) # evaluate to force the Laplace approximation to run and cache modal values + strategy = DynamicPPL.InitFromVector(mx) # build init strategy from cached modal values + ldf = mx.logdensity + _, vi = DynamicPPL.init!!( + ldf.model, DynamicPPL.VarInfo(), strategy, ldf.transform_strategy + ) @test vi[@varname(x)] ≈ mode(Normal()) - vi = VarInfo(mx, [0.5]) # this 0.5 is unlinked + strategy = DynamicPPL.InitFromVector(mx, [0.5]) # same, but override the unmarginalized parameter with 0.5 + _, vi = DynamicPPL.init!!( + ldf.model, DynamicPPL.VarInfo(), strategy, ldf.transform_strategy + ) @test vi[@varname(x)] ≈ mode(Normal()) @test vi[@varname(y)] ≈ 0.5 end @@ -90,9 +97,16 @@ using ADTypes: AutoForwardDiff @testset "linked VarInfo" begin mx = marginalize(model, [@varname(x)]; varinfo=vi_linked) mx([0.5]) # evaluate at some point to force calculation of Laplace approx - vi = VarInfo(mx) + strategy = DynamicPPL.InitFromVector(mx) # build init strategy from cached modal values + ldf = mx.logdensity + _, vi = DynamicPPL.init!!( + ldf.model, DynamicPPL.VarInfo(), strategy, ldf.transform_strategy + ) @test vi[@varname(x)] ≈ mode(Normal()) - vi = VarInfo(mx, [0.5]) # this 0.5 is linked + strategy = DynamicPPL.InitFromVector(mx, [0.5]) # this 0.5 is a linked value for the unmarginalized parameter y + _, vi = DynamicPPL.init!!( + ldf.model, DynamicPPL.VarInfo(), strategy, ldf.transform_strategy + ) binv = Bijectors.inverse(Bijectors.bijector(Beta(2, 2))) @test vi[@varname(x)] ≈ mode(Normal()) # when using getindex it always returns unlinked values From effb4508f7efb6f330fa21f6780fc10c9769d6f7 Mon Sep 17 00:00:00 2001 From: = Date: Tue, 24 Mar 2026 19:15:24 +0530 Subject: [PATCH 2/3] Fix formatting --- HISTORY.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 9c218947f..2d412fd25 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,8 +1,8 @@ # 0.41 -Removed `LogDensityFunctionWrapper` and `VarInfo(::MarginalLogDensity, ...)` -from the MarginalLogDensities extension. Users should now use -`DynamicPPL.InitFromVector(mld, ...)` to obtain an initialisation strategy +Removed `LogDensityFunctionWrapper` and `VarInfo(::MarginalLogDensity, ...)` +from the MarginalLogDensities extension. Users should now use +`DynamicPPL.InitFromVector(mld, ...)` to obtain an initialisation strategy and pass it to `init!!` to get a consistent `VarInfo`. # 0.40.14 From 05347ca7fe0f7cd9b64b77d5851fe4a8671edad9 Mon Sep 17 00:00:00 2001 From: = Date: Tue, 24 Mar 2026 19:24:33 +0530 Subject: [PATCH 3/3] Remove stray merge conflict marker from HISTORY.md --- HISTORY.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index eee83831a..2436899b1 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,11 +1,10 @@ - # 0.41 Removed `LogDensityFunctionWrapper` and `VarInfo(::MarginalLogDensity, ...)` from the MarginalLogDensities extension. Users should now use `DynamicPPL.InitFromVector(mld, ...)` to obtain an initialisation strategy and pass it to `init!!` to get a consistent `VarInfo`. -======= + # 0.40.15 DynamicPPL now allows you to set the type that log-probabilities are initialised with, using the `DynamicPPL.set_logprob_type!` function.