diff --git a/HISTORY.md b/HISTORY.md index 5b4c37b30..2436899b1 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,10 @@ +# 0.41 + +Removed `LogDensityFunctionWrapper` and `VarInfo(::MarginalLogDensity, ...)` +from the MarginalLogDensities extension. Users should now use +`DynamicPPL.InitFromVector(mld, ...)` to obtain an initialisation strategy +and pass it to `init!!` to get a consistent `VarInfo`. + # 0.40.15 DynamicPPL now allows you to set the type that log-probabilities are initialised with, using the `DynamicPPL.set_logprob_type!` function. diff --git a/docs/make.jl b/docs/make.jl index 389cd7921..98258af75 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -16,9 +16,6 @@ using AbstractMCMC: AbstractMCMC using MarginalLogDensities: MarginalLogDensities using Random -# Need this to document a method which uses a type inside the extension... -DPPLMLDExt = Base.get_extension(DynamicPPL, :DynamicPPLMarginalLogDensitiesExt) - # Doctest setup DocMeta.setdocmeta!( DynamicPPL, :DocTestSetup, :(using DynamicPPL, MCMCChains); recursive=true diff --git a/docs/src/api.md b/docs/src/api.md index f6e66b011..88d2e7393 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -170,10 +170,10 @@ marginalize ``` A `MarginalLogDensity` object acts as a function which maps non-marginalised parameter values to a marginal log-probability. -To retrieve a VarInfo object from it, you can use: +To obtain an initialisation strategy reflecting the state of the marginalisation, you can use: ```@docs -VarInfo(::MarginalLogDensities.MarginalLogDensity{<:DPPLMLDExt.LogDensityFunctionWrapper}, ::Union{AbstractVector,Nothing}) +InitFromVector(::MarginalLogDensities.MarginalLogDensity{<:DynamicPPL.LogDensityFunction}, ::Union{AbstractVector,Nothing}) ``` ## Models within models diff --git a/ext/DynamicPPLMarginalLogDensitiesExt.jl b/ext/DynamicPPLMarginalLogDensitiesExt.jl index 348152d30..b2f233c6a 100644 --- a/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ b/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -3,19 +3,11 @@ module DynamicPPLMarginalLogDensitiesExt using DynamicPPL: DynamicPPL, LogDensityProblems, VarName, RangeAndLinked using MarginalLogDensities: MarginalLogDensities -# A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by -# MarginalLogDensities. It's helpful to have a struct so that we can dispatch on its type -# below. -struct LogDensityFunctionWrapper{ - L<:DynamicPPL.LogDensityFunction,V<:DynamicPPL.AbstractVarInfo -} - logdensity::L - # This field is used only to reconstruct the VarInfo later on; it's not needed for the - # actual log-density evaluation. - varinfo::V -end -function (lw::LogDensityFunctionWrapper)(x, _) - return LogDensityProblems.logdensity(lw.logdensity, x) +# Make LogDensityFunction directly callable with the two-argument interface expected by +# MarginalLogDensities. The second argument is the gradient and is unused here because +# MarginalLogDensities handles differentiation separately. +function (ldf::DynamicPPL.LogDensityFunction)(x, _) + return LogDensityProblems.logdensity(ldf, x) end """ @@ -53,7 +45,6 @@ log-density. constructor. ## Example - ```jldoctest julia> using DynamicPPL, Distributions, MarginalLogDensities @@ -80,12 +71,11 @@ julia> logpdf(Normal(2.0), 1.0) marginal log-density can be performed in unconstrained space. However, care must be taken if the model contains variables where the link transformation depends on a marginalized variable. For example: - ```julia - @model function f() - x ~ Normal() - y ~ truncated(Normal(); lower=x) - end + @model function f() + x ~ Normal() + y ~ truncated(Normal(); lower=x) + end ``` Here, the support of `y`, and hence the link transformation used, depends on the value @@ -101,7 +91,7 @@ function DynamicPPL.marginalize( method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(), kwargs..., ) - # Construct the marginal log-density model. + # Construct the log-density function directly from the model and varinfo. ldf = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo) # Determine the indices for the variables to marginalise out. varindices = mapreduce(vcat, marginalized_varnames) do vn @@ -110,121 +100,49 @@ function DynamicPPL.marginalize( (ldf._varname_ranges[vn]::RangeAndLinked).range end mld = MarginalLogDensities.MarginalLogDensity( - LogDensityFunctionWrapper(ldf, varinfo), - varinfo[:], - varindices, - (), - method; - kwargs..., + ldf, varinfo[:], varindices, (), method; kwargs... ) return mld end """ - VarInfo( - mld::MarginalLogDensities.MarginalLogDensity{<:LogDensityFunctionWrapper}, + InitFromVector( + mld::MarginalLogDensities.MarginalLogDensity{<:DynamicPPL.LogDensityFunction}, unmarginalized_params::Union{AbstractVector,Nothing}=nothing ) -Retrieve the `VarInfo` object used in the marginalisation process. - -If a Laplace approximation was used for the marginalisation, the values of the marginalized -parameters are also set to their mode (note that this only happens if the `mld` object has -been used to compute the marginal log-density at least once, so that the mode has been -computed). +Return an [`InitFromVector`](@ref DynamicPPL.InitFromVector) initialisation strategy whose +parameter vector reflects the state of `mld`. -If a vector of `unmarginalized_params` is specified, the values for the corresponding -parameters will also be updated in the returned VarInfo. This vector may be obtained e.g. by -performing an optimization of the marginal log-density. +If a Laplace approximation was used for marginalisation, the marginalized parameters are set +to their modal values (note that this requires `mld` to have been evaluated at least once, +so that the mode has been found). -All other aspects of the VarInfo, such as link status, are preserved from the original -VarInfo used in the marginalisation. - -!!! note - - The other fields of the VarInfo, e.g. accumulated log-probabilities, will not be - updated. If you wish to obtain updated log-probabilities, you should re-evaluate the - model with the values inside the returned VarInfo, for example using: - - ```julia - init_strategy = DynamicPPL.InitFromParams(varinfo.values, nothing) - oavi = DynamicPPL.OnlyAccsVarInfo(( - DynamicPPL.LogPriorAccumulator(), - DynamicPPL.LogLikelihoodAccumulator(), - DynamicPPL.RawValueAccumulator(false), - # ... whatever else you need - )) - _, oavi = DynamicPPL.init!!(rng, model, oavi, init_strategy, DynamicPPL.UnlinkAll()) - ``` - - You can then extract all the updated data from `oavi`. - -## Example - -```jldoctest -julia> using DynamicPPL, Distributions, MarginalLogDensities - -julia> @model function demo() - x ~ Normal() - y ~ Beta(2, 2) - end -demo (generic function with 2 methods) +If `unmarginalized_params` is provided, those values are used for the non-marginalized +parameters. This vector may be obtained e.g. by optimizing the marginal log-density. -julia> # Note that by default `marginalize` uses a linked VarInfo. - mld = marginalize(demo(), [@varname(x)]); - -julia> using MarginalLogDensities: Optimization, OptimizationOptimJL - -julia> # Find the mode of the marginal log-density of `y`, with an initial point of `y0`. - y0 = 2.0; opt_problem = Optimization.OptimizationProblem(mld, [y0]) -OptimizationProblem. In-place: true -u0: 1-element Vector{Float64}: - 2.0 - -julia> # This tells us the optimal (linked) value of `y` is around 0. - opt_solution = Optimization.solve(opt_problem, OptimizationOptimJL.NelderMead()) -retcode: Success -u: 1-element Vector{Float64}: - 4.88281250001733e-5 - -julia> # Get the VarInfo corresponding to the mode of `y`. - vi = VarInfo(mld, opt_solution.u); - -julia> # `x` is set to its mode (which for `Normal()` is zero). - vi[@varname(x)] -0.0 - -julia> # `y` is set to the optimal value we found above. - DynamicPPL.getindex_internal(vi, @varname(y)) -1-element Vector{Float64}: - 4.88281250001733e-5 - -julia> # To obtain values in the original constrained space, we can either - # use `getindex`: - vi[@varname(y)] -0.5000122070312476 - -julia> # Or invlink the entire VarInfo object using the model: - vi_unlinked = DynamicPPL.invlink(vi, demo()); vi_unlinked[:] -2-element Vector{Float64}: - 0.0 - 0.5000122070312476 +To obtain a fully consistent `VarInfo` — with updated log-probabilities and correct link +status — use the returned strategy to re-evaluate the model: +```julia +init_strategy = DynamicPPL.InitFromVector(mld, opt_solution.u) +ldf = mld.logdensity +_, vi = DynamicPPL.init!!(ldf.model, DynamicPPL.VarInfo(), init_strategy, ldf.transform_strategy) ``` """ -function DynamicPPL.VarInfo( - mld::MarginalLogDensities.MarginalLogDensity{<:LogDensityFunctionWrapper}, +function DynamicPPL.InitFromVector( + mld::MarginalLogDensities.MarginalLogDensity{<:DynamicPPL.LogDensityFunction}, unmarginalized_params::Union{AbstractVector,Nothing}=nothing, ) - # Extract the original VarInfo. Its contents will in general be junk. - original_vi = mld.logdensity.varinfo - # Extract the stored parameters, which includes the modes for any marginalized - # parameters + # Retrieve the full cached parameter vector (includes modal values for marginalized + # parameters if a Laplace approximation has been run). full_params = MarginalLogDensities.cached_params(mld) - # We can then (if needed) set the values for any non-marginalized parameters + # Overwrite the non-marginalized entries if the caller supplied them. if unmarginalized_params !== nothing full_params[MarginalLogDensities.ijoint(mld)] = unmarginalized_params end - return DynamicPPL.unflatten!!(original_vi, full_params) + # Use the convenience constructor that reads varname_ranges and transform_strategy + # directly from the LogDensityFunction stored inside mld. + return DynamicPPL.InitFromVector(full_params, mld.logdensity) end end diff --git a/test/ext/DynamicPPLMarginalLogDensitiesExt.jl b/test/ext/DynamicPPLMarginalLogDensitiesExt.jl index 32c4bb479..3aaf81f5a 100644 --- a/test/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ b/test/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -79,10 +79,17 @@ using ADTypes: AutoForwardDiff @testset "unlinked VarInfo" begin mx = marginalize(model, [@varname(x)]; varinfo=vi_unlinked) - mx([0.5]) # evaluate at some point to force calculation of Laplace approx - vi = VarInfo(mx) + mx([0.5]) # evaluate to force the Laplace approximation to run and cache modal values + strategy = DynamicPPL.InitFromVector(mx) # build init strategy from cached modal values + ldf = mx.logdensity + _, vi = DynamicPPL.init!!( + ldf.model, DynamicPPL.VarInfo(), strategy, ldf.transform_strategy + ) @test vi[@varname(x)] ≈ mode(Normal()) - vi = VarInfo(mx, [0.5]) # this 0.5 is unlinked + strategy = DynamicPPL.InitFromVector(mx, [0.5]) # same, but override the unmarginalized parameter with 0.5 + _, vi = DynamicPPL.init!!( + ldf.model, DynamicPPL.VarInfo(), strategy, ldf.transform_strategy + ) @test vi[@varname(x)] ≈ mode(Normal()) @test vi[@varname(y)] ≈ 0.5 end @@ -90,9 +97,16 @@ using ADTypes: AutoForwardDiff @testset "linked VarInfo" begin mx = marginalize(model, [@varname(x)]; varinfo=vi_linked) mx([0.5]) # evaluate at some point to force calculation of Laplace approx - vi = VarInfo(mx) + strategy = DynamicPPL.InitFromVector(mx) # build init strategy from cached modal values + ldf = mx.logdensity + _, vi = DynamicPPL.init!!( + ldf.model, DynamicPPL.VarInfo(), strategy, ldf.transform_strategy + ) @test vi[@varname(x)] ≈ mode(Normal()) - vi = VarInfo(mx, [0.5]) # this 0.5 is linked + strategy = DynamicPPL.InitFromVector(mx, [0.5]) # this 0.5 is a linked value for the unmarginalized parameter y + _, vi = DynamicPPL.init!!( + ldf.model, DynamicPPL.VarInfo(), strategy, ldf.transform_strategy + ) binv = Bijectors.inverse(Bijectors.bijector(Beta(2, 2))) @test vi[@varname(x)] ≈ mode(Normal()) # when using getindex it always returns unlinked values