Skip to content

Remove LogDensityFunctionWrapper and replace VarInfo(mld, ...) with InitFromVector#1335

Open
anurag-mds wants to merge 4 commits intoTuringLang:mainfrom
anurag-mds:remove-logdensityfunctionwrapper-v2
Open

Remove LogDensityFunctionWrapper and replace VarInfo(mld, ...) with InitFromVector#1335
anurag-mds wants to merge 4 commits intoTuringLang:mainfrom
anurag-mds:remove-logdensityfunctionwrapper-v2

Conversation

@anurag-mds
Copy link
Copy Markdown
Contributor

LogDensityFunctionWrapper was a trivial struct whose sole purpose was to convert the LogDensityFunction type to the two-argument callable signature expected by MarginalLogDensities, and to carry a VarInfo to be reconstructed later. But LogDensityFunction already contains all the information we'll ever need (it contains both the model and the transform_strategy as well as _varname_ranges), so this wrapper was unnecessary. I removed it and added a two-argument callable method to LogDensityFunction in the extension.

The VarInfo(mld, ...), which this method enabled, was also problematic because it would unflatten parameter values into a VarInfo using unflatten!!, which would leave the log probabilities and other accumulated data in an inconsistent state. I've now replaced it with InitFromVector(mld, ...), which returns an initialisation strategy that the user can pass to init!! to get a fully consistent VarInfo by re-computing the model.

Closes #1308

@codecov
Copy link
Copy Markdown

codecov bot commented Mar 23, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 77.92%. Comparing base (a1e8f06) to head (05347ca).

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1335      +/-   ##
==========================================
- Coverage   77.93%   77.92%   -0.02%     
==========================================
  Files          50       50              
  Lines        3585     3583       -2     
==========================================
- Hits         2794     2792       -2     
  Misses        791      791              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@anurag-mds anurag-mds force-pushed the remove-logdensityfunctionwrapper-v2 branch from b2d8a8a to fcca4c7 Compare March 23, 2026 18:29
@yebai
Copy link
Copy Markdown
Member

yebai commented Mar 23, 2026

I’d suggest we keep VarInfo.

@penelopeysm
Copy link
Copy Markdown
Member

I wholly disagree -- the varinfo generated contained wrong data, and it's not something we should be giving people to use.

@yebai
Copy link
Copy Markdown
Member

yebai commented Mar 23, 2026

Ah, okay, maybe I didn't fully get the motivation. Let's chat when we meet.

@anurag-mds
Copy link
Copy Markdown
Contributor Author

Actually I used the JuliaFormatter once directly as

using JuliaFormatter
format(".")

This changed multiple files and also throwed many warnings saying The formatter couldn't be applied here.

then I ran for the specific files I changed

using JuliaFormatter
format(["HISTORY.md", "docs/make.jl", "docs/src/api.md", "ext/DynamicPPLMarginalLogDensitiesExt.jl", "test/ext/DynamicPPLMarginalLogDensitiesExt.jl"])

the formatter worked and I pushed the code but still formatting CI is failing I can't get it why?
Can you help me?

@penelopeysm
Copy link
Copy Markdown
Member

Are you using JuliaFormatter v1? https://turinglang.org/docs/contributing/code-formatting/


```@docs
VarInfo(::MarginalLogDensities.MarginalLogDensity{<:DPPLMLDExt.LogDensityFunctionWrapper}, ::Union{AbstractVector,Nothing})
InitFromVector(::MarginalLogDensities.MarginalLogDensity{<:DynamicPPL.LogDensityFunction}, ::Union{AbstractVector,Nothing})
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am concerned that InitFromVector is not a super clear name and that we may need to rename it later, whereas VarInfor is more stable.

Copy link
Copy Markdown
Member

@penelopeysm penelopeysm Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It already exists, though: https://turinglang.org/DynamicPPL.jl/stable/ldf/models/ and it's not really that complicated to use.

I'm conscious of backwards compatibility, but I really strongly believe that VarInfo is not the right thing to keep around. As you are aware, there have been many perf gains in DynamicPPL recently. But we haven't managed to do that because I'm some hacker Linus Torvalds genius who optimised low level code: they're just because we have better, more modular, interfaces that don't make users pay for things that they don't need.

VarInfo is the exact antithesis of that: it's a struct that bundles lots of stuff together and forces users to pay for things they don't need. Besides the performance issues, it also leads to incorrect results with unflatten!! because log-probs and transforms are not updated in tandem with the parameters, which is an issue I've pointed out many times. That's exactly the reason why I think we should remove this method.

If we want Turing/DynamicPPL to be something that people trust, then correctness is very, very important. I think it's the single most important thing, even more so than performance.

@anurag-mds anurag-mds force-pushed the remove-logdensityfunctionwrapper-v2 branch from 2bfc892 to effb450 Compare March 24, 2026 13:45
@anurag-mds
Copy link
Copy Markdown
Contributor Author

Are you using JuliaFormatter v1? https://turinglang.org/docs/contributing/code-formatting/

Fixed: was using the wrong JuliaFormatter version. Formatting CI passed now.

@anurag-mds
Copy link
Copy Markdown
Contributor Author

anurag-mds commented Mar 25, 2026

I haven't commented yet as I was following the discussion between the maintainers regarding the interface design.

I just wanted to clarify: was the change mentioned by @yebai intended for me to implement, or was it a request for @penelopeysm’s approval?

If it is for me, I will update the name from InitFromVector to VarInfor Is there any other change I should include?

Could you also clarify if VarInfor was a typo? I assume you meant the standard VarInfo.

@penelopeysm
Copy link
Copy Markdown
Member

It's a typo. And it's not just about the name, it's really about the interface design. I don't really want to repeat myself, so I'll just point to #1308 (comment).

Comment on lines +6 to 11
# Make LogDensityFunction directly callable with the two-argument interface expected by
# MarginalLogDensities. The second argument is the gradient and is unused here because
# MarginalLogDensities handles differentiation separately.
function (ldf::DynamicPPL.LogDensityFunction)(x, _)
return LogDensityProblems.logdensity(ldf, x)
end
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to not define this function. This makes all LogDensityFunctions callable structs, and it can be very confusing to debug behaviour because if you load this extension it will be callable, and if you don't it won't.

(),
method;
kwargs...,
ldf, varinfo[:], varindices, (), method; kwargs...
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recognise that the MarginalLogDensity interface requires you to pass a function. Instead of making LogDensityFunction callable, you could instead pass Base.Fix1(LogDensityProblems.logdensity, ldf) as the argument here. That is also a callable, but it accomplishes the same thing without making LogDensityFunction itself callable.

Comment on lines +124 to +129
To obtain a fully consistent `VarInfo` — with updated log-probabilities and correct link
status — use the returned strategy to re-evaluate the model:
```julia
init_strategy = DynamicPPL.InitFromVector(mld, opt_solution.u)
ldf = mld.logdensity
_, vi = DynamicPPL.init!!(ldf.model, DynamicPPL.VarInfo(), init_strategy, ldf.transform_strategy)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
To obtain a fully consistent `VarInfo` — with updated log-probabilities and correct link
status — use the returned strategy to re-evaluate the model:
```julia
init_strategy = DynamicPPL.InitFromVector(mld, opt_solution.u)
ldf = mld.logdensity
_, vi = DynamicPPL.init!!(ldf.model, DynamicPPL.VarInfo(), init_strategy, ldf.transform_strategy)
The returned `InitFromVector` strategy can then be used to re-evaluate the model (see also
[the DynamicPPL docs](@ref ldf-model). For example, if `opt_solution` is a vector of
unmarginalised parameters obtained from optimisation of the `mld` object, then you can
write:
```julia
init_strategy = DynamicPPL.InitFromVector(mld, opt_solution.u)
ldf = mld.logdensity
accs = DynamicPPL.OnlyAccsVarInfo()
_, accs = DynamicPPL.init!!(ldf.model, accs, init_strategy, ldf.transform_strategy)

Comment on lines +85 to 88
_, vi = DynamicPPL.init!!(
ldf.model, DynamicPPL.VarInfo(), strategy, ldf.transform_strategy
)
@test vi[@varname(x)] ≈ mode(Normal())
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_, vi = DynamicPPL.init!!(
ldf.model, DynamicPPL.VarInfo(), strategy, ldf.transform_strategy
)
@test vi[@varname(x)] mode(Normal())
accs = OnlyAccsVarInfo(RawValueAccumulator(false))
_, accs = DynamicPPL.init!!(
ldf.model, accs, strategy, ldf.transform_strategy
)
@test get_raw_values(accs)[@varname(x)] mode(Normal())

And likewise for the other tests too (see e.g. https://turinglang.org/DynamicPPL.jl/stable/migration/#Getting-parameter-values).

I think it might also be better to just extract this into a separate function (the function could return, e.g., get_raw_values(accs) and then in the tests you can check the values in it). There's quite a lot of code duplication in these tests.

Comment on lines +3 to +7
Removed `LogDensityFunctionWrapper` and `VarInfo(::MarginalLogDensity, ...)`
from the MarginalLogDensities extension. Users should now use
`DynamicPPL.InitFromVector(mld, ...)` to obtain an initialisation strategy
and pass it to `init!!` to get a consistent `VarInfo`.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Please could you rebase against the breaking branch, since this is an API change?
  2. The changelog doesn't need to be formatted to a certain character width; we use semantic linebreaks in the changelog (i.e. each sentence is on its own line). You can see the rest of the changelog for examples.

@anurag-mds
Copy link
Copy Markdown
Contributor Author

Hi @penelopeysm, I ran into two issues:

Issue 1: Base.Fix1 doesn't work
When using Base.Fix1(LogDensityProblems.logdensity, ldf), I get:

MethodError: no method matching logdensity(::LogDensityFunction, ::Vector{...}, ::Tuple{})

Because MarginalLogDensities calls the callable with two arguments (x, extra_data), so Base.Fix1 expands to logdensity(ldf, x, extra_data) 3 args instead of 2.

My suggestion is to use an anonymous function instead:

julia(x, _) -> LogDensityProblems.logdensity(ldf, x)

This captures ldf without making LogDensityFunction globally callable, and correctly absorbs the second argument. Could you confirm if this is the right approach?

Issue 2: Tracer types -- <:Real constraint
MarginalLogDensities passes tracer/dual number types during Hessian computation, but logdensity only accepts AbstractVector{<:Real} and tracer types are not <:Real. I removed the <:Real constraint in 3 places in src/logdensityfunction.jl (logdensity, logdensity_at, and LogDensityAt's callable), and confirmed from the error logs that the closest candidate correctly changed from AbstractVector{<:Real} to AbstractVector.

Could you confirm whether this is the right fix?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Replace VarInfo argument in MLDExt with transform status

3 participants