From 8c3d30f78b375e00c1ed4fba75b0d8d0414e43a0 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 21 Oct 2025 18:08:25 +0100 Subject: [PATCH 001/148] v0.39 --- HISTORY.md | 2 ++ Project.toml | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/HISTORY.md b/HISTORY.md index 57ccaecd1..8165317f6 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,7 @@ # DynamicPPL Changelog +## 0.39.0 + ## 0.38.0 ### Breaking changes diff --git a/Project.toml b/Project.toml index 2fe65fd7b..7f58083bf 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.38.0" +version = "0.39.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From 7300c224dc0fbac2febea2b6185c74f48181dcc9 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 31 Oct 2025 18:48:58 +0000 Subject: [PATCH 002/148] Update DPPL compats for benchmarks and docs --- benchmarks/Project.toml | 2 +- docs/Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/Project.toml b/benchmarks/Project.toml index 0d4e9a654..55ca81da0 100644 --- a/benchmarks/Project.toml +++ b/benchmarks/Project.toml @@ -23,7 +23,7 @@ DynamicPPL = {path = "../"} ADTypes = "1.14.0" BenchmarkTools = "1.6.0" Distributions = "0.25.117" -DynamicPPL = "0.38" +DynamicPPL = "0.39" Enzyme = "0.13" ForwardDiff = "0.10.38, 1" LogDensityProblems = "2.1.2" diff --git a/docs/Project.toml b/docs/Project.toml index fed06ebde..69e0a4c5a 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -19,7 +19,7 @@ Accessors = "0.1" Distributions = "0.25" Documenter = "1" DocumenterMermaid = "0.1, 0.2" -DynamicPPL = "0.38" +DynamicPPL = "0.39" FillArrays = "0.13, 1" ForwardDiff = "0.10, 1" JET = "0.9, 0.10, 0.11" From 79150baf0a70380f46d9573746a900f7c38bb370 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 4 Nov 2025 17:31:18 +0000 Subject: [PATCH 003/148] remove merge conflict markers --- HISTORY.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index cd4cbc767..45be1772d 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -19,8 +19,6 @@ The generic method `returned(::Model, values, keys)` is deprecated and will be r Added a compatibility entry for JET@0.11. -> > > > > > > main - ## 0.38.1 Added `from_linked_vec_transform` and `from_vec_transform` methods for `ProductNamedTupleDistribution`. From 4ca95281ebbe451f99813f009bfc9c7140330a45 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 11 Nov 2025 16:24:14 +0000 Subject: [PATCH 004/148] Remove `NodeTrait` (#1133) * Remove NodeTrait * Changelog * Fix exports * docs * fix a bug * Fix doctests * Fix test * tweak changelog --- HISTORY.md | 17 +++++++ docs/src/api.md | 46 +++++++++++++---- src/DynamicPPL.jl | 13 +++-- src/contexts.jl | 82 ++++++++++++------------------ src/contexts/conditionfix.jl | 92 +++++++++++----------------------- src/contexts/default.jl | 1 - src/contexts/init.jl | 1 - src/contexts/prefix.jl | 17 ++----- src/contexts/transformation.jl | 1 - src/model.jl | 4 +- src/test_utils/contexts.jl | 19 ++----- test/contexts.jl | 36 +++++++------ 12 files changed, 154 insertions(+), 175 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 613957c33..f181897f7 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -2,6 +2,23 @@ ## 0.39.0 +### Breaking changes + +#### Parent and leaf contexts + +The `DynamicPPL.NodeTrait` function has been removed. +Instead of implementing this, parent contexts should subtype `DynamicPPL.AbstractParentContext`. +This is an abstract type which requires you to overload two functions, `DynamicPPL.childcontext` and `DynamicPPL.setchildcontext`. + +There should generally be few reasons to define your own parent contexts (the only one we are aware of, outside of DynamicPPL itself, is `Turing.Inference.GibbsContext`), so this change should not really affect users. + +Leaf contexts require no changes, apart from a removal of the `NodeTrait` function. + +`ConditionContext` and `PrefixContext` are no longer exported. +You should not need to use these directly, please use `AbstractPPL.condition` and `DynamicPPL.prefix` instead. + +#### Miscellaneous + Removed the method `returned(::Model, values, keys)`; please use `returned(::Model, ::AbstractDict{<:VarName})` instead. ## 0.38.9 diff --git a/docs/src/api.md b/docs/src/api.md index bbe39fb73..63dafdfca 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -352,13 +352,6 @@ Base.empty! SimpleVarInfo ``` -### Tilde-pipeline - -```@docs -tilde_assume!! -tilde_observe!! -``` - ### Accumulators The subtypes of [`AbstractVarInfo`](@ref) store the cumulative log prior and log likelihood, and sometimes other variables that change during executing, in what are called accumulators. @@ -463,15 +456,48 @@ By default, it does not perform any actual sampling: it only evaluates the model If you wish to sample new values, see the section on [VarInfo initialisation](#VarInfo-initialisation) just below this. The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model. -Contexts are subtypes of `AbstractPPL.AbstractContext`. + +All contexts are subtypes of `AbstractPPL.AbstractContext`. + +Contexts are split into two kinds: + +**Leaf contexts**: These are the most important contexts as they ultimately decide how model evaluation proceeds. +For example, `DefaultContext` evaluates the model using values stored inside a VarInfo's metadata, whereas `InitContext` obtains new values either by sampling or from a known set of parameters. +DynamicPPL has more leaf contexts which are used for internal purposes, but these are the two that are exported. ```@docs DefaultContext -PrefixContext -ConditionContext InitContext ``` +To implement a leaf context, you need to subtype `AbstractPPL.AbstractContext` and implement the `tilde_assume!!` and `tilde_observe!!` methods for your context. + +```@docs +tilde_assume!! +tilde_observe!! +``` + +**Parent contexts**: These essentially act as 'modifiers' for leaf contexts. +For example, `PrefixContext` adds a prefix to all variable names during evaluation, while `ConditionContext` marks certain variables as observed. + +To implement a parent context, you have to subtype `DynamicPPL.AbstractParentContext`, and implement the `childcontext` and `setchildcontext` methods. +If needed, you can also implement `tilde_assume!!` and `tilde_observe!!` for your context. +This is optional; the default implementation is to simply delegate to the child context. + +```@docs +AbstractParentContext +childcontext +setchildcontext +``` + +Since contexts form a tree structure, these functions are automatically defined for manipulating context stacks. +They are mainly useful for modifying the fundamental behaviour (i.e. the leaf context), without affecting any of the modifiers (i.e. parent contexts). + +```@docs +leafcontext +setleafcontext +``` + ### VarInfo initialisation The function `init!!` is used to initialise, or overwrite, values in a VarInfo. diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index e66f3fe11..c43bd89d5 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -94,16 +94,21 @@ export AbstractVarInfo, values_as_in_model, # LogDensityFunction LogDensityFunction, - # Contexts + # Leaf contexts + AbstractContext, contextualize, DefaultContext, - PrefixContext, - ConditionContext, + InitContext, + # Parent contexts + AbstractParentContext, + childcontext, + setchildcontext, + leafcontext, + setleafcontext, # Tilde pipeline tilde_assume!!, tilde_observe!!, # Initialisation - InitContext, AbstractInitStrategy, InitFromPrior, InitFromUniform, diff --git a/src/contexts.jl b/src/contexts.jl index 32a236e8e..46c5b8855 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -1,48 +1,32 @@ """ - NodeTrait(context) - NodeTrait(f, context) + AbstractParentContext -Specifies the role of `context` in the context-tree. +An abstract context that has a child context. -The officially supported traits are: -- `IsLeaf`: `context` does not have any decendants. -- `IsParent`: `context` has a child context to which we often defer. - Expects the following methods to be implemented: - - [`childcontext`](@ref) - - [`setchildcontext`](@ref) -""" -abstract type NodeTrait end -NodeTrait(_, context) = NodeTrait(context) - -""" - IsLeaf - -Specifies that the context is a leaf in the context-tree. -""" -struct IsLeaf <: NodeTrait end -""" - IsParent +Subtypes of `AbstractParentContext` must implement the following interface: -Specifies that the context is a parent in the context-tree. +- `DynamicPPL.childcontext(context::AbstractParentContext)`: Return the child context. +- `DynamicPPL.setchildcontext(parent::AbstractParentContext, child::AbstractContext)`: Reconstruct + `parent` but now using `child` as its child context. """ -struct IsParent <: NodeTrait end +abstract type AbstractParentContext <: AbstractContext end """ - childcontext(context) + childcontext(context::AbstractParentContext) Return the descendant context of `context`. """ childcontext """ - setchildcontext(parent::AbstractContext, child::AbstractContext) + setchildcontext(parent::AbstractParentContext, child::AbstractContext) Reconstruct `parent` but now using `child` is its [`childcontext`](@ref), effectively updating the child context. # Examples ```jldoctest -julia> using DynamicPPL: DynamicTransformationContext +julia> using DynamicPPL: DynamicTransformationContext, ConditionContext julia> ctx = ConditionContext((; a = 1)); @@ -60,12 +44,11 @@ setchildcontext """ leafcontext(context::AbstractContext) -Return the leaf of `context`, i.e. the first descendant context that `IsLeaf`. +Return the leaf of `context`, i.e. the first descendant context that is not an +`AbstractParentContext`. """ -leafcontext(context::AbstractContext) = - leafcontext(NodeTrait(leafcontext, context), context) -leafcontext(::IsLeaf, context::AbstractContext) = context -leafcontext(::IsParent, context::AbstractContext) = leafcontext(childcontext(context)) +leafcontext(context::AbstractContext) = context +leafcontext(context::AbstractParentContext) = leafcontext(childcontext(context)) """ setleafcontext(left::AbstractContext, right::AbstractContext) @@ -80,12 +63,10 @@ original leaf context of `left`. ```jldoctest julia> using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext, AbstractContext, DynamicTransformationContext -julia> struct ParentContext{C} <: AbstractContext +julia> struct ParentContext{C} <: AbstractParentContext context::C end -julia> DynamicPPL.NodeTrait(::ParentContext) = DynamicPPL.IsParent() - julia> DynamicPPL.childcontext(context::ParentContext) = context.context julia> DynamicPPL.setchildcontext(::ParentContext, child) = ParentContext(child) @@ -104,21 +85,10 @@ julia> # Append another parent context. ParentContext(ParentContext(ParentContext(DefaultContext()))) ``` """ -function setleafcontext(left::AbstractContext, right::AbstractContext) - return setleafcontext( - NodeTrait(setleafcontext, left), NodeTrait(setleafcontext, right), left, right - ) -end -function setleafcontext( - ::IsParent, ::IsParent, left::AbstractContext, right::AbstractContext -) +function setleafcontext(left::AbstractParentContext, right::AbstractContext) return setchildcontext(left, setleafcontext(childcontext(left), right)) end -function setleafcontext(::IsParent, ::IsLeaf, left::AbstractContext, right::AbstractContext) - return setchildcontext(left, setleafcontext(childcontext(left), right)) -end -setleafcontext(::IsLeaf, ::IsParent, left::AbstractContext, right::AbstractContext) = right -setleafcontext(::IsLeaf, ::IsLeaf, left::AbstractContext, right::AbstractContext) = right +setleafcontext(::AbstractContext, right::AbstractContext) = right """ DynamicPPL.tilde_assume!!( @@ -138,10 +108,15 @@ This function should return a tuple `(x, vi)`, where `x` is the sampled value (w must be in unlinked space!) and `vi` is the updated VarInfo. """ function tilde_assume!!( - context::AbstractContext, right::Distribution, vn::VarName, vi::AbstractVarInfo + context::AbstractParentContext, right::Distribution, vn::VarName, vi::AbstractVarInfo ) return tilde_assume!!(childcontext(context), right, vn, vi) end +function tilde_assume!!( + context::AbstractContext, ::Distribution, ::VarName, ::AbstractVarInfo +) + return error("tilde_assume!! not implemented for context of type $(typeof(context))") +end """ DynamicPPL.tilde_observe!!( @@ -171,7 +146,7 @@ This function should return a tuple `(left, vi)`, where `left` is the same as th `vi` is the updated VarInfo. """ function tilde_observe!!( - context::AbstractContext, + context::AbstractParentContext, right::Distribution, left, vn::Union{VarName,Nothing}, @@ -179,3 +154,12 @@ function tilde_observe!!( ) return tilde_observe!!(childcontext(context), right, left, vn, vi) end +function tilde_observe!!( + context::AbstractContext, + ::Distribution, + ::Any, + ::Union{VarName,Nothing}, + ::AbstractVarInfo, +) + return error("tilde_observe!! not implemented for context of type $(typeof(context))") +end diff --git a/src/contexts/conditionfix.jl b/src/contexts/conditionfix.jl index d3802de85..7a34db5cb 100644 --- a/src/contexts/conditionfix.jl +++ b/src/contexts/conditionfix.jl @@ -11,7 +11,7 @@ when there are varnames that cannot be represented as symbols, e.g. """ struct ConditionContext{ Values<:Union{NamedTuple,AbstractDict{<:VarName}},Ctx<:AbstractContext -} <: AbstractContext +} <: AbstractParentContext values::Values context::Ctx end @@ -41,9 +41,10 @@ function Base.show(io::IO, context::ConditionContext) return print(io, "ConditionContext($(context.values), $(childcontext(context)))") end -NodeTrait(::ConditionContext) = IsParent() childcontext(context::ConditionContext) = context.context -setchildcontext(parent::ConditionContext, child) = ConditionContext(parent.values, child) +function setchildcontext(parent::ConditionContext, child::AbstractContext) + return ConditionContext(parent.values, child) +end """ hasconditioned(context::AbstractContext, vn::VarName) @@ -76,11 +77,8 @@ Return `true` if `vn` is found in `context` or any of its descendants. This is contrast to [`hasconditioned(::AbstractContext, ::VarName)`](@ref) which only checks for `vn` in `context`, not recursively checking if `vn` is in any of its descendants. """ -function hasconditioned_nested(context::AbstractContext, vn) - return hasconditioned_nested(NodeTrait(hasconditioned_nested, context), context, vn) -end -hasconditioned_nested(::IsLeaf, context, vn) = hasconditioned(context, vn) -function hasconditioned_nested(::IsParent, context, vn) +hasconditioned_nested(context::AbstractContext, vn) = hasconditioned(context, vn) +function hasconditioned_nested(context::AbstractParentContext, vn) return hasconditioned(context, vn) || hasconditioned_nested(childcontext(context), vn) end function hasconditioned_nested(context::PrefixContext, vn) @@ -96,15 +94,12 @@ This is contrast to [`getconditioned`](@ref) which only returns the value `vn` i not recursively looking into its descendants. """ function getconditioned_nested(context::AbstractContext, vn) - return getconditioned_nested(NodeTrait(getconditioned_nested, context), context, vn) -end -function getconditioned_nested(::IsLeaf, context, vn) return error("context $(context) does not contain value for $vn") end function getconditioned_nested(context::PrefixContext, vn) return getconditioned_nested(collapse_prefix_stack(context), vn) end -function getconditioned_nested(::IsParent, context, vn) +function getconditioned_nested(context::AbstractParentContext, vn) return if hasconditioned(context, vn) getconditioned(context, vn) else @@ -113,7 +108,7 @@ function getconditioned_nested(::IsParent, context, vn) end """ - decondition(context::AbstractContext, syms...) + decondition_context(context::AbstractContext, syms...) Return `context` but with `syms` no longer conditioned on. @@ -121,13 +116,10 @@ Note that this recursively traverses contexts, deconditioning all along the way. See also: [`condition`](@ref) """ -decondition_context(::IsLeaf, context, args...) = context -function decondition_context(::IsParent, context, args...) +decondition_context(context::AbstractContext, args...) = context +function decondition_context(context::AbstractParentContext, args...) return setchildcontext(context, decondition_context(childcontext(context), args...)) end -function decondition_context(context, args...) - return decondition_context(NodeTrait(context), context, args...) -end function decondition_context(context::ConditionContext) return decondition_context(childcontext(context)) end @@ -160,11 +152,8 @@ Return `NamedTuple` of values that are conditioned on under context`. Note that this will recursively traverse the context stack and return a merged version of the condition values. """ -function conditioned(context::AbstractContext) - return conditioned(NodeTrait(conditioned, context), context) -end -conditioned(::IsLeaf, context) = NamedTuple() -conditioned(::IsParent, context) = conditioned(childcontext(context)) +conditioned(::AbstractContext) = NamedTuple() +conditioned(context::AbstractParentContext) = conditioned(childcontext(context)) function conditioned(context::ConditionContext) # Note the order of arguments to `merge`. The behavior of the rest of DPPL # is that the outermost `context` takes precendence, hence when resolving @@ -176,7 +165,7 @@ function conditioned(context::PrefixContext) return conditioned(collapse_prefix_stack(context)) end -struct FixedContext{Values,Ctx<:AbstractContext} <: AbstractContext +struct FixedContext{Values,Ctx<:AbstractContext} <: AbstractParentContext values::Values context::Ctx end @@ -197,16 +186,17 @@ function Base.show(io::IO, context::FixedContext) return print(io, "FixedContext($(context.values), $(childcontext(context)))") end -NodeTrait(::FixedContext) = IsParent() childcontext(context::FixedContext) = context.context -setchildcontext(parent::FixedContext, child) = FixedContext(parent.values, child) +function setchildcontext(parent::FixedContext, child::AbstractContext) + return FixedContext(parent.values, child) +end """ hasfixed(context::AbstractContext, vn::VarName) Return `true` if a fixed value for `vn` is found in `context`. """ -hasfixed(context::AbstractContext, vn::VarName) = false +hasfixed(::AbstractContext, ::VarName) = false hasfixed(context::FixedContext, vn::VarName) = hasvalue(context.values, vn) function hasfixed(context::FixedContext, vns::AbstractArray{<:VarName}) return all(Base.Fix1(hasvalue, context.values), vns) @@ -230,11 +220,8 @@ Return `true` if a fixed value for `vn` is found in `context` or any of its desc This is contrast to [`hasfixed(::AbstractContext, ::VarName)`](@ref) which only checks for `vn` in `context`, not recursively checking if `vn` is in any of its descendants. """ -function hasfixed_nested(context::AbstractContext, vn) - return hasfixed_nested(NodeTrait(hasfixed_nested, context), context, vn) -end -hasfixed_nested(::IsLeaf, context, vn) = hasfixed(context, vn) -function hasfixed_nested(::IsParent, context, vn) +hasfixed_nested(context::AbstractContext, vn) = hasfixed(context, vn) +function hasfixed_nested(context::AbstractParentContext, vn) return hasfixed(context, vn) || hasfixed_nested(childcontext(context), vn) end function hasfixed_nested(context::PrefixContext, vn) @@ -250,15 +237,12 @@ This is contrast to [`getfixed`](@ref) which only returns the value `vn` in `con not recursively looking into its descendants. """ function getfixed_nested(context::AbstractContext, vn) - return getfixed_nested(NodeTrait(getfixed_nested, context), context, vn) -end -function getfixed_nested(::IsLeaf, context, vn) return error("context $(context) does not contain value for $vn") end function getfixed_nested(context::PrefixContext, vn) return getfixed_nested(collapse_prefix_stack(context), vn) end -function getfixed_nested(::IsParent, context, vn) +function getfixed_nested(context::AbstractParentContext, vn) return if hasfixed(context, vn) getfixed(context, vn) else @@ -283,7 +267,7 @@ end function fix(values::NTuple{<:Any,<:Pair{<:VarName}}) return fix(DefaultContext(), values) end -fix(context::AbstractContext, values::NamedTuple{()}) = context +fix(context::AbstractContext, ::NamedTuple{()}) = context function fix(context::AbstractContext, values::Union{AbstractDict,NamedTuple}) return FixedContext(values, context) end @@ -306,13 +290,10 @@ Note that this recursively traverses contexts, unfixing all along the way. See also: [`fix`](@ref) """ -unfix(::IsLeaf, context, args...) = context -function unfix(::IsParent, context, args...) +unfix(context::AbstractContext, args...) = context +function unfix(context::AbstractParentContext, args...) return setchildcontext(context, unfix(childcontext(context), args...)) end -function unfix(context, args...) - return unfix(NodeTrait(context), context, args...) -end function unfix(context::FixedContext) return unfix(childcontext(context)) end @@ -341,9 +322,8 @@ Return the values that are fixed under `context`. Note that this will recursively traverse the context stack and return a merged version of the fix values. """ -fixed(context::AbstractContext) = fixed(NodeTrait(fixed, context), context) -fixed(::IsLeaf, context) = NamedTuple() -fixed(::IsParent, context) = fixed(childcontext(context)) +fixed(::AbstractContext) = NamedTuple() +fixed(context::AbstractParentContext) = fixed(childcontext(context)) function fixed(context::FixedContext) # Note the order of arguments to `merge`. The behavior of the rest of DPPL # is that the outermost `context` takes precendence, hence when resolving @@ -374,7 +354,7 @@ topic](https://turinglang.org/DynamicPPL.jl/previews/PR892/internals/submodel_co which explains this in much more detail. ```jldoctest -julia> using DynamicPPL: collapse_prefix_stack +julia> using DynamicPPL: collapse_prefix_stack, PrefixContext, ConditionContext julia> c1 = PrefixContext(@varname(a), ConditionContext((x=1, ))); @@ -403,11 +383,8 @@ function collapse_prefix_stack(context::PrefixContext) # depth of the context stack. return prefix_cond_and_fixed_variables(collapsed, context.vn_prefix) end -function collapse_prefix_stack(context::AbstractContext) - return collapse_prefix_stack(NodeTrait(collapse_prefix_stack, context), context) -end -collapse_prefix_stack(::IsLeaf, context) = context -function collapse_prefix_stack(::IsParent, context) +collapse_prefix_stack(context::AbstractContext) = context +function collapse_prefix_stack(context::AbstractParentContext) new_child_context = collapse_prefix_stack(childcontext(context)) return setchildcontext(context, new_child_context) end @@ -448,19 +425,10 @@ function prefix_cond_and_fixed_variables(ctx::FixedContext, prefix::VarName) prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix) return FixedContext(prefixed_vn_dict, prefixed_child_ctx) end -function prefix_cond_and_fixed_variables(c::AbstractContext, prefix::VarName) - return prefix_cond_and_fixed_variables( - NodeTrait(prefix_cond_and_fixed_variables, c), c, prefix - ) -end -function prefix_cond_and_fixed_variables( - ::IsLeaf, context::AbstractContext, prefix::VarName -) +function prefix_cond_and_fixed_variables(context::AbstractContext, ::VarName) return context end -function prefix_cond_and_fixed_variables( - ::IsParent, context::AbstractContext, prefix::VarName -) +function prefix_cond_and_fixed_variables(context::AbstractParentContext, prefix::VarName) return setchildcontext( context, prefix_cond_and_fixed_variables(childcontext(context), prefix) ) diff --git a/src/contexts/default.jl b/src/contexts/default.jl index ec21e1a56..3cafe39f1 100644 --- a/src/contexts/default.jl +++ b/src/contexts/default.jl @@ -17,7 +17,6 @@ with `DefaultContext` means 'calculating the log-probability associated with the in the `AbstractVarInfo`'. """ struct DefaultContext <: AbstractContext end -NodeTrait(::DefaultContext) = IsLeaf() """ DynamicPPL.tilde_assume!!( diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 44dbc5508..83507353f 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -150,7 +150,6 @@ struct InitContext{R<:Random.AbstractRNG,S<:AbstractInitStrategy} <: AbstractCon return InitContext(Random.default_rng(), strategy) end end -NodeTrait(::InitContext) = IsLeaf() function tilde_assume!!( ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo diff --git a/src/contexts/prefix.jl b/src/contexts/prefix.jl index 24615e683..45307874a 100644 --- a/src/contexts/prefix.jl +++ b/src/contexts/prefix.jl @@ -13,7 +13,7 @@ unique. See also: [`to_submodel`](@ref) """ -struct PrefixContext{Tvn<:VarName,C<:AbstractContext} <: AbstractContext +struct PrefixContext{Tvn<:VarName,C<:AbstractContext} <: AbstractParentContext vn_prefix::Tvn context::C end @@ -23,7 +23,6 @@ function PrefixContext(::Val{sym}, context::AbstractContext) where {sym} end PrefixContext(::Val{sym}) where {sym} = PrefixContext(VarName{sym}()) -NodeTrait(::PrefixContext) = IsParent() childcontext(context::PrefixContext) = context.context function setchildcontext(ctx::PrefixContext, child::AbstractContext) return PrefixContext(ctx.vn_prefix, child) @@ -37,11 +36,8 @@ Apply the prefixes in the context `ctx` to the variable name `vn`. function prefix(ctx::PrefixContext, vn::VarName) return AbstractPPL.prefix(prefix(childcontext(ctx), vn), ctx.vn_prefix) end -function prefix(ctx::AbstractContext, vn::VarName) - return prefix(NodeTrait(ctx), ctx, vn) -end -prefix(::IsLeaf, ::AbstractContext, vn::VarName) = vn -function prefix(::IsParent, ctx::AbstractContext, vn::VarName) +prefix(::AbstractContext, vn::VarName) = vn +function prefix(ctx::AbstractParentContext, vn::VarName) return prefix(childcontext(ctx), vn) end @@ -72,11 +68,8 @@ function prefix_and_strip_contexts(ctx::PrefixContext, vn::VarName) ) return AbstractPPL.prefix(vn_prefixed, ctx.vn_prefix), child_context_without_prefixes end -function prefix_and_strip_contexts(ctx::AbstractContext, vn::VarName) - return prefix_and_strip_contexts(NodeTrait(ctx), ctx, vn) -end -prefix_and_strip_contexts(::IsLeaf, ctx::AbstractContext, vn::VarName) = (vn, ctx) -function prefix_and_strip_contexts(::IsParent, ctx::AbstractContext, vn::VarName) +prefix_and_strip_contexts(ctx::AbstractContext, vn::VarName) = (vn, ctx) +function prefix_and_strip_contexts(ctx::AbstractParentContext, vn::VarName) vn, new_ctx = prefix_and_strip_contexts(childcontext(ctx), vn) return vn, setchildcontext(ctx, new_ctx) end diff --git a/src/contexts/transformation.jl b/src/contexts/transformation.jl index 5153f7857..c2eee2863 100644 --- a/src/contexts/transformation.jl +++ b/src/contexts/transformation.jl @@ -10,7 +10,6 @@ Note that some `AbstractVarInfo` types, must notably `VarInfo`, override the how to do the transformation, used by e.g. `SimpleVarInfo`. """ struct DynamicTransformationContext{isinverse} <: AbstractContext end -NodeTrait(::DynamicTransformationContext) = IsLeaf() function tilde_assume!!( ::DynamicTransformationContext{isinverse}, diff --git a/src/model.jl b/src/model.jl index ec98b90cd..94fcd9fd4 100644 --- a/src/model.jl +++ b/src/model.jl @@ -427,7 +427,7 @@ Return the conditioned values in `model`. ```jldoctest julia> using Distributions -julia> using DynamicPPL: conditioned, contextualize +julia> using DynamicPPL: conditioned, contextualize, PrefixContext, ConditionContext julia> @model function demo() m ~ Normal() @@ -770,7 +770,7 @@ Return the fixed values in `model`. ```jldoctest julia> using Distributions -julia> using DynamicPPL: fixed, contextualize +julia> using DynamicPPL: fixed, contextualize, PrefixContext julia> @model function demo() m ~ Normal() diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index aae2e4ec6..c48d2ddfd 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -4,11 +4,10 @@ # Utilities for testing contexts. # Dummy context to test nested behaviors. -struct TestParentContext{C<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext +struct TestParentContext{C<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractParentContext context::C end TestParentContext() = TestParentContext(DefaultContext()) -DynamicPPL.NodeTrait(::TestParentContext) = DynamicPPL.IsParent() DynamicPPL.childcontext(context::TestParentContext) = context.context DynamicPPL.setchildcontext(::TestParentContext, child) = TestParentContext(child) function Base.show(io::IO, c::TestParentContext) @@ -25,19 +24,13 @@ This method ensures that `context` - Correctly implements the tilde-pipeline. """ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model) - node_trait = DynamicPPL.NodeTrait(context) - if node_trait isa DynamicPPL.IsLeaf - test_leaf_context(context, model) - elseif node_trait isa DynamicPPL.IsParent - test_parent_context(context, model) - else - error("Invalid NodeTrait: $node_trait") - end + return test_leaf_context(context, model) +end +function test_context(context::DynamicPPL.AbstractParentContext, model::DynamicPPL.Model) + return test_parent_context(context, model) end function test_leaf_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model) - @test DynamicPPL.NodeTrait(context) isa DynamicPPL.IsLeaf - # Note that for a leaf context we can't assume that it will work with an # empty VarInfo. (For example, DefaultContext will error with empty # varinfos.) Thus we only test evaluation with VarInfos that are already @@ -57,8 +50,6 @@ function test_leaf_context(context::DynamicPPL.AbstractContext, model::DynamicPP end function test_parent_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model) - @test DynamicPPL.NodeTrait(context) isa DynamicPPL.IsParent - @testset "get/set leaf and child contexts" begin # Ensure we're using a different leaf context than the current. leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext diff --git a/test/contexts.jl b/test/contexts.jl index 972d833a5..ae7332a43 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -6,10 +6,9 @@ using DynamicPPL: childcontext, setchildcontext, AbstractContext, - NodeTrait, - IsLeaf, - IsParent, + AbstractParentContext, contextual_isassumption, + PrefixContext, FixedContext, ConditionContext, decondition_context, @@ -25,22 +24,21 @@ using LinearAlgebra: I using Random: Xoshiro # TODO: Should we maybe put this in DPPL itself? +function Base.iterate(context::AbstractParentContext) + return context, childcontext(context) +end function Base.iterate(context::AbstractContext) - if NodeTrait(context) isa IsLeaf - return nothing - end - - return context, context + return context, nothing end -function Base.iterate(_::AbstractContext, context::AbstractContext) - return _iterate(NodeTrait(context), context) +function Base.iterate(::AbstractContext, state::AbstractParentContext) + return state, childcontext(state) end -_iterate(::IsLeaf, context) = nothing -function _iterate(::IsParent, context) - child = childcontext(context) - return child, child +function Base.iterate(::AbstractContext, state::AbstractContext) + return state, nothing +end +function Base.iterate(::AbstractContext, state::Nothing) + return nothing end - Base.IteratorSize(::Type{<:AbstractContext}) = Base.SizeUnknown() Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @@ -347,11 +345,11 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @testset "collapse_prefix_stack" begin # Utility function to make sure that there are no PrefixContexts in # the context stack. - function has_no_prefixcontexts(ctx::AbstractContext) - return !(ctx isa PrefixContext) && ( - NodeTrait(ctx) isa IsLeaf || has_no_prefixcontexts(childcontext(ctx)) - ) + has_no_prefixcontexts(::PrefixContext) = false + function has_no_prefixcontexts(ctx::AbstractParentContext) + return has_no_prefixcontexts(childcontext(ctx)) end + has_no_prefixcontexts(::AbstractContext) = true # Prefix -> Condition c1 = PrefixContext(@varname(a), ConditionContext((c=1, d=2))) From 535ce4f68e8f162fb382fb5d55eae0238d332e7a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 13 Nov 2025 13:30:43 +0000 Subject: [PATCH 005/148] FastLDF / InitContext unified (#1132) * Fast Log Density Function * Make it work with AD * Optimise performance for identity VarNames * Mark `get_range_and_linked` as having zero derivative * Update comment * make AD testing / benchmarking use FastLDF * Fix tests * Optimise away `make_evaluate_args_and_kwargs` * const func annotation * Disable benchmarks on non-typed-Metadata-VarInfo * Fix `_evaluate!!` correctly to handle submodels * Actually fix submodel evaluate * Document thoroughly and organise code * Support more VarInfos, make it thread-safe (?) * fix bug in parsing ranges from metadata/VNV * Fix get_param_eltype for TSVI * Disable Enzyme benchmark * Don't override _evaluate!!, that breaks ForwardDiff (sometimes) * Move FastLDF to experimental for now * Fix imports, add tests, etc * More test fixes * Fix imports / tests * Remove AbstractFastEvalContext * Changelog and patch bump * Add correctness tests, fix imports * Concretise parameter vector in tests * Add zero-allocation tests * Add Chairmarks as test dep * Disable allocations tests on multi-threaded * Fast InitContext (#1125) * Make InitContext work with OnlyAccsVarInfo * Do not convert NamedTuple to Dict * remove logging * Enable InitFromPrior and InitFromUniform too * Fix `infer_nested_eltype` invocation * Refactor FastLDF to use InitContext * note init breaking change * fix logjac sign * workaround Mooncake segfault * fix changelog too * Fix get_param_eltype for context stacks * Add a test for threaded observe * Export init * Remove dead code * fix transforms for pathological distributions * Tidy up loads of things * fix typed_identity spelling * fix definition order * Improve docstrings * Remove stray comment * export get_param_eltype (unfortunatley) * Add more comment * Update comment * Remove inlines, fix OAVI docstring * Improve docstrings * Simplify InitFromParams constructor * Replace map(identity, x[:]) with [i for i in x[:]] * Simplify implementation for InitContext/OAVI * Add another model to allocation tests Co-authored-by: Markus Hauru * Revert removal of dist argument (oops) * Format * Update some outdated bits of FastLDF docstring * remove underscores --------- Co-authored-by: Markus Hauru --- HISTORY.md | 15 ++ docs/src/api.md | 12 +- ext/DynamicPPLEnzymeCoreExt.jl | 13 +- ext/DynamicPPLMooncakeExt.jl | 3 + src/DynamicPPL.jl | 5 +- src/compiler.jl | 38 ++-- src/contexts/init.jl | 255 ++++++++++++++++++++---- src/experimental.jl | 2 + src/fasteval.jl | 336 ++++++++++++++++++++++++++++++++ src/model.jl | 32 ++- src/onlyaccs.jl | 42 ++++ src/utils.jl | 35 ++++ test/Project.toml | 1 + test/fasteval.jl | 233 ++++++++++++++++++++++ test/integration/enzyme/main.jl | 6 +- test/runtests.jl | 1 + 16 files changed, 955 insertions(+), 74 deletions(-) create mode 100644 src/fasteval.jl create mode 100644 src/onlyaccs.jl create mode 100644 test/fasteval.jl diff --git a/HISTORY.md b/HISTORY.md index f181897f7..0f0102ce4 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -21,6 +21,21 @@ You should not need to use these directly, please use `AbstractPPL.condition` an Removed the method `returned(::Model, values, keys)`; please use `returned(::Model, ::AbstractDict{<:VarName})` instead. +The method `DynamicPPL.init` (for implementing `AbstractInitStrategy`) now has a different signature: it must return a tuple of the generated value, plus a transform function that maps it back to unlinked space. +This is a generalisation of the previous behaviour, where `init` would always return an unlinked value (in effect forcing the transform to be the identity function). + +### Other changes + +#### FastLDF + +Added `DynamicPPL.Experimental.FastLDF`, a version of `LogDensityFunction` that provides performance improvements on the order of 2–10× for both model evaluation as well as automatic differentiation. +Exact speedups depend on the model size: larger models have less significant speedups because the bulk of the work is done in calls to `logpdf`. + +Please note that `FastLDF` is currently considered internal and its API may change without warning. +We intend to replace `LogDensityFunction` with `FastLDF` in a release in the near future, but until then we recommend not using it. + +For more information about `FastLDF`, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/fasteval.jl` file, which contains extensive comments. + ## 0.38.9 Remove warning when using Enzyme as the AD backend. diff --git a/docs/src/api.md b/docs/src/api.md index 63dafdfca..e81f18dc7 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -170,6 +170,12 @@ DynamicPPL.prefix ## Utilities +`typed_identity` is the same as `identity`, but with an overload for `with_logabsdet_jacobian` that ensures that it never errors. + +```@docs +typed_identity +``` + It is possible to manually increase (or decrease) the accumulated log likelihood or prior from within a model function. ```@docs @@ -517,10 +523,12 @@ InitFromParams ``` If you wish to write your own, you have to subtype [`DynamicPPL.AbstractInitStrategy`](@ref) and implement the `init` method. +In very rare situations, you may also need to implement `get_param_eltype`, which defines the element type of the parameters generated by the strategy. ```@docs -DynamicPPL.AbstractInitStrategy -DynamicPPL.init +AbstractInitStrategy +init +get_param_eltype ``` ### Choosing a suitable VarInfo diff --git a/ext/DynamicPPLEnzymeCoreExt.jl b/ext/DynamicPPLEnzymeCoreExt.jl index 35159636f..ef21c255b 100644 --- a/ext/DynamicPPLEnzymeCoreExt.jl +++ b/ext/DynamicPPLEnzymeCoreExt.jl @@ -1,16 +1,15 @@ module DynamicPPLEnzymeCoreExt -if isdefined(Base, :get_extension) - using DynamicPPL: DynamicPPL - using EnzymeCore -else - using ..DynamicPPL: DynamicPPL - using ..EnzymeCore -end +using DynamicPPL: DynamicPPL +using EnzymeCore # Mark is_transformed as having 0 derivative. The `nothing` return value is not significant, Enzyme # only checks whether such a method exists, and never runs it. @inline EnzymeCore.EnzymeRules.inactive(::typeof(DynamicPPL.is_transformed), args...) = nothing +# Likewise for get_range_and_linked. +@inline EnzymeCore.EnzymeRules.inactive( + ::typeof(DynamicPPL._get_range_and_linked), args... +) = nothing end diff --git a/ext/DynamicPPLMooncakeExt.jl b/ext/DynamicPPLMooncakeExt.jl index 23a3430eb..8adf66030 100644 --- a/ext/DynamicPPLMooncakeExt.jl +++ b/ext/DynamicPPLMooncakeExt.jl @@ -5,5 +5,8 @@ using Mooncake: Mooncake # This is purely an optimisation. Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(is_transformed),Vararg} +Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{ + typeof(DynamicPPL._get_range_and_linked),Vararg +} end # module diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index c43bd89d5..e9b902363 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -84,8 +84,8 @@ export AbstractVarInfo, # Compiler @model, # Utilities - init, OrderedDict, + typed_identity, # Model Model, getmissings, @@ -113,6 +113,8 @@ export AbstractVarInfo, InitFromPrior, InitFromUniform, InitFromParams, + init, + get_param_eltype, # Pseudo distributions NamedDist, NoDist, @@ -193,6 +195,7 @@ include("abstract_varinfo.jl") include("threadsafe.jl") include("varinfo.jl") include("simple_varinfo.jl") +include("onlyaccs.jl") include("compiler.jl") include("pointwise_logdensities.jl") include("logdensityfunction.jl") diff --git a/src/compiler.jl b/src/compiler.jl index badba9f9d..3324780ca 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -718,14 +718,15 @@ end # TODO(mhauru) matchingvalue has methods that can accept both types and values. Why? # TODO(mhauru) This function needs a more comprehensive docstring. """ - matchingvalue(vi, value) + matchingvalue(param_eltype, value) -Convert the `value` to the correct type for the `vi` object. +Convert the `value` to the correct type, given the element type of the parameters +being used to evaluate the model. """ -function matchingvalue(vi, value) +function matchingvalue(param_eltype, value) T = typeof(value) if hasmissing(T) - _value = convert(get_matching_type(vi, T), value) + _value = convert(get_matching_type(param_eltype, T), value) # TODO(mhauru) Why do we make a deepcopy, even though in the !hasmissing branch we # are happy to return `value` as-is? if _value === value @@ -738,29 +739,30 @@ function matchingvalue(vi, value) end end -function matchingvalue(vi, value::FloatOrArrayType) - return get_matching_type(vi, value) +function matchingvalue(param_eltype, value::FloatOrArrayType) + return get_matching_type(param_eltype, value) end -function matchingvalue(vi, ::TypeWrap{T}) where {T} - return TypeWrap{get_matching_type(vi, T)}() +function matchingvalue(param_eltype, ::TypeWrap{T}) where {T} + return TypeWrap{get_matching_type(param_eltype, T)}() end # TODO(mhauru) This function needs a more comprehensive docstring. What is it for? """ - get_matching_type(vi, ::TypeWrap{T}) where {T} + get_matching_type(param_eltype, ::TypeWrap{T}) where {T} -Get the specialized version of type `T` for `vi`. +Get the specialized version of type `T`, given an element type of the parameters +being used to evaluate the model. """ get_matching_type(_, ::Type{T}) where {T} = T -function get_matching_type(vi, ::Type{<:Union{Missing,AbstractFloat}}) - return Union{Missing,float_type_with_fallback(eltype(vi))} +function get_matching_type(param_eltype, ::Type{<:Union{Missing,AbstractFloat}}) + return Union{Missing,float_type_with_fallback(param_eltype)} end -function get_matching_type(vi, ::Type{<:AbstractFloat}) - return float_type_with_fallback(eltype(vi)) +function get_matching_type(param_eltype, ::Type{<:AbstractFloat}) + return float_type_with_fallback(param_eltype) end -function get_matching_type(vi, ::Type{<:Array{T,N}}) where {T,N} - return Array{get_matching_type(vi, T),N} +function get_matching_type(param_eltype, ::Type{<:Array{T,N}}) where {T,N} + return Array{get_matching_type(param_eltype, T),N} end -function get_matching_type(vi, ::Type{<:Array{T}}) where {T} - return Array{get_matching_type(vi, T)} +function get_matching_type(param_eltype, ::Type{<:Array{T}}) where {T} + return Array{get_matching_type(param_eltype, T)} end diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 83507353f..a79969a13 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -1,11 +1,11 @@ """ AbstractInitStrategy -Abstract type representing the possible ways of initialising new values for -the random variables in a model (e.g., when creating a new VarInfo). +Abstract type representing the possible ways of initialising new values for the random +variables in a model (e.g., when creating a new VarInfo). -Any subtype of `AbstractInitStrategy` must implement the -[`DynamicPPL.init`](@ref) method. +Any subtype of `AbstractInitStrategy` must implement the [`DynamicPPL.init`](@ref) method, +and in some cases, [`DynamicPPL.get_param_eltype`](@ref) (see its docstring for details). """ abstract type AbstractInitStrategy end @@ -14,14 +14,60 @@ abstract type AbstractInitStrategy end Generate a new value for a random variable with the given distribution. -!!! warning "Return values must be unlinked" - The values returned by `init` must always be in the untransformed space, i.e., - they must be within the support of the original distribution. That means that, - for example, `init(rng, dist, u::InitFromUniform)` will in general return values that - are outside the range [u.lower, u.upper]. +This function must return a tuple `(x, trf)`, where + +- `x` is the generated value +- `trf` is a function that transforms the generated value back to the unlinked space. If the + value is already in unlinked space, then this should be `DynamicPPL.typed_identity`. You + can also use `Base.identity`, but if you use this, you **must** be confident that + `zero(eltype(x))` will **never** error. See the docstring of `typed_identity` for more + information. """ function init end +""" + DynamicPPL.get_param_eltype(strategy::AbstractInitStrategy) + +Return the element type of the parameters generated from the given initialisation strategy. + +The default implementation returns `Any`. However, for `InitFromParams` which provides known +parameters for evaluating the model, methods are implemented in order to return more specific +types. + +In general, if you are implementing a custom `AbstractInitStrategy`, correct behaviour can +only be guaranteed if you implement this method as well. However, quite often, the default +return value of `Any` will actually suffice. The cases where this does *not* suffice, and +where you _do_ have to manually implement `get_param_eltype`, are explained in the extended +help (see `??DynamicPPL.get_param_eltype` in the REPL). + +# Extended help + +There are a few edge cases in DynamicPPL where the element type is needed. These largely +relate to determining the element type of accumulators ahead of time (_before_ evaluation), +as well as promoting type parameters in model arguments. The classic case is when evaluating +a model with ForwardDiff: the accumulators must be set to `Dual`s, and any `Vector{Float64}` +arguments must be promoted to `Vector{Dual}`. Other tracer types, for example those in +SparseConnectivityTracer.jl, also require similar treatment. + +If the `AbstractInitStrategy` is never used in combination with tracer types, then it is +perfectly safe to return `Any`. This does not lead to type instability downstream because +the actual accumulators will still be created with concrete Float types (the `Any` is just +used to determine whether the float type needs to be modified). + +In case that wasn't enough: in fact, even the above is not always true. Firstly, the +accumulator argument is only true when evaluating with ThreadSafeVarInfo. See the comments +in `DynamicPPL.unflatten` for more details. For non-threadsafe evaluation, Julia is capable +of automatically promoting the types on its own. Secondly, the promotion only matters if you +are trying to directly assign into a `Vector{Float64}` with a `ForwardDiff.Dual` or similar +tracer type, for example using `xs[i] = MyDual`. This doesn't actually apply to +tilde-statements like `xs[i] ~ ...` because those use `Accessors.@set` under the hood, which +also does the promotion for you. For the gory details, see the following issues: + +- https://github.com/TuringLang/DynamicPPL.jl/issues/906 for accumulator types +- https://github.com/TuringLang/DynamicPPL.jl/issues/823 for type argument promotion +""" +get_param_eltype(::AbstractInitStrategy) = Any + """ InitFromPrior() @@ -29,7 +75,7 @@ Obtain new values by sampling from the prior distribution. """ struct InitFromPrior <: AbstractInitStrategy end function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, ::InitFromPrior) - return rand(rng, dist) + return rand(rng, dist), typed_identity end """ @@ -69,43 +115,61 @@ function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::InitFro if x isa Array{<:Any,0} x = x[] end - return x + return x, typed_identity end """ InitFromParams( - params::Union{AbstractDict{<:VarName},NamedTuple}, + params::Any fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior() ) -Obtain new values by extracting them from the given dictionary or NamedTuple. +Obtain new values by extracting them from the given set of `params`. + +The most common use case is to provide a `NamedTuple` or `AbstractDict{<:VarName}`, which +provides a mapping from variable names to values. However, we leave the type of `params` +open in order to allow for custom parameter storage types. -The parameter `fallback` specifies how new values are to be obtained if they -cannot be found in `params`, or they are specified as `missing`. `fallback` -can either be an initialisation strategy itself, in which case it will be -used to obtain new values, or it can be `nothing`, in which case an error -will be thrown. The default for `fallback` is `InitFromPrior()`. +## Custom parameter storage types -!!! note - The values in `params` must be provided in the space of the untransformed - distribution. +For `InitFromParams` to work correctly with a custom `params::P`, you need to implement + +```julia +DynamicPPL.init(rng, vn::VarName, dist::Distribution, p::InitFromParams{P}) where {P} +``` + +This tells you how to obtain values for the random variable `vn` from `p.params`. Note that +the last argument is `InitFromParams(params)`, not just `params` itself. Please see the +docstring of [`DynamicPPL.init`](@ref) for more information on the expected behaviour. + +If you only use `InitFromParams` with `DynamicPPL.OnlyAccsVarInfo`, as is usually the case, +then you will not need to implement anything else. So far, this is the same as you would do +for creating any new `AbstractInitStrategy` subtype. + +However, to use `InitFromParams` with a full `DynamicPPL.VarInfo`, you *may* also need to +implement + +```julia +DynamicPPL.get_param_eltype(p::InitFromParams{P}) where {P} +``` + +See the docstring of [`DynamicPPL.get_param_eltype`](@ref) for more information on when this +is needed. + +The argument `fallback` specifies how new values are to be obtained if they cannot be found +in `params`, or they are specified as `missing`. `fallback` can either be an initialisation +strategy itself, in which case it will be used to obtain new values, or it can be `nothing`, +in which case an error will be thrown. The default for `fallback` is `InitFromPrior()`. """ struct InitFromParams{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitStrategy params::P fallback::S - function InitFromParams( - params::AbstractDict{<:VarName}, - fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior(), - ) - return new{typeof(params),typeof(fallback)}(params, fallback) - end - function InitFromParams( - params::NamedTuple, fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior() - ) - return InitFromParams(to_varname_dict(params), fallback) - end end -function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams) +InitFromParams(params) = InitFromParams(params, InitFromPrior()) + +function init( + rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams{P} +) where {P<:Union{AbstractDict{<:VarName},NamedTuple}} # TODO(penelopeysm): It would be nice to do a check to make sure that all # of the parameters in `p.params` were actually used, and either warn or # error if they aren't. This is actually quite non-trivial though because @@ -119,13 +183,89 @@ function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitF else # TODO(penelopeysm): Since x is user-supplied, maybe we could also # check here that the type / size of x matches the dist? - x + x, typed_identity end else p.fallback === nothing && error("No value was provided for the variable `$(vn)`.") init(rng, vn, dist, p.fallback) end end +function get_param_eltype( + strategy::InitFromParams{<:Union{AbstractDict{<:VarName},NamedTuple}} +) + return infer_nested_eltype(typeof(strategy.params)) +end + +""" + RangeAndLinked + +Suppose we have vectorised parameters `params::AbstractVector{<:Real}`. Each random variable +in the model will in general correspond to a sub-vector of `params`. This struct stores +information about that range, as well as whether the sub-vector represents a linked value or +an unlinked value. + +$(TYPEDFIELDS) +""" +struct RangeAndLinked + # indices that the variable corresponds to in the vectorised parameter + range::UnitRange{Int} + # whether it's linked + is_linked::Bool +end + +""" + VectorWithRanges( + iden_varname_ranges::NamedTuple, + varname_ranges::Dict{VarName,RangeAndLinked}, + vect::AbstractVector{<:Real}, + ) + +A struct that wraps a vector of parameter values, plus information about how random +variables map to ranges in that vector. + +In the simplest case, this could be accomplished only with a single dictionary mapping +VarNames to ranges and link status. However, for performance reasons, we separate out +VarNames with identity optics into a NamedTuple (`iden_varname_ranges`). All +non-identity-optic VarNames are stored in the `varname_ranges` Dict. + +It would be nice to improve the NamedTuple and Dict approach. See, e.g. +https://github.com/TuringLang/DynamicPPL.jl/issues/1116. +""" +struct VectorWithRanges{N<:NamedTuple,T<:AbstractVector{<:Real}} + # This NamedTuple stores the ranges for identity VarNames + iden_varname_ranges::N + # This Dict stores the ranges for all other VarNames + varname_ranges::Dict{VarName,RangeAndLinked} + # The full parameter vector which we index into to get variable values + vect::T +end + +function _get_range_and_linked( + vr::VectorWithRanges, ::VarName{sym,typeof(identity)} +) where {sym} + return vr.iden_varname_ranges[sym] +end +function _get_range_and_linked(vr::VectorWithRanges, vn::VarName) + return vr.varname_ranges[vn] +end +function init( + ::Random.AbstractRNG, + vn::VarName, + dist::Distribution, + p::InitFromParams{<:VectorWithRanges}, +) + vr = p.params + range_and_linked = _get_range_and_linked(vr, vn) + transform = if range_and_linked.is_linked + from_linked_vec_transform(dist) + else + from_vec_transform(dist) + end + return (@view vr.vect[range_and_linked.range]), transform +end +function get_param_eltype(strategy::InitFromParams{<:VectorWithRanges}) + return eltype(strategy.params.vect) +end """ InitContext( @@ -155,9 +295,8 @@ function tilde_assume!!( ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo ) in_varinfo = haskey(vi, vn) - # `init()` always returns values in original space, i.e. possibly - # constrained - x = init(ctx.rng, vn, dist, ctx.strategy) + val, transform = init(ctx.rng, vn, dist, ctx.strategy) + x, inv_logjac = with_logabsdet_jacobian(transform, val) # Determine whether to insert a transformed value into the VarInfo. # If the VarInfo alrady had a value for this variable, we will # keep the same linked status as in the original VarInfo. If not, we @@ -165,17 +304,49 @@ function tilde_assume!!( # is_transformed(vi) returns true if vi is nonempty and all variables in vi # are linked. insert_transformed_value = in_varinfo ? is_transformed(vi, vn) : is_transformed(vi) - y, logjac = if insert_transformed_value - with_logabsdet_jacobian(link_transform(dist), x) + val_to_insert, logjac = if insert_transformed_value + # Calculate the forward logjac and sum them up. + y, fwd_logjac = with_logabsdet_jacobian(link_transform(dist), x) + # Note that if we use VectorWithRanges with a full VarInfo, this double-Jacobian + # calculation wastes a lot of time going from linked vectorised -> unlinked -> + # linked, and `inv_logjac` will also just be the negative of `fwd_logjac`. + # + # However, `VectorWithRanges` is only really used with `OnlyAccsVarInfo`, in which + # case this branch is never hit (since `in_varinfo` will always be false). It does + # mean that the combination of InitFromParams{<:VectorWithRanges} with a full, + # linked, VarInfo will be very slow. That should never really be used, though. So + # (at least for now) we can leave this branch in for full generality with other + # combinations of init strategies / VarInfo. + # + # TODO(penelopeysm): Figure out one day how to refactor this. The crux of the issue + # is that the transform used by `VectorWithRanges` is `from_linked_VEC_transform`, + # which is NOT the same as `inverse(link_transform)` (because there is an additional + # vectorisation step). We need `init` and `tilde_assume!!` to share this information + # but it's not clear right now how to do this. In my opinion, there are a couple of + # potential ways forward: + # + # 1. Just remove metadata entirely so that there is never any need to construct + # a linked vectorised value again. This would require us to use VAIMAcc as the only + # way of getting values. I consider this the best option, but it might take a long + # time. + # + # 2. Clean up the behaviour of bijectors so that we can have a complete separation + # between the linking and vectorisation parts of it. That way, `x` can either be + # unlinked, unlinked vectorised, linked, or linked vectorised, and regardless of + # which it is, we should only need to apply at most one linking and one + # vectorisation transform. Doing so would allow us to remove the first call to + # `with_logabsdet_jacobian`, and instead compose and/or uncompose the + # transformations before calling `with_logabsdet_jacobian` once. + y, -inv_logjac + fwd_logjac else - x, zero(LogProbType) + x, -inv_logjac end # Add the new value to the VarInfo. `push!!` errors if the value already # exists, hence the need for setindex!!. if in_varinfo - vi = setindex!!(vi, y, vn) + vi = setindex!!(vi, val_to_insert, vn) else - vi = push!!(vi, vn, y, dist) + vi = push!!(vi, vn, val_to_insert, dist) end # Neither of these set the `trans` flag so we have to do it manually if # necessary. diff --git a/src/experimental.jl b/src/experimental.jl index 8c82dca68..c644c09b2 100644 --- a/src/experimental.jl +++ b/src/experimental.jl @@ -2,6 +2,8 @@ module Experimental using DynamicPPL: DynamicPPL +include("fasteval.jl") + # This file only defines the names of the functions, and their docstrings. The actual implementations are in `ext/DynamicPPLJETExt.jl`, since we don't want to depend on JET.jl other than as a weak dependency. """ is_suitable_varinfo(model::Model, varinfo::AbstractVarInfo; kwargs...) diff --git a/src/fasteval.jl b/src/fasteval.jl new file mode 100644 index 000000000..c668b1413 --- /dev/null +++ b/src/fasteval.jl @@ -0,0 +1,336 @@ +using DynamicPPL: + AbstractVarInfo, + AccumulatorTuple, + InitContext, + InitFromParams, + LogJacobianAccumulator, + LogLikelihoodAccumulator, + LogPriorAccumulator, + Model, + ThreadSafeVarInfo, + VarInfo, + OnlyAccsVarInfo, + RangeAndLinked, + VectorWithRanges, + Metadata, + VarNamedVector, + default_accumulators, + float_type_with_fallback, + getlogjoint, + getlogjoint_internal, + getloglikelihood, + getlogprior, + getlogprior_internal +using ADTypes: ADTypes +using BangBang: BangBang +using AbstractPPL: AbstractPPL, VarName +using LogDensityProblems: LogDensityProblems +import DifferentiationInterface as DI +using Random: Random + +""" + FastLDF( + model::Model, + getlogdensity::Function=getlogjoint_internal, + varinfo::AbstractVarInfo=VarInfo(model); + adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, + ) + +A struct which contains a model, along with all the information necessary to: + + - calculate its log density at a given point; + - and if `adtype` is provided, calculate the gradient of the log density at that point. + +This information can be extracted using the LogDensityProblems.jl interface, specifically, +using `LogDensityProblems.logdensity` and `LogDensityProblems.logdensity_and_gradient`. If +`adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a concrete AD +backend type, then `logdensity_and_gradient` is also implemented. + +There are several options for `getlogdensity` that are 'supported' out of the box: + +- [`getlogjoint_internal`](@ref): calculate the log joint, including the log-Jacobian term + for any variables that have been linked in the provided VarInfo. +- [`getlogprior_internal`](@ref): calculate the log prior, including the log-Jacobian term + for any variables that have been linked in the provided VarInfo. +- [`getlogjoint`](@ref): calculate the log joint in the model space, ignoring any effects of + linking +- [`getlogprior`](@ref): calculate the log prior in the model space, ignoring any effects of + linking +- [`getloglikelihood`](@ref): calculate the log likelihood (this is unaffected by linking, + since transforms are only applied to random variables) + +!!! note + By default, `FastLDF` uses `getlogjoint_internal`, i.e., the result of + `LogDensityProblems.logdensity(f, x)` will depend on whether the `FastLDF` was created + with a linked or unlinked VarInfo. This is done primarily to ease interoperability with + MCMC samplers. + +If you provide one of these functions, a `VarInfo` will be automatically created for you. If +you provide a different function, you have to manually create a VarInfo and pass it as the +third argument. + +If the `adtype` keyword argument is provided, then this struct will also store the adtype +along with other information for efficient calculation of the gradient of the log density. +Note that preparing a `FastLDF` with an AD type `AutoBackend()` requires the AD backend +itself to have been loaded (e.g. with `import Backend`). + +## Fields + +Note that it is undefined behaviour to access any of a `FastLDF`'s fields, apart from: + +- `fastldf.model`: The original model from which this `FastLDF` was constructed. +- `fastldf.adtype`: The AD type used for gradient calculations, or `nothing` if no AD + type was provided. + +# Extended help + +Up until DynamicPPL v0.38, there have been two ways of evaluating a DynamicPPL model at a +given set of parameters: + +1. With `unflatten` + `evaluate!!` with `DefaultContext`: this stores a vector of parameters + inside a VarInfo's metadata, then reads parameter values from the VarInfo during evaluation. + +2. With `InitFromParams`: this reads parameter values from a NamedTuple or a Dict, and stores + them inside a VarInfo's metadata. + +In general, both of these approaches work fine, but the fact that they modify the VarInfo's +metadata can often be quite wasteful. In particular, it is very common that the only outputs +we care about from model evaluation are those which are stored in accumulators, such as log +probability densities, or `ValuesAsInModel`. + +To avoid this issue, we use `OnlyAccsVarInfo`, which is a VarInfo that only contains +accumulators. It implements enough of the `AbstractVarInfo` interface to not error during +model evaluation. + +Because `OnlyAccsVarInfo` does not store any parameter values, when evaluating a model with +it, it is mandatory that parameters are provided from outside the VarInfo, namely via +`InitContext`. + +The main problem that we face is that it is not possible to directly implement +`DynamicPPL.init(rng, vn, dist, strategy)` for `strategy::InitFromParams{<:AbstractVector}`. +In particular, it is not clear: + + - which parts of the vector correspond to which random variables, and + - whether the variables are linked or unlinked. + +Traditionally, this problem has been solved by `unflatten`, because that function would +place values into the VarInfo's metadata alongside the information about ranges and linking. +That way, when we evaluate with `DefaultContext`, we can read this information out again. +However, we want to avoid using a metadata. Thus, here, we _extract this information from +the VarInfo_ a single time when constructing a `FastLDF` object. Inside the FastLDF, we +store a mapping from VarNames to ranges in that vector, along with link status. + +For VarNames with identity optics, this is stored in a NamedTuple for efficiency. For all +other VarNames, this is stored in a Dict. The internal data structure used to represent this +could almost certainly be optimised further. See e.g. the discussion in +https://github.com/TuringLang/DynamicPPL.jl/issues/1116. + +When evaluating the model, this allows us to combine the parameter vector together with those +ranges to create an `InitFromParams{VectorWithRanges}`, which lets us very quickly read +parameter values from the vector. + +Note that this assumes that the ranges and link status are static throughout the lifetime of +the `FastLDF` object. Therefore, a `FastLDF` object cannot handle models which have variable +numbers of parameters, or models which may visit random variables in different orders depending +on stochastic control flow. **Indeed, silent errors may occur with such models.** This is a +general limitation of vectorised parameters: the original `unflatten` + `evaluate!!` +approach also fails with such models. +""" +struct FastLDF{ + M<:Model, + AD<:Union{ADTypes.AbstractADType,Nothing}, + F<:Function, + N<:NamedTuple, + ADP<:Union{Nothing,DI.GradientPrep}, +} + model::M + adtype::AD + _getlogdensity::F + _iden_varname_ranges::N + _varname_ranges::Dict{VarName,RangeAndLinked} + _adprep::ADP + + function FastLDF( + model::Model, + getlogdensity::Function=getlogjoint_internal, + varinfo::AbstractVarInfo=VarInfo(model); + adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, + ) + # Figure out which variable corresponds to which index, and + # which variables are linked. + all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo) + # Do AD prep if needed + prep = if adtype === nothing + nothing + else + # Make backend-specific tweaks to the adtype + adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo) + x = [val for val in varinfo[:]] + DI.prepare_gradient( + FastLogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges), + adtype, + x, + ) + end + return new{ + typeof(model), + typeof(adtype), + typeof(getlogdensity), + typeof(all_iden_ranges), + typeof(prep), + }( + model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep + ) + end +end + +################################### +# LogDensityProblems.jl interface # +################################### +""" + fast_ldf_accs(getlogdensity::Function) + +Determine which accumulators are needed for fast evaluation with the given +`getlogdensity` function. +""" +fast_ldf_accs(::Function) = default_accumulators() +fast_ldf_accs(::typeof(getlogjoint_internal)) = default_accumulators() +function fast_ldf_accs(::typeof(getlogjoint)) + return AccumulatorTuple((LogPriorAccumulator(), LogLikelihoodAccumulator())) +end +function fast_ldf_accs(::typeof(getlogprior_internal)) + return AccumulatorTuple((LogPriorAccumulator(), LogJacobianAccumulator())) +end +fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),)) +fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),)) + +struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple} + model::M + getlogdensity::F + iden_varname_ranges::N + varname_ranges::Dict{VarName,RangeAndLinked} +end +function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) + ctx = InitContext( + Random.default_rng(), + InitFromParams( + VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing + ), + ) + model = DynamicPPL.setleafcontext(f.model, ctx) + accs = fast_ldf_accs(f.getlogdensity) + # Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!, + # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!` + # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic + # here. + # TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what + # it _should_ do, but this is wrong regardless. + # https://github.com/TuringLang/DynamicPPL.jl/issues/1086 + vi = if Threads.nthreads() > 1 + accs = map( + acc -> DynamicPPL.convert_eltype(float_type_with_fallback(eltype(params)), acc), + accs, + ) + ThreadSafeVarInfo(OnlyAccsVarInfo(accs)) + else + OnlyAccsVarInfo(accs) + end + _, vi = DynamicPPL._evaluate!!(model, vi) + return f.getlogdensity(vi) +end + +function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) + return FastLogDensityAt( + fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges + )( + params + ) +end + +function LogDensityProblems.logdensity_and_gradient( + fldf::FastLDF, params::AbstractVector{<:Real} +) + return DI.value_and_gradient( + FastLogDensityAt( + fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges + ), + fldf._adprep, + fldf.adtype, + params, + ) +end + +###################################################### +# Helper functions to extract ranges and link status # +###################################################### + +# This fails for SimpleVarInfo, but honestly there is no reason to support that here. The +# fact is that evaluation doesn't use a VarInfo, it only uses it once to generate the ranges +# and link status. So there is no motivation to use SimpleVarInfo inside a +# LogDensityFunction any more, we can just always use typed VarInfo. In fact one could argue +# that there is no purpose in supporting untyped VarInfo either. +""" + get_ranges_and_linked(varinfo::VarInfo) + +Given a `VarInfo`, extract the ranges of each variable in the vectorised parameter +representation, along with whether each variable is linked or unlinked. + +This function should return a tuple containing: + +- A NamedTuple mapping VarNames with identity optics to their corresponding `RangeAndLinked` +- A Dict mapping all other VarNames to their corresponding `RangeAndLinked`. +""" +function get_ranges_and_linked(varinfo::VarInfo{<:NamedTuple{syms}}) where {syms} + all_iden_ranges = NamedTuple() + all_ranges = Dict{VarName,RangeAndLinked}() + offset = 1 + for sym in syms + md = varinfo.metadata[sym] + this_md_iden, this_md_others, offset = get_ranges_and_linked_metadata(md, offset) + all_iden_ranges = merge(all_iden_ranges, this_md_iden) + all_ranges = merge(all_ranges, this_md_others) + end + return all_iden_ranges, all_ranges +end +function get_ranges_and_linked(varinfo::VarInfo{<:Union{Metadata,VarNamedVector}}) + all_iden, all_others, _ = get_ranges_and_linked_metadata(varinfo.metadata, 1) + return all_iden, all_others +end +function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int) + all_iden_ranges = NamedTuple() + all_ranges = Dict{VarName,RangeAndLinked}() + offset = start_offset + for (vn, idx) in md.idcs + is_linked = md.is_transformed[idx] + range = md.ranges[idx] .+ (start_offset - 1) + if AbstractPPL.getoptic(vn) === identity + all_iden_ranges = merge( + all_iden_ranges, + NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), + ) + else + all_ranges[vn] = RangeAndLinked(range, is_linked) + end + offset += length(range) + end + return all_iden_ranges, all_ranges, offset +end +function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int) + all_iden_ranges = NamedTuple() + all_ranges = Dict{VarName,RangeAndLinked}() + offset = start_offset + for (vn, idx) in vnv.varname_to_index + is_linked = vnv.is_unconstrained[idx] + range = vnv.ranges[idx] .+ (start_offset - 1) + if AbstractPPL.getoptic(vn) === identity + all_iden_ranges = merge( + all_iden_ranges, + NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), + ) + else + all_ranges[vn] = RangeAndLinked(range, is_linked) + end + offset += length(range) + end + return all_iden_ranges, all_ranges, offset +end diff --git a/src/model.jl b/src/model.jl index 94fcd9fd4..2bcfe8f98 100644 --- a/src/model.jl +++ b/src/model.jl @@ -986,9 +986,13 @@ Return the arguments and keyword arguments to be passed to the evaluator of the ) where {_F,argnames} unwrap_args = [ if is_splat_symbol(var) - :($matchingvalue(varinfo, model.args.$var)...) + :( + $matchingvalue( + $get_param_eltype(varinfo, model.context), model.args.$var + )... + ) else - :($matchingvalue(varinfo, model.args.$var)) + :($matchingvalue($get_param_eltype(varinfo, model.context), model.args.$var)) end for var in argnames ] return quote @@ -1006,6 +1010,30 @@ Return the arguments and keyword arguments to be passed to the evaluator of the end end +""" + get_param_eltype(varinfo::AbstractVarInfo, context::AbstractContext) + +Get the element type of the parameters being used to evaluate a model, using a `varinfo` +under the given `context`. For example, when evaluating a model with ForwardDiff AD, this +should return `ForwardDiff.Dual`. + +By default, this uses `eltype(varinfo)` which is slightly cursed. This relies on the fact +that typically, before evaluation, the parameters will have been inserted into the VarInfo's +metadata field. + +For `InitContext`, it's quite different: because `InitContext` is responsible for supplying +the parameters, we can avoid using `eltype(varinfo)` and instead query the parameters inside +it. See the docstring of `get_param_eltype(strategy::AbstractInitStrategy)` for more +explanation. +""" +function get_param_eltype(vi::AbstractVarInfo, ctx::AbstractParentContext) + return get_param_eltype(vi, DynamicPPL.childcontext(ctx)) +end +get_param_eltype(vi::AbstractVarInfo, ::AbstractContext) = eltype(vi) +function get_param_eltype(::AbstractVarInfo, ctx::InitContext) + return get_param_eltype(ctx.strategy) +end + """ getargnames(model::Model) diff --git a/src/onlyaccs.jl b/src/onlyaccs.jl new file mode 100644 index 000000000..940f23124 --- /dev/null +++ b/src/onlyaccs.jl @@ -0,0 +1,42 @@ +""" + OnlyAccsVarInfo + +This is a wrapper around an `AccumulatorTuple` that implements the minimal `AbstractVarInfo` +interface to work with the `tilde_assume!!` and `tilde_observe!!` functions for +`InitContext`. + +Note that this does not implement almost every other AbstractVarInfo interface function, and +so using this with a different leaf context such as `DefaultContext` will result in errors. + +Conceptually, one can also think of this as a VarInfo that doesn't contain a metadata field. +This is also why it only works with `InitContext`: in this case, the parameters used for +evaluation are supplied by the context instead of the metadata. +""" +struct OnlyAccsVarInfo{Accs<:AccumulatorTuple} <: AbstractVarInfo + accs::Accs +end +OnlyAccsVarInfo() = OnlyAccsVarInfo(default_accumulators()) +function OnlyAccsVarInfo(accs::NTuple{N,AbstractAccumulator}) where {N} + return OnlyAccsVarInfo(AccumulatorTuple(accs)) +end + +# Minimal AbstractVarInfo interface +DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi +DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs +DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs) + +# Ideally, we'd define this together with InitContext, but alas that file comes way before +# this one, and sorting out the include order is a pain. +function tilde_assume!!( + ctx::InitContext, + dist::Distribution, + vn::VarName, + vi::Union{OnlyAccsVarInfo,ThreadSafeVarInfo{<:OnlyAccsVarInfo}}, +) + # For OnlyAccsVarInfo, since we don't need to write into the VarInfo, we can + # cut out a lot of the code above. + val, transform = init(ctx.rng, vn, dist, ctx.strategy) + x, inv_logjac = with_logabsdet_jacobian(transform, val) + vi = accumulate_assume!!(vi, x, -inv_logjac, vn, dist) + return x, vi +end diff --git a/src/utils.jl b/src/utils.jl index b55a2f715..75fb805dc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -15,6 +15,41 @@ This is Float64 on 64-bit systems and Float32 on 32-bit systems. """ const LogProbType = float(Real) +""" + typed_identity(x) + +Identity function, but with an overload for `with_logabsdet_jacobian` to ensure +that it returns a sensible zero logjac. + +The problem with plain old `identity` is that the default definition of +`with_logabsdet_jacobian` for `identity` returns `zero(eltype(x))`: +https://github.com/JuliaMath/ChangesOfVariables.jl/blob/d6a8115fc9b9419decbdb48e2c56ec9675b4c6a4/src/with_ladj.jl#L154 + +This is fine for most samples `x`, but if `eltype(x)` doesn't return a sensible type (e.g. +if it's `Any`), then using `identity` will error with `zero(Any)`. This can happen with, +for example, `ProductNamedTupleDistribution`: + +```julia +julia> using Distributions; d = product_distribution((a = Normal(), b = LKJCholesky(3, 0.5))); + +julia> eltype(rand(d)) +Any +``` + +The same problem precludes us from eventually broadening the scope of DynamicPPL.jl to +support distributions with non-numeric samples. + +Furthermore, in principle, the type of the log-probability should be separate from the type +of the sample. Thus, instead of using `zero(LogProbType)`, we should use the eltype of the +LogJacobianAccumulator. There's no easy way to thread that through here, but if a way to do +this is discovered, then `typed_identity` is what will allow us to obtain that custom +behaviour. +""" +function typed_identity end +@inline typed_identity(x) = x +@inline Bijectors.with_logabsdet_jacobian(::typeof(typed_identity), x) = + (x, zero(LogProbType)) + """ @addlogprob!(ex) diff --git a/test/Project.toml b/test/Project.toml index 2dbd5b455..efd916308 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,6 +6,7 @@ Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" diff --git a/test/fasteval.jl b/test/fasteval.jl new file mode 100644 index 000000000..db2333711 --- /dev/null +++ b/test/fasteval.jl @@ -0,0 +1,233 @@ +module DynamicPPLFastLDFTests + +using AbstractPPL: AbstractPPL +using Chairmarks +using DynamicPPL +using Distributions +using DistributionsAD: filldist +using ADTypes +using DynamicPPL.Experimental: FastLDF +using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest +using LinearAlgebra: I +using Test +using LogDensityProblems: LogDensityProblems + +using ForwardDiff: ForwardDiff +using ReverseDiff: ReverseDiff +# Need to include this block here in case we run this test file standalone +@static if VERSION < v"1.12" + using Pkg + Pkg.add("Mooncake") + using Mooncake: Mooncake +end + +@testset "FastLDF: Correctness" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$varinfo_func" for varinfo_func in [ + DynamicPPL.untyped_varinfo, + DynamicPPL.typed_varinfo, + DynamicPPL.untyped_vector_varinfo, + DynamicPPL.typed_vector_varinfo, + ] + unlinked_vi = varinfo_func(m) + @testset "$islinked" for islinked in (false, true) + vi = if islinked + DynamicPPL.link!!(unlinked_vi, m) + else + unlinked_vi + end + nt_ranges, dict_ranges = DynamicPPL.Experimental.get_ranges_and_linked(vi) + params = [x for x in vi[:]] + # Iterate over all variables + for vn in keys(vi) + # Check that `getindex_internal` returns the same thing as using the ranges + # directly + range_with_linked = if AbstractPPL.getoptic(vn) === identity + nt_ranges[AbstractPPL.getsym(vn)] + else + dict_ranges[vn] + end + @test params[range_with_linked.range] == + DynamicPPL.getindex_internal(vi, vn) + # Check that the link status is correct + @test range_with_linked.is_linked == islinked + end + + # Compare results of FastLDF vs ordinary LogDensityFunction. These tests + # can eventually go once we replace LogDensityFunction with FastLDF, but + # for now it helps to have this check! (Eventually we should just check + # against manually computed log-densities). + # + # TODO(penelopeysm): I think we need to add tests for some really + # pathological models here. + @testset "$getlogdensity" for getlogdensity in ( + DynamicPPL.getlogjoint_internal, + DynamicPPL.getlogjoint, + DynamicPPL.getloglikelihood, + DynamicPPL.getlogprior_internal, + DynamicPPL.getlogprior, + ) + ldf = DynamicPPL.LogDensityFunction(m, getlogdensity, vi) + fldf = FastLDF(m, getlogdensity, vi) + @test LogDensityProblems.logdensity(ldf, params) ≈ + LogDensityProblems.logdensity(fldf, params) + end + end + end + end + + @testset "Threaded observe" begin + if Threads.nthreads() > 1 + @model function threaded(y) + x ~ Normal() + Threads.@threads for i in eachindex(y) + y[i] ~ Normal(x) + end + end + N = 100 + model = threaded(zeros(N)) + ldf = DynamicPPL.Experimental.FastLDF(model) + + xs = [1.0] + @test LogDensityProblems.logdensity(ldf, xs) ≈ + logpdf(Normal(), xs[1]) + N * logpdf(Normal(xs[1]), 0.0) + end + end +end + +@testset "FastLDF: performance" begin + if Threads.nthreads() == 1 + # Evaluating these three models should not lead to any allocations (but only when + # not using TSVI). + @model function f() + x ~ Normal() + return 1.0 ~ Normal(x) + end + @model function submodel_inner() + m ~ Normal(0, 1) + s ~ Exponential() + return (m=m, s=s) + end + # Note that for the allocation tests to work on this one, `inner` has + # to be passed as an argument to `submodel_outer`, instead of just + # being called inside the model function itself + @model function submodel_outer(inner) + params ~ to_submodel(inner) + y ~ Normal(params.m, params.s) + return 1.0 ~ Normal(y) + end + @testset for model in + (f(), submodel_inner() | (; s=0.0), submodel_outer(submodel_inner())) + vi = VarInfo(model) + fldf = DynamicPPL.Experimental.FastLDF( + model, DynamicPPL.getlogjoint_internal, vi + ) + x = vi[:] + bench = median(@be LogDensityProblems.logdensity(fldf, x)) + @test iszero(bench.allocs) + end + end +end + +@testset "AD with FastLDF" begin + # Used as the ground truth that others are compared against. + ref_adtype = AutoForwardDiff() + + test_adtypes = @static if VERSION < v"1.12" + [ + AutoReverseDiff(; compile=false), + AutoReverseDiff(; compile=true), + AutoMooncake(; config=nothing), + ] + else + [AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true)] + end + + @testset "Correctness" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + varinfo = VarInfo(m) + linked_varinfo = DynamicPPL.link(varinfo, m) + f = FastLDF(m, getlogjoint_internal, linked_varinfo) + x = [p for p in linked_varinfo[:]] + + # Calculate reference logp + gradient of logp using ForwardDiff + ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) + ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual + + @testset "$adtype" for adtype in test_adtypes + @info "Testing AD on: $(m.f) - $adtype" + + @test run_ad( + m, + adtype; + varinfo=linked_varinfo, + test=WithExpectedResult(ref_logp, ref_grad), + ) isa Any + end + end + end + + # Test that various different ways of specifying array types as arguments work with all + # ADTypes. + @testset "Array argument types" begin + test_m = randn(2, 3) + + function eval_logp_and_grad(model, m, adtype) + ldf = FastLDF(model(); adtype=adtype) + return LogDensityProblems.logdensity_and_gradient(ldf, m[:]) + end + + @model function scalar_matrix_model(::Type{T}=Float64) where {T<:Real} + m = Matrix{T}(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + scalar_matrix_model_reference = eval_logp_and_grad( + scalar_matrix_model, test_m, ref_adtype + ) + + @model function matrix_model(::Type{T}=Matrix{Float64}) where {T} + m = T(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + matrix_model_reference = eval_logp_and_grad(matrix_model, test_m, ref_adtype) + + @model function scalar_array_model(::Type{T}=Float64) where {T<:Real} + m = Array{T}(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + scalar_array_model_reference = eval_logp_and_grad( + scalar_array_model, test_m, ref_adtype + ) + + @model function array_model(::Type{T}=Array{Float64}) where {T} + m = T(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + array_model_reference = eval_logp_and_grad(array_model, test_m, ref_adtype) + + @testset "$adtype" for adtype in test_adtypes + scalar_matrix_model_logp_and_grad = eval_logp_and_grad( + scalar_matrix_model, test_m, adtype + ) + @test scalar_matrix_model_logp_and_grad[1] ≈ scalar_matrix_model_reference[1] + @test scalar_matrix_model_logp_and_grad[2] ≈ scalar_matrix_model_reference[2] + matrix_model_logp_and_grad = eval_logp_and_grad(matrix_model, test_m, adtype) + @test matrix_model_logp_and_grad[1] ≈ matrix_model_reference[1] + @test matrix_model_logp_and_grad[2] ≈ matrix_model_reference[2] + scalar_array_model_logp_and_grad = eval_logp_and_grad( + scalar_array_model, test_m, adtype + ) + @test scalar_array_model_logp_and_grad[1] ≈ scalar_array_model_reference[1] + @test scalar_array_model_logp_and_grad[2] ≈ scalar_array_model_reference[2] + array_model_logp_and_grad = eval_logp_and_grad(array_model, test_m, adtype) + @test array_model_logp_and_grad[1] ≈ array_model_reference[1] + @test array_model_logp_and_grad[2] ≈ array_model_reference[2] + end + end +end + +end diff --git a/test/integration/enzyme/main.jl b/test/integration/enzyme/main.jl index b40bbeb8f..ea4ec497d 100644 --- a/test/integration/enzyme/main.jl +++ b/test/integration/enzyme/main.jl @@ -6,8 +6,10 @@ import Enzyme: set_runtime_activity, Forward, Reverse, Const using ForwardDiff: ForwardDiff # run_ad uses FD for correctness test ADTYPES = Dict( - "EnzymeForward" => AutoEnzyme(; mode=set_runtime_activity(Forward)), - "EnzymeReverse" => AutoEnzyme(; mode=set_runtime_activity(Reverse)), + "EnzymeForward" => + AutoEnzyme(; mode=set_runtime_activity(Forward), function_annotation=Const), + "EnzymeReverse" => + AutoEnzyme(; mode=set_runtime_activity(Reverse), function_annotation=Const), ) @testset "$ad_key" for (ad_key, ad_type) in ADTYPES diff --git a/test/runtests.jl b/test/runtests.jl index 861d3bb87..10fac8b0f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -89,6 +89,7 @@ include("test_util.jl") include("ext/DynamicPPLMooncakeExt.jl") end include("ad.jl") + include("fasteval.jl") end @testset "prob and logprob macro" begin @test_throws ErrorException prob"..." From 9624103885a6f8b6cac70b1d0796da5a6227b65a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 14 Nov 2025 00:34:59 +0000 Subject: [PATCH 006/148] implement `LogDensityProblems.dimension` --- src/fasteval.jl | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/fasteval.jl b/src/fasteval.jl index c668b1413..aa2fdd933 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -149,6 +149,7 @@ struct FastLDF{ _iden_varname_ranges::N _varname_ranges::Dict{VarName,RangeAndLinked} _adprep::ADP + _dim::Int function FastLDF( model::Model, @@ -159,13 +160,14 @@ struct FastLDF{ # Figure out which variable corresponds to which index, and # which variables are linked. all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo) + x = [val for val in varinfo[:]] + dim = length(x) # Do AD prep if needed prep = if adtype === nothing nothing else # Make backend-specific tweaks to the adtype adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo) - x = [val for val in varinfo[:]] DI.prepare_gradient( FastLogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges), adtype, @@ -179,7 +181,7 @@ struct FastLDF{ typeof(all_iden_ranges), typeof(prep), }( - model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep + model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep, dim ) end end @@ -260,6 +262,10 @@ function LogDensityProblems.logdensity_and_gradient( ) end +function LogDensityProblems.dimension(fldf::FastLDF) + return fldf._dim +end + ###################################################### # Helper functions to extract ranges and link status # ###################################################### From ce807139b31919b40afc98bcc522e2f41e14dc30 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 14 Nov 2025 00:40:37 +0000 Subject: [PATCH 007/148] forgot about capabilities... --- src/fasteval.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/fasteval.jl b/src/fasteval.jl index aa2fdd933..4f402f4a8 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -262,6 +262,16 @@ function LogDensityProblems.logdensity_and_gradient( ) end +function LogDensityProblems.capabilities( + ::Type{<:DynamicPPL.Experimental.FastLDF{M,Nothing}} +) where {M} + return LogDensityProblems.LogDensityOrder{0}() +end +function LogDensityProblems.capabilities( + ::Type{<:DynamicPPL.Experimental.FastLDF{M,<:ADTypes.AbstractADType}} +) where {M} + return LogDensityProblems.LogDensityOrder{1}() +end function LogDensityProblems.dimension(fldf::FastLDF) return fldf._dim end From 8553e401182b894a653b638e5978f00ae42ae031 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 18 Nov 2025 12:03:14 +0000 Subject: [PATCH 008/148] use interpolation in run_ad --- src/test_utils/ad.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index a49ffd18b..d7a34e6e0 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -298,8 +298,8 @@ function run_ad( # Benchmark grad_time, primal_time = if benchmark - primal_benchmark = @be (ldf, params) logdensity(_[1], _[2]) - grad_benchmark = @be (ldf, params) logdensity_and_gradient(_[1], _[2]) + primal_benchmark = @be logdensity($ldf, $params) + grad_benchmark = @be logdensity_and_gradient($ldf, $params) median_primal = median(primal_benchmark).time median_grad = median(grad_benchmark).time r(f) = round(f; sigdigits=4) From 3cd8d3431e14ebc581266c1323d1db8a5bd4c0eb Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 18 Nov 2025 17:48:32 +0000 Subject: [PATCH 009/148] Improvements to benchmark outputs (#1146) * print output * fix * reenable * add more lines to guide the eye * reorder table * print tgrad / trel as well * forgot this type --- benchmarks/benchmarks.jl | 38 +++++++++++++++++++++++++++++++++----- src/test_utils/ad.jl | 10 +++++++++- 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/benchmarks/benchmarks.jl b/benchmarks/benchmarks.jl index 3af6573cf..e8ffa7e0b 100644 --- a/benchmarks/benchmarks.jl +++ b/benchmarks/benchmarks.jl @@ -98,12 +98,15 @@ function run(; to_json=false) }[] for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinations - @info "Running benchmark for $model_name" + @info "Running benchmark for $model_name, $varinfo_choice, $adbackend, $islinked" relative_eval_time, relative_ad_eval_time = try results = benchmark(model, varinfo_choice, adbackend, islinked) + @info " t(eval) = $(results.primal_time)" + @info " t(grad) = $(results.grad_time)" (results.primal_time / reference_time), (results.grad_time / results.primal_time) catch e + @info "benchmark errored: $e" missing, missing end push!( @@ -155,18 +158,33 @@ function combine(head_filename::String, base_filename::String) all_testcases = union(Set(keys(head_testcases)), Set(keys(base_testcases))) @info "$(length(all_testcases)) unique test cases found" sorted_testcases = sort( - collect(all_testcases); by=(c -> (c.model_name, c.ad_backend, c.varinfo, c.linked)) + collect(all_testcases); by=(c -> (c.model_name, c.linked, c.varinfo, c.ad_backend)) ) results_table = Tuple{ - String,Int,String,String,Bool,String,String,String,String,String,String + String, + Int, + String, + String, + Bool, + String, + String, + String, + String, + String, + String, + String, + String, + String, }[] + sublabels = ["base", "this PR", "speedup"] results_colnames = [ [ EmptyCells(5), MultiColumn(3, "t(eval) / t(ref)"), MultiColumn(3, "t(grad) / t(eval)"), + MultiColumn(3, "t(grad) / t(ref)"), ], - [colnames[1:5]..., "base", "this PR", "speedup", "base", "this PR", "speedup"], + [colnames[1:5]..., sublabels..., sublabels..., sublabels...], ] sprint_float(x::Float64) = @sprintf("%.2f", x) sprint_float(m::Missing) = "err" @@ -183,6 +201,10 @@ function combine(head_filename::String, base_filename::String) # Finally that lets us do this division safely speedup_eval = base_eval / head_eval speedup_grad = base_grad / head_grad + # As well as this multiplication, which is t(grad) / t(ref) + head_grad_vs_ref = head_grad * head_eval + base_grad_vs_ref = base_grad * base_eval + speedup_grad_vs_ref = base_grad_vs_ref / head_grad_vs_ref push!( results_table, ( @@ -197,6 +219,9 @@ function combine(head_filename::String, base_filename::String) sprint_float(base_grad), sprint_float(head_grad), sprint_float(speedup_grad), + sprint_float(base_grad_vs_ref), + sprint_float(head_grad_vs_ref), + sprint_float(speedup_grad_vs_ref), ), ) end @@ -212,7 +237,10 @@ function combine(head_filename::String, base_filename::String) backend=:text, fit_table_in_display_horizontally=false, fit_table_in_display_vertically=false, - table_format=TextTableFormat(; horizontal_line_at_merged_column_labels=true), + table_format=TextTableFormat(; + horizontal_line_at_merged_column_labels=true, + horizontal_lines_at_data_rows=collect(3:3:length(results_table)), + ), ) println("```") end diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index d7a34e6e0..8ee850877 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -5,7 +5,13 @@ using Chairmarks: @be import DifferentiationInterface as DI using DocStringExtensions using DynamicPPL: - Model, LogDensityFunction, VarInfo, AbstractVarInfo, getlogjoint_internal, link + DynamicPPL, + Model, + LogDensityFunction, + VarInfo, + AbstractVarInfo, + getlogjoint_internal, + link using LogDensityProblems: logdensity, logdensity_and_gradient using Random: AbstractRNG, default_rng using Statistics: median @@ -298,7 +304,9 @@ function run_ad( # Benchmark grad_time, primal_time = if benchmark + logdensity(ldf, params) # Warm-up primal_benchmark = @be logdensity($ldf, $params) + logdensity_and_gradient(ldf, params) # Warm-up grad_benchmark = @be logdensity_and_gradient($ldf, $params) median_primal = median(primal_benchmark).time median_grad = median(grad_benchmark).time From eab71317d406a56fe06df7c8f944a4063e564112 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 19 Nov 2025 17:08:31 +0000 Subject: [PATCH 010/148] Add VarNamedTuple, tests, and WIP docs --- docs/src/internals/varnamedtuple.md | 112 ++++++++++ src/varnamedtuple.jl | 310 ++++++++++++++++++++++++++++ test/varnamedtuple.jl | 89 ++++++++ 3 files changed, 511 insertions(+) create mode 100644 docs/src/internals/varnamedtuple.md create mode 100644 src/varnamedtuple.jl create mode 100644 test/varnamedtuple.jl diff --git a/docs/src/internals/varnamedtuple.md b/docs/src/internals/varnamedtuple.md new file mode 100644 index 000000000..9f7a84cdb --- /dev/null +++ b/docs/src/internals/varnamedtuple.md @@ -0,0 +1,112 @@ +# VarNamedTuple as the basis of VarInfo + +This document collects thoughts and ideas for how to unify our multitude of AbstractVarInfo types using a VarNamedTuple type. It may eventually turn into a draft design document, but for now it is more raw than that. + +## The current situation + +We currently have the following AbstractVarInfo types: + + - A: VarInfo with Metadata + - B: VarInfo with VarNamedVector + - C: VarInfo with NamedTuple, with values being Metadata + - D: VarInfo with NamedTuple, with values being VarNamedVector + - E: SimpleVarInfo with NamedTuples + - F: SimpleVarInfo with OrderedDict + +A and C are the classic ones, and the defaults. C wraps groups the Metadata objects by the lead Symbol of the VarName of a variable, e.g. `x` in `@varname(x.y[1].z)`, which allows different lead Symbols to have different element types and for the VarInfo to still be type stable. B and D were created to simplify A and C, give them a nicer interface, and make them deal better with changing variable sizes, but according to recent (Oct 2025) benchmarks are quite a lot slower, which needs work. + +E and F are entirely distinct in implementation from the others. E is simply a mapping from Symbols to values, with each VarName being converted to a single symbol, e.g. `Symbol("a[1]")`. F is a mapping from VarNames to values as an OrderedDict, with VarName as the key type. + +A-D carry within them values for variables, but also their bijectors/distributions, and store all values vectorised, using the bijectors to map to the original values. They also store for each variable a flag for whether the variable has been linked. E-F store only the raw values, and a global flag for the whole SimpleVarInfo for whether it's linked. The link transform itself is implicit. + +TODO: Write a better summary of pros and cons of each approach. + +## VarNamedTuple + +VarNamedTuple has been discussed as a possible data structure to generalise the structure used in VarInfo to achieve type stability, i.e. grouping VarNames by their lead Symbol. The same NamedTuple structure has been used elsewhere, too, e.g. in Turing.GibbsContext. The idea was to encapsulate this structure into its own type, reducing code duplication and making the design more robust and powerful. See https://github.com/TuringLang/DynamicPPL.jl/issues/900 for the discussion. + +An AbstractVarInfo type could be only one application of VarNamedTuple, but here I'll focus on it exclusively. If we can make VarNamedTuple work for an AbstractVarInfo, I bet we can make it work for other purposes (condition, fix, Gibbs) as well. + +Without going into full detail, here's @mhauru's current proposal for what it would look like. This proposal remains in constant flux as I develop the code. + +A VarNamedTuple is a mapping of VarNames to values. Values can be anything. In the case of using VarNamedTuple to implement an AbstractVarInfo, the values would be random samples for random variables. However, they could hold with them extra information. For instance, we might use a value that is a tuple of a vectorised value, a bijector, and a flag for whether the variable is linked. + +I sometimes shorten VarNamedTuple to VNT. + +Internally, a VarNamedTuple consists of nested NamedTuples. For instance, the mapping `@varname(x) => 1, @varname(y.z) => 2` would be stored as + +``` +(; x=1, y=(; z=2)) +``` + +(This is a slight simplification, really it would be nested VarNamedTuples rather than NamedTuples, but I omit this detail.) +This forms a tree, with each node being a NamedTuple, like so: + +``` + NT +x / \ y + 1 NT + \ z + 2 +``` + +Each `NT` marks a NamedTuple, and the labels on the edges its keys. Here the root node has the keys `x` and `y`. This is like with the type stable VarInfo in our current design, except with possibly more levels (our current one only has the root node). Each nested `PropertyLens`, i.e. each `.` in a VarName like `@varname(a.b.c.e)`, creates a new layer of the tree. + +For simplicity, at least for now, we ban any VarNames where an `IndexLens` precedes a `PropertyLens`. That is, we ban any VarNames like `@varname(a.b[1].c)`. Recall that VarNames allow three different kinds of lenses: `PropertyLens`es, `IndexLens`es, and `identity` (the trivial lens). Thus the only allowed VarName types are `@varname(a.b.c.d)` and `@varname(a.b.c.d[i,j,k])`. + +This means that we can add levels to the NamedTuple tree until all `PropertyLenses` have been covered. The leaves of the tree are then of two kinds: They are either the raw value itself if the last lens of the VarName is an `identity`, or otherwise they are something that can be indexed with an `IndexLens`, such as an `Array`. + +To get a value from a VarNamedTuple is very simple: For `getindex(vnt::VNT, vn::VarName{S})` (`S` being the lead Symbol) you recurse into `getindex(vnt[S], unprefix(vn, S))`. If the last lens of `vn` is an `IndexLens`, we assume that the leaf of the NamedTuple tree we've reached contains something that can be indexed with it. + +Setting values in a VNT is equally simple if there are no `IndexLenses`: For `setindex!!(vnt::VNT, value::Any, vn::VarName)` one simply finds the leaf of the `vnt` tree corresponding to `vn` and sets its value to `value`. + +The tricky part is what to do when setting values with `IndexLenses`. There are three possible situations. Say one calls `setindex!!(vnt, 3.0, @varname(a.b[3]))`. + + 1. If `getindex(vnt, @varname(a.b))` is already a vector of length at least 3, this is easy: Just set the third element. + 2. If `getindex(vnt, @varname(a.b))` is a vector of length less than 3, what should we do? Do we error? Do we extend that vector? + 3. If `getindex(vnt, @varname(a.b))` isn't even set, what do we do? Say for instance that `vnt` is currently empty. We should set `vnt` to be something like `(; a=(; b=x))`, where `x` is such that `x[3] = 3.0`, but what exactly should `x` be? Is it a dictionary? A vector of length 3? If the latter, what are `x[2]` and `x[1]`? Or should this `setindex!!` call simply error? + +A note at this point: VarNamedTuples must always use `setindex!!`, the `!!` version that may or may not operate in place. The NamedTuples can't be modified in place, but the values at the leaves may be. Always using a `!!` function makes type stability easier, and makes structures like the type unstable old VarInfo with Metadata unnecessary: Any value can be set into any VarNamedTuple. The type parameters of the VNT will simply expand as necessary. + +To solve the problem of points 2. and 3. above I propose expanding the definition of VNT a bit. This will also help make VNT more flexible, which may help performance or allow more use cases. The modification is this: + +Unlike I said above, let's say that VNT isn't just nested NamedTuples with some values at the leaves. Let's say it also has a field called `make_leaf`. `make_leaf(value, lens)` is a function that takes any value, and a lens that is either `identity` or an `IndexLens`, and returns the value wrapped in some suitable struct that can be stored in the leaf of the NamedTuple tree. The values should always be such that `make_leaf(value, lens)[lens] == value`. + +Our earlier example of `VarNamedTuple(@varname(x) => 1, @varname(y.z) => 2; make_leaf=f)` would be stored as a tree like + +``` + --NT-- + x / \ y +f(1, identity) NT + \ z + f(2, identity) +``` + +The above, first draft of VNT which did not include `make_leaf` is equivalent to the trivial choice `make_leaf(value, lens) = lens === identity ? value : error("Don't know how to deal IndexLenses")`. The problems 2. and 3. above are "solved" by making it `make_leaf`'s problem to figure out what to do. For instance, `make_leaf` can always return a `Dict` that maps lenses to values. This is probably slow, but works for any lens. Or it can initialise a vector type, that can grow as needed when indexed into. + +The idea would be to use `make_leaf` to try out different ways of implementing a VarInfo, find a good default, and ,if necessary, leave the option for power users to customise behaviour. The first ones to implement would be + + - `make_leaf` that returns a Metadata object. This would be a direct replacement for type stable VarInfo that uses Metadata, except now with more nested levels of NamedTuple. + - `make_leaf` that returns an `OrderedDict`. This would be a direct replacement for SimpleVarInfo with OrderedDict. + +You may ask, have we simple gone from too many VarInfo types to too many `make_leaf` functions. Yes we have. But hopefully we have gained something in the process: + + - The leaf types can be simpler. They do not need to deal with VarNames any more, they only need to deal with `identity` lenses and `IndexLenses`. + - All AbstactVarInfos are as type stable as their leaf types allow. There is no more notion of an untyped VarInfo being converted to a typed one. + - Type stability is maintained even with nested `PropertyLenses` like `@varname(a.b)`, which happens a lot with submodels. + - Many functions that are currently implemented individually for each AbstactVarInfo type would now have a single implementation for the VarNamedTuple-based AbstactVarInfo type, reducing code duplication. I would also hope to get ride of most of the generated functions for in `varinfo.jl`. + +My guess is that the eventual One AbstractVarInfo To Rule Them All would have a `make_leaf` function that stores the raw values when the lens is an `identity`, and uses a flexible Vector, a lot like VarNamedVector, when the lens is an IndexLens. However, I could be wrong on that being the best option. Implementing and benchmarking is the only way to know. + +I think the two big questions are: + + - Will we run into some big, unanticipated blockers when we start to implement this. + - Will the nesting of NamedTuples cause performance regressions, if the compiler either chokes or gives up. + +I'll try to derisk these early on in this PR. + +## Questions / issues + + - People might really need IndexLenses in the middle of VarNames. The one place this comes up is submodels within a loop. I'm still inclined to keep designing without allowing for that, for now, but should keep in mind that that needs to be relaxed eventually. If it makes it easier, we can require that users explicitly tell us the size of any arrays for which this is done. + - When storing values for nested NamedTuples, the actual variable may be a struct. Do we need to be able to reconstruct the struct from the NamedTuple? If so, how do we do that? + - Do `Colon` indices cause any extra trouble for the leafnodes? diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl new file mode 100644 index 000000000..448ae4636 --- /dev/null +++ b/src/varnamedtuple.jl @@ -0,0 +1,310 @@ +# TODO(mhauru) This module should probably be moved to AbstractPPL. +module VarNamedTuples + +using AbstractPPL +using BangBang +using Accessors +using DynamicPPL: _compose_no_identity + +export VarNamedTuple + +"""The factor by which we increase the dimensions of PartialArrays when resizing them.""" +const PARTIAL_ARRAY_DIM_GROWTH_FACTOR = 4 + +_has_colon(::IndexLens{T}) where {T} = any(x <: Colon for x in T.parameters) + +function _is_multiindex(::IndexLens{T}) where {T} + return any(x <: UnitRange || x <: Colon for x in T.parameters) +end + +struct VarNamedTuple{T<:Function,Names,Values} + data::NamedTuple{Names,Values} + make_leaf::T +end + +struct IndexDict{T<:Function,Keys,Values} + data::Dict{Keys,Values} + make_leaf::T +end + +struct PartialArray{T<:Function,ElType,numdims} + data::Array{ElType,numdims} + mask::Array{Bool,numdims} + make_leaf::T +end + +function PartialArray(eltype, num_dims, make_leaf) + dims = ntuple(_ -> PARTIAL_ARRAY_DIM_GROWTH_FACTOR, num_dims) + data = Array{eltype,num_dims}(undef, dims) + mask = fill(false, dims) + return PartialArray(data, mask, make_leaf) +end + +_length_needed(i::Integer) = i +_length_needed(r::UnitRange) = last(r) +_length_needed(::Colon) = 0 + +"""Take the minimum size that a dimension of a PartialArray needs to be, and return the size +we choose it to be. This size will be the smallest possible power of +PARTIAL_ARRAY_DIM_GROWTH_FACTOR. Growing PartialArrays in big jumps like this helps reduce +data copying, as resizes aren't needed as often. +""" +function _partial_array_dim_size(min_dim) + factor = PARTIAL_ARRAY_DIM_GROWTH_FACTOR + return factor^(Int(ceil(log(factor, min_dim)))) +end + +function _resize_partialarray(iarr::PartialArray, inds) + min_sizes = ntuple(i -> max(size(iarr.data, i), _length_needed(inds[i])), length(inds)) + new_sizes = map(_partial_array_dim_size, min_sizes) + # Generic multidimensional Arrays can not be resized, so we need to make a new one. + # See https://github.com/JuliaLang/julia/issues/37900 + new_data = Array{eltype(iarr.data),ndims(iarr.data)}(undef, new_sizes) + new_mask = fill(false, new_sizes) + # Note that we have to use CartesianIndices instead of eachindex, because the latter + # may use a linear index that does not match between the old and the new arrays. + for i in CartesianIndices(iarr.data) + mask_val = iarr.mask[i] + @inbounds new_mask[i] = mask_val + if mask_val + @inbounds new_data[i] = iarr.data[i] + end + end + return PartialArray(new_data, new_mask, iarr.make_leaf) +end + +# The below implements the same functionality as above, but more performantly for 1D arrays. +function _resize_partialarray(iarr::PartialArray{T,Eltype,1}, (ind,)) where {T,Eltype} + # Resize arrays to accommodate new indices. + old_size = size(iarr.data, 1) + min_size = max(old_size, _length_needed(ind)) + new_size = _partial_array_dim_size(min_size) + resize!(iarr.data, new_size) + resize!(iarr.mask, new_size) + @inbounds iarr.mask[(old_size + 1):new_size] .= false + return iarr +end + +function BangBang.setindex!!(iarr::PartialArray, value, optic::IndexLens) + if _has_colon(optic) + # TODO(mhauru) This could be implemented by getting size information from `value`. + # However, the corresponding getindex is more fundamentally ill-defined. + throw(ArgumentError("Indexing with colons is not supported")) + end + inds = optic.indices + if length(inds) != ndims(iarr.data) + throw(ArgumentError("Invalid index $(inds)")) + end + iarr = if checkbounds(Bool, iarr.mask, inds...) + iarr + else + _resize_partialarray(iarr, inds) + end + new_data = setindex!!(iarr.data, value, inds...) + if _is_multiindex(optic) + iarr.mask[inds...] .= true + else + iarr.mask[inds...] = true + end + return PartialArray(new_data, iarr.mask, iarr.make_leaf) +end + +function Base.getindex(iarr::PartialArray, optic::IndexLens) + if _has_colon(optic) + throw(ArgumentError("Indexing with colons is not supported")) + end + inds = optic.indices + if length(inds) != ndims(iarr.data) + throw(ArgumentError("Invalid index $(inds)")) + end + if !haskey(iarr, optic) + throw(BoundsError(iarr, inds)) + end + return getindex(iarr.data, inds...) +end + +function Base.haskey(iarr::PartialArray, optic::IndexLens) + if _has_colon(optic) + throw(ArgumentError("Indexing with colons is not supported")) + end + inds = optic.indices + return checkbounds(Bool, iarr.mask, inds...) && + all(@inbounds(getindex(iarr.mask, inds...))) +end + +function make_leaf_array(value, ::PropertyLens{S}) where {S} + return VarNamedTuple(NamedTuple{(S,)}((value,)), make_leaf_array) +end +make_leaf_array(value, ::typeof(identity)) = value +function make_leaf_array(value, optic::ComposedFunction) + sub = make_leaf_array(value, optic.outer) + return make_leaf_array(sub, optic.inner) +end + +function make_leaf_array(value, optic::IndexLens{T}) where {T} + inds = optic.indices + num_inds = length(inds) + # Check if any of the indices are ranges or colons. If yes, value needs to be an + # AbstractArray. Otherwise it needs to be an individual value. + et = _is_multiindex(optic) ? eltype(value) : typeof(value) + iarr = PartialArray(et, num_inds, make_leaf_array) + return setindex!!(iarr, value, optic) +end + +function make_leaf_dict(value, ::PropertyLens{S}) where {S} + return VarNamedTuple(NamedTuple{(S,)}((value,)), make_leaf_dict) +end +make_leaf_dict(value, ::typeof(identity)) = value +function make_leaf_dict(value, optic::ComposedFunction) + sub = make_leaf_dict(value, optic.outer) + return make_leaf_dict(sub, optic.inner) +end +function make_leaf_dict(value, optic::IndexLens) + return IndexDict(Dict(optic.indices => value), make_leaf_dict) +end + +VarNamedTuple() = VarNamedTuple((;), make_leaf_array) + +function Base.show(io::IO, vnt::VarNamedTuple) + print(io, "(") + for (i, (name, value)) in enumerate(pairs(vnt.data)) + if i > 1 + print(io, ", ") + end + print(io, name, " -> ") + print(io, value) + end + return print(io, ")") +end + +function Base.show(io::IO, id::IndexDict) + return print(io, id.data) +end + +Base.getindex(vnt::VarNamedTuple, name::Symbol) = vnt.data[name] + +function varname_to_lens(name::VarName{S}) where {S} + return _compose_no_identity(getoptic(name), PropertyLens{S}()) +end + +function Base.getindex(vnt::VarNamedTuple, name::VarName) + return getindex(vnt, varname_to_lens(name)) +end +function Base.getindex( + x::Union{VarNamedTuple,IndexDict,PartialArray}, optic::ComposedFunction +) + subdata = getindex(x, optic.inner) + return getindex(subdata, optic.outer) +end +function Base.getindex(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} + return getindex(vnt.data, S) +end +function Base.getindex(id::IndexDict, optic::IndexLens) + return getindex(id.data, optic.indices) +end + +function Base.haskey(vnt::VarNamedTuple, name::VarName) + return haskey(vnt, varname_to_lens(name)) +end + +Base.haskey(vnt::VarNamedTuple, ::typeof(identity)) = true + +function Base.haskey(vnt::VarNamedTuple, optic::ComposedFunction) + return haskey(vnt, optic.inner) && haskey(getindex(vnt, optic.inner), optic.outer) +end + +Base.haskey(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} = haskey(vnt.data, S) +Base.haskey(id::IndexDict, optic::IndexLens) = haskey(id.data, optic.indices) +Base.haskey(::VarNamedTuple, ::IndexLens) = false +Base.haskey(::IndexDict, ::PropertyLens) = false + +# TODO(mhauru) This is type piracy. +Base.getindex(arr::AbstractArray, optic::IndexLens) = getindex(arr, optic.indices...) + +# TODO(mhauru) This is type piracy. +function BangBang.setindex!!(arr::AbstractArray, value, optic::IndexLens) + return BangBang.setindex!!(arr, value, optic.indices...) +end + +function BangBang.setindex!!(vnt::VarNamedTuple, value, name::VarName) + return BangBang.setindex!!(vnt, value, varname_to_lens(name)) +end + +function BangBang.setindex!!( + vnt::Union{VarNamedTuple,IndexDict,PartialArray}, value, optic::ComposedFunction +) + sub = if haskey(vnt, optic.inner) + BangBang.setindex!!(getindex(vnt, optic.inner), value, optic.outer) + else + vnt.make_leaf(value, optic.outer) + end + return BangBang.setindex!!(vnt, sub, optic.inner) +end + +function BangBang.setindex!!(vnt::VarNamedTuple, value, ::PropertyLens{S}) where {S} + # I would like this to just read + # return VarNamedTuple(BangBang.setindex!!(vnt.data, value, S), vnt.make_leaf) + # but that seems to be type unstable. Why? Shouldn't it obviously be the same as the + # below? + return VarNamedTuple(merge(vnt.data, NamedTuple{(S,)}((value,))), vnt.make_leaf) +end + +function BangBang.setindex!!(id::IndexDict, value, optic::IndexLens) + return IndexDict(setindex!!(id.data, value, optic.indices), id.make_leaf) +end + +function apply(func, vnt::VarNamedTuple, name::VarName) + if !haskey(vnt.data, name.name) + throw(KeyError(repr(name))) + end + subdata = getindex(vnt, name) + new_subdata = func(subdata) + return BangBang.setindex!!(vnt, new_subdata, name) +end + +function Base.map(func, vnt::VarNamedTuple) + new_data = NamedTuple{keys(vnt.data)}(map(func, values(vnt.data))) + return VarNamedTuple(new_data, vnt.make_leaf) +end + +function Base.keys(vnt::VarNamedTuple) + result = () + for sym in keys(vnt.data) + subdata = vnt.data[sym] + if subdata isa VarNamedTuple + subkeys = keys(subdata) + result = ( + (AbstractPPL.prefix(sk, VarName{sym}()) for sk in subkeys)..., result... + ) + else + result = (VarName{sym}(), result...) + end + subkeys = keys(vnt.data[sym]) + end + return result +end + +function Base.haskey(vnt::VarNamedTuple, name::VarName{S,Optic}) where {S,Optic} + if !haskey(vnt.data, S) + return false + end + subdata = vnt.data[S] + return if Optic === typeof(identity) + true + elseif Optic <: IndexLens + try + AbstractPPL.getoptic(name)(subdata) + true + catch e + if e isa BoundsError || e isa KeyError + false + else + rethrow(e) + end + end + else + haskey(subdata, AbstractPPL.unprefix(name, VarName{S}())) + end +end + +end diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl new file mode 100644 index 000000000..85b824ffc --- /dev/null +++ b/test/varnamedtuple.jl @@ -0,0 +1,89 @@ +module VarNamedTupleTests + +using Test: @inferred, @test, @test_throws, @testset +using DynamicPPL: @varname, VarNamedTuple +using BangBang: setindex!! + +@testset "Basic sets and gets" begin + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, 32.0, @varname(a))) + @test @inferred(getindex(vnt, @varname(a))) == 32.0 + + vnt = @inferred(setindex!!(vnt, [1, 2, 3], @varname(b))) + @test @inferred(getindex(vnt, @varname(b))) == [1, 2, 3] + @test @inferred(getindex(vnt, @varname(b[2]))) == 2 + + vnt = @inferred(setindex!!(vnt, 64.0, @varname(a))) + @test @inferred(getindex(vnt, @varname(a))) == 64.0 + @test @inferred(getindex(vnt, @varname(b))) == [1, 2, 3] + + vnt = @inferred(setindex!!(vnt, 15, @varname(b[2]))) + @test @inferred(getindex(vnt, @varname(b))) == [1, 15, 3] + @test @inferred(getindex(vnt, @varname(b[2]))) == 15 + + vnt = @inferred(setindex!!(vnt, [10], @varname(c.x.y))) + @test @inferred(getindex(vnt, @varname(c.x.y))) == [10] + + vnt = @inferred(setindex!!(vnt, 11, @varname(c.x.y[1]))) + @test @inferred(getindex(vnt, @varname(c.x.y))) == [11] + @test @inferred(getindex(vnt, @varname(c.x.y[1]))) == 11 + + vnt = @inferred(setindex!!(vnt, -1.0, @varname(d[4]))) + @test @inferred(getindex(vnt, @varname(d[4]))) == -1.0 + + vnt = @inferred(setindex!!(vnt, -2.0, @varname(d[4]))) + @test @inferred(getindex(vnt, @varname(d[4]))) == -2.0 + + # These can't be @inferred because `d` now has an abstract element type. Note that this + # does not ruin type stability for other varnames that don't involve `d`. + vnt = setindex!!(vnt, "a", @varname(d[5])) + @test getindex(vnt, @varname(d[5])) == "a" + + vnt = @inferred(setindex!!(vnt, 1.0, @varname(e.f[3].g.h[2].i))) + @test @inferred(getindex(vnt, @varname(e.f[3].g.h[2].i))) == 1.0 + + vnt = @inferred(setindex!!(vnt, 2.0, @varname(e.f[3].g.h[2].i))) + @test @inferred(getindex(vnt, @varname(e.f[3].g.h[2].i))) == 2.0 + + vec = fill(1.0, 4) + vnt = @inferred(setindex!!(vnt, vec, @varname(j[1:4]))) + @test @inferred(getindex(vnt, @varname(j[1:4]))) == vec + @test @inferred(getindex(vnt, @varname(j[2]))) == vec[2] + @test haskey(vnt, @varname(j[4])) + @test !haskey(vnt, @varname(j[5])) + @test_throws BoundsError getindex(vnt, @varname(j[5])) + + vec = fill(2.0, 4) + vnt = @inferred(setindex!!(vnt, vec, @varname(j[2:5]))) + @test @inferred(getindex(vnt, @varname(j[1]))) == 1.0 + @test @inferred(getindex(vnt, @varname(j[2:5]))) == vec + @test haskey(vnt, @varname(j[5])) + + arr = fill(2.0, (4, 2)) + vn = @varname(k.l[2:5, 3, 1:2, 2]) + vnt = @inferred(setindex!!(vnt, arr, vn)) + @test @inferred(getindex(vnt, vn)) == arr + # A subset of the elements set just now. + @test @inferred(getindex(vnt, @varname(k.l[2, 3, 1:2, 2]))) == fill(2.0, 2) + + # Not enough, or too many, indices. + @test_throws "Invalid index" setindex!!(vnt, 0.0, @varname(k.l[1, 2, 3])) + @test_throws "Invalid index" setindex!!(vnt, 0.0, @varname(k.l[1, 2, 3, 4, 5])) + + arr = fill(3.0, (3, 3)) + vn = @varname(k.l[1, 1:3, 1:3, 1]) + vnt = @inferred(setindex!!(vnt, arr, vn)) + @test @inferred(getindex(vnt, vn)) == arr + # A subset of the elements set just now. + @test @inferred(getindex(vnt, @varname(k.l[1, 1:2, 1:2, 1]))) == fill(3.0, 2, 2) + # A subset of the elements set previously. + @test @inferred(getindex(vnt, @varname(k.l[2, 3, 1:2, 2]))) == fill(2.0, 2) + @test !haskey(vnt, @varname(k.l[2, 3, 3, 2])) + + vnt = @inferred(setindex!!(vnt, 1.0, @varname(m[2]))) + vnt = @inferred(setindex!!(vnt, 1.0, @varname(m[3]))) + @test @inferred(getindex(vnt, @varname(m[2:3]))) == [1.0, 1.0] + @test !haskey(vnt, @varname(m[1])) +end + +end From 0c7825bd9b80494459393ea6b7349885d8c2e29c Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 20 Nov 2025 11:58:22 +0000 Subject: [PATCH 011/148] Add comparisons and merge --- src/varnamedtuple.jl | 179 +++++++++++++++++++++++++++++++++++++----- test/varnamedtuple.jl | 120 ++++++++++++++++++++++++++++ 2 files changed, 278 insertions(+), 21 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 448ae4636..006e8f0d5 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -4,16 +4,18 @@ module VarNamedTuples using AbstractPPL using BangBang using Accessors -using DynamicPPL: _compose_no_identity +using ..DynamicPPL: _compose_no_identity export VarNamedTuple """The factor by which we increase the dimensions of PartialArrays when resizing them.""" const PARTIAL_ARRAY_DIM_GROWTH_FACTOR = 4 -_has_colon(::IndexLens{T}) where {T} = any(x <: Colon for x in T.parameters) +const INDEX_TYPES = Union{Integer,UnitRange,Colon} -function _is_multiindex(::IndexLens{T}) where {T} +_has_colon(::T) where {T<:Tuple} = any(x <: Colon for x in T.parameters) + +function _is_multiindex(::T) where {T<:Tuple} return any(x <: UnitRange || x <: Colon for x in T.parameters) end @@ -22,6 +24,12 @@ struct VarNamedTuple{T<:Function,Names,Values} make_leaf::T end +# TODO(mhauru) Since I define this, should I also define `isequal` and `hash`? Same for +# PartialArrays. +function Base.:(==)(vnt1::VarNamedTuple, vnt2::VarNamedTuple) + return vnt1.make_leaf === vnt2.make_leaf && vnt1.data == vnt2.data +end + struct IndexDict{T<:Function,Keys,Values} data::Dict{Keys,Values} make_leaf::T @@ -33,13 +41,44 @@ struct PartialArray{T<:Function,ElType,numdims} make_leaf::T end -function PartialArray(eltype, num_dims, make_leaf) +function PartialArray(eltype, num_dims, make_leaf=make_leaf_array) dims = ntuple(_ -> PARTIAL_ARRAY_DIM_GROWTH_FACTOR, num_dims) data = Array{eltype,num_dims}(undef, dims) mask = fill(false, dims) return PartialArray(data, mask, make_leaf) end +Base.ndims(iarr::PartialArray) = ndims(iarr.data) + +# We deliberately don't define Base.size for PartialArray, because it is ill-defined. +# The size of the .data field is an implementation detail. +_internal_size(iarr::PartialArray, args...) = size(iarr.data, args...) + +function Base.copy(pa::PartialArray) + return PartialArray(copy(pa.data), copy(pa.mask), pa.make_leaf) +end + +function Base.:(==)(pa1::PartialArray, pa2::PartialArray) + if (pa1.make_leaf !== pa2.make_leaf) || (ndims(pa1) != ndims(pa2)) + return false + end + size1 = _internal_size(pa1) + size2 = _internal_size(pa2) + # TODO(mhauru) This could be optimised, but not sure it's worth it. + merge_size = ntuple(i -> max(size1[i], size2[i]), ndims(pa1)) + for i in CartesianIndices(merge_size) + m1 = checkbounds(Bool, pa1.mask, Tuple(i)...) ? pa1.mask[i] : false + m2 = checkbounds(Bool, pa2.mask, Tuple(i)...) ? pa2.mask[i] : false + if m1 != m2 + return false + end + if m1 && (pa1.data[i] != pa2.data[i]) + return false + end + end + return true +end + _length_needed(i::Integer) = i _length_needed(r::UnitRange) = last(r) _length_needed(::Colon) = 0 @@ -55,11 +94,13 @@ function _partial_array_dim_size(min_dim) end function _resize_partialarray(iarr::PartialArray, inds) - min_sizes = ntuple(i -> max(size(iarr.data, i), _length_needed(inds[i])), length(inds)) + min_sizes = ntuple( + i -> max(_internal_size(iarr, i), _length_needed(inds[i])), length(inds) + ) new_sizes = map(_partial_array_dim_size, min_sizes) # Generic multidimensional Arrays can not be resized, so we need to make a new one. # See https://github.com/JuliaLang/julia/issues/37900 - new_data = Array{eltype(iarr.data),ndims(iarr.data)}(undef, new_sizes) + new_data = Array{eltype(iarr.data),ndims(iarr)}(undef, new_sizes) new_mask = fill(false, new_sizes) # Note that we have to use CartesianIndices instead of eachindex, because the latter # may use a linear index that does not match between the old and the new arrays. @@ -76,7 +117,7 @@ end # The below implements the same functionality as above, but more performantly for 1D arrays. function _resize_partialarray(iarr::PartialArray{T,Eltype,1}, (ind,)) where {T,Eltype} # Resize arrays to accommodate new indices. - old_size = size(iarr.data, 1) + old_size = _internal_size(iarr, 1) min_size = max(old_size, _length_needed(ind)) new_size = _partial_array_dim_size(min_size) resize!(iarr.data, new_size) @@ -85,14 +126,19 @@ function _resize_partialarray(iarr::PartialArray{T,Eltype,1}, (ind,)) where {T,E return iarr end -function BangBang.setindex!!(iarr::PartialArray, value, optic::IndexLens) - if _has_colon(optic) +function BangBang.setindex!!(pa::PartialArray, value, optic::IndexLens) + return BangBang.setindex!!(pa, value, optic.indices...) +end +Base.getindex(pa::PartialArray, optic::IndexLens) = Base.getindex(pa, optic.indices...) +Base.haskey(pa::PartialArray, optic::IndexLens) = Base.haskey(pa, optic.indices) + +function BangBang.setindex!!(iarr::PartialArray, value, inds::Vararg{INDEX_TYPES}) + if _has_colon(inds) # TODO(mhauru) This could be implemented by getting size information from `value`. # However, the corresponding getindex is more fundamentally ill-defined. throw(ArgumentError("Indexing with colons is not supported")) end - inds = optic.indices - if length(inds) != ndims(iarr.data) + if length(inds) != ndims(iarr) throw(ArgumentError("Invalid index $(inds)")) end iarr = if checkbounds(Bool, iarr.mask, inds...) @@ -101,7 +147,7 @@ function BangBang.setindex!!(iarr::PartialArray, value, optic::IndexLens) _resize_partialarray(iarr, inds) end new_data = setindex!!(iarr.data, value, inds...) - if _is_multiindex(optic) + if _is_multiindex(inds) iarr.mask[inds...] .= true else iarr.mask[inds...] = true @@ -109,29 +155,105 @@ function BangBang.setindex!!(iarr::PartialArray, value, optic::IndexLens) return PartialArray(new_data, iarr.mask, iarr.make_leaf) end -function Base.getindex(iarr::PartialArray, optic::IndexLens) - if _has_colon(optic) +function Base.getindex(iarr::PartialArray, inds::Vararg{INDEX_TYPES}) + if _has_colon(inds) throw(ArgumentError("Indexing with colons is not supported")) end - inds = optic.indices - if length(inds) != ndims(iarr.data) + if length(inds) != ndims(iarr) throw(ArgumentError("Invalid index $(inds)")) end - if !haskey(iarr, optic) + if !haskey(iarr, inds) throw(BoundsError(iarr, inds)) end return getindex(iarr.data, inds...) end -function Base.haskey(iarr::PartialArray, optic::IndexLens) - if _has_colon(optic) +function Base.haskey(iarr::PartialArray, inds) + if _has_colon(inds) throw(ArgumentError("Indexing with colons is not supported")) end - inds = optic.indices return checkbounds(Bool, iarr.mask, inds...) && all(@inbounds(getindex(iarr.mask, inds...))) end +Base.merge(x1::PartialArray, x2::PartialArray) = _merge_recursive(x1, x2) +Base.merge(x1::VarNamedTuple, x2::VarNamedTuple) = _merge_recursive(x1, x2) +_merge_recursive(_, x2) = x2 + +function _merge_element_recursive(x1::PartialArray, x2::PartialArray, ind::CartesianIndex) + m1 = x1.mask[ind] + m2 = x2.mask[ind] + return if m1 && m2 + _merge_recursive(x1.data[ind], x2.data[ind]) + elseif m2 + x2.data[ind] + else + x1.data[ind] + end +end + +# TODO(mhauru) Would this benefit from a specialised method for 1D PartialArrays? +function _merge_recursive(pa1::PartialArray, pa2::PartialArray) + if ndims(pa1) != ndims(pa2) + throw( + ArgumentError("Cannot merge PartialArrays with different number of dimensions") + ) + end + if pa1.make_leaf !== pa2.make_leaf + throw( + ArgumentError("Cannot merge PartialArrays with different make_leaf functions") + ) + end + num_dims = ndims(pa1) + merge_size = ntuple(i -> max(_internal_size(pa1, i), _internal_size(pa2, i)), num_dims) + result = if merge_size == _internal_size(pa2) + # Either pa2 is strictly bigger than pa1, or they are equal in size. + result = copy(pa2) + for i in CartesianIndices(pa1.data) + @inbounds if pa1.mask[i] + result = setindex!!( + result, _merge_element_recursive(pa1, result, i), Tuple(i)... + ) + end + end + result + else + if merge_size == _internal_size(pa1) + # pa1 is bigger than pa2 + result = copy(pa1) + for i in CartesianIndices(pa2.data) + @inbounds if pa2.mask[i] + result = setindex!!( + result, _merge_element_recursive(result, pa2, i), Tuple(i)... + ) + end + end + result + else + # Neither is strictly bigger than the other. + et = promote_type(eltype(pa1), eltype(pa2)) + new_data = Array{et,num_dims}(undef, merge_size) + new_mask = fill(false, merge_size) + result = PartialArray(new_data, new_mask, pa2.make_leaf) + for i in CartesianIndices(pa2.data) + @inbounds if pa2.mask[i] + result.mask[i] = true + result.data[i] = pa2.data[i] + end + end + for i in CartesianIndices(pa1.data) + @inbounds if pa1.mask[i] + result = setindex!!( + result, _merge_element_recursive(pa1, result, i), Tuple(i)... + ) + end + end + result + end + end + return result +end + function make_leaf_array(value, ::PropertyLens{S}) where {S} return VarNamedTuple(NamedTuple{(S,)}((value,)), make_leaf_array) end @@ -146,7 +268,7 @@ function make_leaf_array(value, optic::IndexLens{T}) where {T} num_inds = length(inds) # Check if any of the indices are ranges or colons. If yes, value needs to be an # AbstractArray. Otherwise it needs to be an individual value. - et = _is_multiindex(optic) ? eltype(value) : typeof(value) + et = _is_multiindex(optic.indices) ? eltype(value) : typeof(value) iarr = PartialArray(et, num_inds, make_leaf_array) return setindex!!(iarr, value, optic) end @@ -307,4 +429,19 @@ function Base.haskey(vnt::VarNamedTuple, name::VarName{S,Optic}) where {S,Optic} end end +# TODO(mhauru) Check the performance of this, and make it into a generated function if +# necessary. +function _merge_recursive(vnt1::VarNamedTuple, vnt2::VarNamedTuple) + result_data = vnt1.data + for k in keys(vnt2.data) + val = if haskey(result_data, k) + _merge_recursive(result_data[k], vnt2.data[k]) + else + vnt2.data[k] + end + Accessors.@reset result_data[k] = val + end + return VarNamedTuple(result_data, vnt2.make_leaf) +end + end diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 85b824ffc..f9864e7be 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -86,4 +86,124 @@ using BangBang: setindex!! @test !haskey(vnt, @varname(m[1])) end +@testset "equality" begin + vnt1 = VarNamedTuple() + vnt2 = VarNamedTuple() + @test vnt1 == vnt2 + + vnt1 = setindex!!(vnt1, 1.0, @varname(a)) + @test vnt1 != vnt2 + + vnt2 = setindex!!(vnt2, 1.0, @varname(a)) + @test vnt1 == vnt2 + + vnt1 = setindex!!(vnt1, [1, 2], @varname(b)) + vnt2 = setindex!!(vnt2, [1, 2], @varname(b)) + @test vnt1 == vnt2 + + vnt2 = setindex!!(vnt2, [1, 3], @varname(b)) + @test vnt1 != vnt2 + vnt2 = setindex!!(vnt2, [1, 2], @varname(b)) + + # Try with index lenses too + vnt1 = setindex!!(vnt1, 2, @varname(c[2])) + vnt2 = setindex!!(vnt2, 2, @varname(c[2])) + @test vnt1 == vnt2 + + vnt2 = setindex!!(vnt2, 3, @varname(c[2])) + @test vnt1 != vnt2 + vnt2 = setindex!!(vnt2, 2, @varname(c[2])) + + vnt1 = setindex!!(vnt1, ["a", "b"], @varname(d.e[1:2])) + vnt2 = setindex!!(vnt2, ["a", "b"], @varname(d.e[1:2])) + @test vnt1 == vnt2 + + vnt2 = setindex!!(vnt2, :b, @varname(d.e[2])) + @test vnt1 != vnt2 +end + +@testset "merge" begin + vnt1 = VarNamedTuple() + vnt2 = VarNamedTuple() + expected_merge = VarNamedTuple() + # TODO(mhauru) Wrap this merge in @inferred, likewise other merges where it makes sense. + @test merge(vnt1, vnt2) == expected_merge + + vnt1 = setindex!!(vnt1, 1.0, @varname(a)) + vnt2 = setindex!!(vnt2, 2.0, @varname(b)) + vnt1 = setindex!!(vnt1, 1, @varname(c)) + vnt2 = setindex!!(vnt2, 2, @varname(c)) + expected_merge = setindex!!(expected_merge, 1.0, @varname(a)) + expected_merge = setindex!!(expected_merge, 2, @varname(c)) + expected_merge = setindex!!(expected_merge, 2.0, @varname(b)) + @test merge(vnt1, vnt2) == expected_merge + + vnt1 = VarNamedTuple() + vnt2 = VarNamedTuple() + expected_merge = VarNamedTuple() + vnt1 = setindex!!(vnt1, [1], @varname(d.a)) + vnt2 = setindex!!(vnt2, [2, 2], @varname(d.b)) + vnt1 = setindex!!(vnt1, [1], @varname(d.c)) + vnt2 = setindex!!(vnt2, [2, 2], @varname(d.c)) + expected_merge = setindex!!(expected_merge, [1], @varname(d.a)) + expected_merge = setindex!!(expected_merge, [2, 2], @varname(d.c)) + expected_merge = setindex!!(expected_merge, [2, 2], @varname(d.b)) + @test merge(vnt1, vnt2) == expected_merge + + vnt1 = setindex!!(vnt1, 1, @varname(e.a[1])) + vnt2 = setindex!!(vnt2, 2, @varname(e.a[2])) + expected_merge = setindex!!(expected_merge, 1, @varname(e.a[1])) + expected_merge = setindex!!(expected_merge, 2, @varname(e.a[2])) + vnt1 = setindex!!(vnt1, 1, @varname(e.a[3])) + vnt2 = setindex!!(vnt2, 2, @varname(e.a[3])) + expected_merge = setindex!!(expected_merge, 2, @varname(e.a[3])) + @test merge(vnt1, vnt2) == expected_merge + + vnt1 = setindex!!(vnt1, fill(1, 4), @varname(e.a[7:10])) + vnt2 = setindex!!(vnt2, fill(2, 4), @varname(e.a[8:11])) + expected_merge = setindex!!(expected_merge, 1, @varname(e.a[7])) + expected_merge = setindex!!(expected_merge, fill(2, 4), @varname(e.a[8:11])) + @test merge(vnt1, vnt2) == expected_merge + + vnt1 = setindex!!(vnt1, ["1", "1"], @varname(f.a[1].b.c[2, 2].d[1, 3:4])) + vnt2 = setindex!!(vnt2, ["2", "2"], @varname(f.a[1].b.c[2, 2].d[1, 3:4])) + expected_merge = setindex!!( + expected_merge, ["2", "2"], @varname(f.a[1].b.c[2, 2].d[1, 3:4]) + ) + vnt1 = setindex!!(vnt1, :1, @varname(f.a[1].b.c[3, 2].d[1, 1])) + vnt2 = setindex!!(vnt2, :2, @varname(f.a[1].b.c[4, 2].d[1, 1])) + expected_merge = setindex!!(expected_merge, :1, @varname(f.a[1].b.c[3, 2].d[1, 1])) + expected_merge = setindex!!(expected_merge, :2, @varname(f.a[1].b.c[4, 2].d[1, 1])) + @test merge(vnt1, vnt2) == expected_merge + + # PartialArrays with different sizes. + vnt1 = VarNamedTuple() + vnt2 = VarNamedTuple() + vnt1 = setindex!!(vnt1, 1, @varname(a[1])) + vnt1 = setindex!!(vnt1, 1, @varname(a[1025])) + vnt2 = setindex!!(vnt2, 2, @varname(a[1])) + vnt2 = setindex!!(vnt2, 2, @varname(a[2])) + expected_merge_12 = VarNamedTuple() + expected_merge_12 = setindex!!(expected_merge_12, 1, @varname(a[1025])) + expected_merge_12 = setindex!!(expected_merge_12, 2, @varname(a[1])) + expected_merge_12 = setindex!!(expected_merge_12, 2, @varname(a[2])) + @test merge(vnt1, vnt2) == expected_merge_12 + expected_merge_21 = setindex!!(expected_merge_12, 1, @varname(a[1])) + @test merge(vnt2, vnt1) == expected_merge_21 + + vnt1 = VarNamedTuple() + vnt2 = VarNamedTuple() + vnt1 = setindex!!(vnt1, 1, @varname(a[1, 1])) + vnt1 = setindex!!(vnt1, 1, @varname(a[1025, 1])) + vnt2 = setindex!!(vnt2, :2, @varname(a[1, 1])) + vnt2 = setindex!!(vnt2, :2, @varname(a[1, 1025])) + expected_merge_12 = VarNamedTuple() + expected_merge_12 = setindex!!(expected_merge_12, :2, @varname(a[1, 1])) + expected_merge_12 = setindex!!(expected_merge_12, 1, @varname(a[1025, 1])) + expected_merge_12 = setindex!!(expected_merge_12, :2, @varname(a[1, 1025])) + @test merge(vnt1, vnt2) == expected_merge_12 + expected_merge_21 = setindex!!(expected_merge_12, 1, @varname(a[1, 1])) + @test merge(vnt2, vnt1) == expected_merge_21 +end + end From 15d5a8a97795de35390706e858cf60a48cb17b76 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 20 Nov 2025 12:09:39 +0000 Subject: [PATCH 012/148] Start using VNT in FastLDF --- src/DynamicPPL.jl | 2 ++ src/contexts/init.jl | 16 +++------ src/fasteval.jl | 81 +++++++++++++------------------------------ test/fasteval.jl | 8 ++--- test/varnamedtuple.jl | 12 +++---- 5 files changed, 39 insertions(+), 80 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index e9b902363..5f32a8b66 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -178,6 +178,8 @@ abstract type AbstractVarInfo <: AbstractModelTrace end # Necessary forward declarations include("utils.jl") +include("varnamedtuple.jl") +using .VarNamedTuples: VarNamedTuple include("contexts.jl") include("contexts/default.jl") include("contexts/init.jl") diff --git a/src/contexts/init.jl b/src/contexts/init.jl index a79969a13..a0ad92fe3 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -215,8 +215,7 @@ end """ VectorWithRanges( - iden_varname_ranges::NamedTuple, - varname_ranges::Dict{VarName,RangeAndLinked}, + varname_ranges::VarNamedTuple, vect::AbstractVector{<:Real}, ) @@ -231,20 +230,13 @@ non-identity-optic VarNames are stored in the `varname_ranges` Dict. It would be nice to improve the NamedTuple and Dict approach. See, e.g. https://github.com/TuringLang/DynamicPPL.jl/issues/1116. """ -struct VectorWithRanges{N<:NamedTuple,T<:AbstractVector{<:Real}} - # This NamedTuple stores the ranges for identity VarNames - iden_varname_ranges::N - # This Dict stores the ranges for all other VarNames - varname_ranges::Dict{VarName,RangeAndLinked} +struct VectorWithRanges{VNT<:VarNamedTuple,T<:AbstractVector{<:Real}} + # Ranges for all VarNames + varname_ranges::VNT # The full parameter vector which we index into to get variable values vect::T end -function _get_range_and_linked( - vr::VectorWithRanges, ::VarName{sym,typeof(identity)} -) where {sym} - return vr.iden_varname_ranges[sym] -end function _get_range_and_linked(vr::VectorWithRanges, vn::VarName) return vr.varname_ranges[vn] end diff --git a/src/fasteval.jl b/src/fasteval.jl index 4f402f4a8..b82180dca 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -13,6 +13,7 @@ using DynamicPPL: RangeAndLinked, VectorWithRanges, Metadata, + VarNamedTuple, VarNamedVector, default_accumulators, float_type_with_fallback, @@ -140,14 +141,13 @@ struct FastLDF{ M<:Model, AD<:Union{ADTypes.AbstractADType,Nothing}, F<:Function, - N<:NamedTuple, + VNT<:VarNamedTuple, ADP<:Union{Nothing,DI.GradientPrep}, } model::M adtype::AD _getlogdensity::F - _iden_varname_ranges::N - _varname_ranges::Dict{VarName,RangeAndLinked} + _varname_ranges::VNT _adprep::ADP _dim::Int @@ -159,7 +159,7 @@ struct FastLDF{ ) # Figure out which variable corresponds to which index, and # which variables are linked. - all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo) + all_ranges = get_ranges_and_linked(varinfo) x = [val for val in varinfo[:]] dim = length(x) # Do AD prep if needed @@ -169,19 +169,17 @@ struct FastLDF{ # Make backend-specific tweaks to the adtype adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo) DI.prepare_gradient( - FastLogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges), - adtype, - x, + FastLogDensityAt(model, getlogdensity, all_ranges), adtype, x ) end return new{ typeof(model), typeof(adtype), typeof(getlogdensity), - typeof(all_iden_ranges), + typeof(all_ranges), typeof(prep), }( - model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep, dim + model, adtype, getlogdensity, all_ranges, prep, dim ) end end @@ -206,18 +204,15 @@ end fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),)) fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),)) -struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple} +struct FastLogDensityAt{M<:Model,F<:Function,VNT<:VarNamedTuple} model::M getlogdensity::F - iden_varname_ranges::N - varname_ranges::Dict{VarName,RangeAndLinked} + varname_ranges::VNT end function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) ctx = InitContext( Random.default_rng(), - InitFromParams( - VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing - ), + InitFromParams(VectorWithRanges(f.varname_ranges, params), nothing), ) model = DynamicPPL.setleafcontext(f.model, ctx) accs = fast_ldf_accs(f.getlogdensity) @@ -242,20 +237,14 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) end function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) - return FastLogDensityAt( - fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges - )( - params - ) + return FastLogDensityAt(fldf.model, fldf._getlogdensity, fldf._varname_ranges)(params) end function LogDensityProblems.logdensity_and_gradient( fldf::FastLDF, params::AbstractVector{<:Real} ) return DI.value_and_gradient( - FastLogDensityAt( - fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges - ), + FastLogDensityAt(fldf.model, fldf._getlogdensity, fldf._varname_ranges), fldf._adprep, fldf.adtype, params, @@ -291,62 +280,42 @@ end Given a `VarInfo`, extract the ranges of each variable in the vectorised parameter representation, along with whether each variable is linked or unlinked. -This function should return a tuple containing: - -- A NamedTuple mapping VarNames with identity optics to their corresponding `RangeAndLinked` -- A Dict mapping all other VarNames to their corresponding `RangeAndLinked`. +This function returns a VarNamedTuple mapping all VarNames to their corresponding +`RangeAndLinked`. """ function get_ranges_and_linked(varinfo::VarInfo{<:NamedTuple{syms}}) where {syms} - all_iden_ranges = NamedTuple() - all_ranges = Dict{VarName,RangeAndLinked}() + all_ranges = VarNamedTuple() offset = 1 for sym in syms md = varinfo.metadata[sym] - this_md_iden, this_md_others, offset = get_ranges_and_linked_metadata(md, offset) - all_iden_ranges = merge(all_iden_ranges, this_md_iden) + this_md_others, offset = get_ranges_and_linked_metadata(md, offset) all_ranges = merge(all_ranges, this_md_others) end - return all_iden_ranges, all_ranges + return all_ranges end function get_ranges_and_linked(varinfo::VarInfo{<:Union{Metadata,VarNamedVector}}) - all_iden, all_others, _ = get_ranges_and_linked_metadata(varinfo.metadata, 1) - return all_iden, all_others + all_ranges, _ = get_ranges_and_linked_metadata(varinfo.metadata, 1) + return all_ranges end function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int) - all_iden_ranges = NamedTuple() - all_ranges = Dict{VarName,RangeAndLinked}() + all_ranges = VarNamedTuple() offset = start_offset for (vn, idx) in md.idcs is_linked = md.is_transformed[idx] range = md.ranges[idx] .+ (start_offset - 1) - if AbstractPPL.getoptic(vn) === identity - all_iden_ranges = merge( - all_iden_ranges, - NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), - ) - else - all_ranges[vn] = RangeAndLinked(range, is_linked) - end + all_ranges = BangBang.setindex!!(all_ranges, RangeAndLinked(range, is_linked), vn) offset += length(range) end - return all_iden_ranges, all_ranges, offset + return all_ranges, offset end function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int) - all_iden_ranges = NamedTuple() - all_ranges = Dict{VarName,RangeAndLinked}() + all_ranges = VarNamedTuple() offset = start_offset for (vn, idx) in vnv.varname_to_index is_linked = vnv.is_unconstrained[idx] range = vnv.ranges[idx] .+ (start_offset - 1) - if AbstractPPL.getoptic(vn) === identity - all_iden_ranges = merge( - all_iden_ranges, - NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), - ) - else - all_ranges[vn] = RangeAndLinked(range, is_linked) - end + all_ranges = BangBang.setindex!!(all_ranges, RangeAndLinked(range, is_linked), vn) offset += length(range) end - return all_iden_ranges, all_ranges, offset + return all_ranges, offset end diff --git a/test/fasteval.jl b/test/fasteval.jl index db2333711..2ad50ed26 100644 --- a/test/fasteval.jl +++ b/test/fasteval.jl @@ -36,17 +36,13 @@ end else unlinked_vi end - nt_ranges, dict_ranges = DynamicPPL.Experimental.get_ranges_and_linked(vi) + ranges = DynamicPPL.Experimental.get_ranges_and_linked(vi) params = [x for x in vi[:]] # Iterate over all variables for vn in keys(vi) # Check that `getindex_internal` returns the same thing as using the ranges # directly - range_with_linked = if AbstractPPL.getoptic(vn) === identity - nt_ranges[AbstractPPL.getsym(vn)] - else - dict_ranges[vn] - end + range_with_linked = ranges[vn] @test params[range_with_linked.range] == DynamicPPL.getindex_internal(vi, vn) # Check that the link status is correct diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index f9864e7be..99f528175 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -180,11 +180,11 @@ end vnt1 = VarNamedTuple() vnt2 = VarNamedTuple() vnt1 = setindex!!(vnt1, 1, @varname(a[1])) - vnt1 = setindex!!(vnt1, 1, @varname(a[1025])) + vnt1 = setindex!!(vnt1, 1, @varname(a[257])) vnt2 = setindex!!(vnt2, 2, @varname(a[1])) vnt2 = setindex!!(vnt2, 2, @varname(a[2])) expected_merge_12 = VarNamedTuple() - expected_merge_12 = setindex!!(expected_merge_12, 1, @varname(a[1025])) + expected_merge_12 = setindex!!(expected_merge_12, 1, @varname(a[257])) expected_merge_12 = setindex!!(expected_merge_12, 2, @varname(a[1])) expected_merge_12 = setindex!!(expected_merge_12, 2, @varname(a[2])) @test merge(vnt1, vnt2) == expected_merge_12 @@ -194,13 +194,13 @@ end vnt1 = VarNamedTuple() vnt2 = VarNamedTuple() vnt1 = setindex!!(vnt1, 1, @varname(a[1, 1])) - vnt1 = setindex!!(vnt1, 1, @varname(a[1025, 1])) + vnt1 = setindex!!(vnt1, 1, @varname(a[257, 1])) vnt2 = setindex!!(vnt2, :2, @varname(a[1, 1])) - vnt2 = setindex!!(vnt2, :2, @varname(a[1, 1025])) + vnt2 = setindex!!(vnt2, :2, @varname(a[1, 257])) expected_merge_12 = VarNamedTuple() expected_merge_12 = setindex!!(expected_merge_12, :2, @varname(a[1, 1])) - expected_merge_12 = setindex!!(expected_merge_12, 1, @varname(a[1025, 1])) - expected_merge_12 = setindex!!(expected_merge_12, :2, @varname(a[1, 1025])) + expected_merge_12 = setindex!!(expected_merge_12, 1, @varname(a[257, 1])) + expected_merge_12 = setindex!!(expected_merge_12, :2, @varname(a[1, 257])) @test merge(vnt1, vnt2) == expected_merge_12 expected_merge_21 = setindex!!(expected_merge_12, 1, @varname(a[1, 1])) @test merge(vnt2, vnt1) == expected_merge_21 From 871eb9fd1216f392460462d4c84d8a38ca89da05 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 20 Nov 2025 12:41:55 +0000 Subject: [PATCH 013/148] Move _compose_no_identity to utils.jl --- src/utils.jl | 16 ++++++++++++++++ src/varnamedvector.jl | 16 ---------------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 75fb805dc..fe2879182 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -949,3 +949,19 @@ end Return `typeof(x)` stripped of its type parameters. """ basetypeof(x::T) where {T} = Base.typename(T).wrapper + +# TODO(mhauru) Might add another specialisation to _compose_no_identity, where if +# ReshapeTransforms are composed with each other or with a an UnwrapSingeltonTransform, only +# the latter one would be kept. +""" + _compose_no_identity(f, g) + +Like `f ∘ g`, but if `f` or `g` is `identity` it is omitted. + +This helps avoid trivial cases of `ComposedFunction` that would cause unnecessary type +conflicts. +""" +_compose_no_identity(f, g) = f ∘ g +_compose_no_identity(::typeof(identity), g) = g +_compose_no_identity(f, ::typeof(identity)) = f +_compose_no_identity(::typeof(identity), ::typeof(identity)) = identity diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index 17b851d1d..e5d2f2c2e 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -1355,22 +1355,6 @@ function nextrange(vnv::VarNamedVector, x) return (offset + 1):(offset + length(x)) end -# TODO(mhauru) Might add another specialisation to _compose_no_identity, where if -# ReshapeTransforms are composed with each other or with a an UnwrapSingeltonTransform, only -# the latter one would be kept. -""" - _compose_no_identity(f, g) - -Like `f ∘ g`, but if `f` or `g` is `identity` it is omitted. - -This helps avoid trivial cases of `ComposedFunction` that would cause unnecessary type -conflicts. -""" -_compose_no_identity(f, g) = f ∘ g -_compose_no_identity(::typeof(identity), g) = g -_compose_no_identity(f, ::typeof(identity)) = f -_compose_no_identity(::typeof(identity), ::typeof(identity)) = identity - """ shift_right!(x::AbstractVector{<:Real}, start::Int, n::Int) From 4a1156038eb673dc9567d8a3a4d008455ec83908 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 22 Nov 2025 00:26:16 +0000 Subject: [PATCH 014/148] Allow generation of `ParamsWithStats` from `FastLDF` plus parameters, and also `bundle_samples` (#1129) * Implement `ParamsWithStats` for `FastLDF` * Add comments * Implement `bundle_samples` for ParamsWithStats -> MCMCChains * Remove redundant comment * don't need Statistics? --- ext/DynamicPPLMCMCChainsExt.jl | 37 ++++++++++++++++ src/DynamicPPL.jl | 2 +- src/chains.jl | 57 ++++++++++++++++++++++++ src/fasteval.jl | 81 ++++++++++++++++++++++++---------- test/chains.jl | 28 +++++++++++- 5 files changed, 180 insertions(+), 25 deletions(-) diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index d8c343917..e74f0b8a9 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -140,6 +140,43 @@ function AbstractMCMC.to_samples( end end +function AbstractMCMC.bundle_samples( + ts::Vector{<:DynamicPPL.ParamsWithStats}, + model::DynamicPPL.Model, + spl::AbstractMCMC.AbstractSampler, + state, + chain_type::Type{MCMCChains.Chains}; + save_state=false, + stats=missing, + sort_chain=false, + discard_initial=0, + thinning=1, + kwargs..., +) + bare_chain = AbstractMCMC.from_samples(MCMCChains.Chains, reshape(ts, :, 1)) + + # Add additional MCMC-specific info + info = bare_chain.info + if save_state + info = merge(info, (model=model, sampler=spl, samplerstate=state)) + end + if !ismissing(stats) + info = merge(info, (start_time=stats.start, stop_time=stats.stop)) + end + + # Reconstruct the chain with the extra information + # Yeah, this is quite ugly. Blame MCMCChains. + chain = MCMCChains.Chains( + bare_chain.value.data, + names(bare_chain), + bare_chain.name_map; + info=info, + start=discard_initial + 1, + thin=thinning, + ) + return sort_chain ? sort(chain) : chain +end + """ predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index e9b902363..6d3900e91 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -202,6 +202,7 @@ include("logdensityfunction.jl") include("model_utils.jl") include("extract_priors.jl") include("values_as_in_model.jl") +include("experimental.jl") include("chains.jl") include("bijector.jl") @@ -209,7 +210,6 @@ include("debug_utils.jl") using .DebugUtils include("test_utils.jl") -include("experimental.jl") include("deprecated.jl") if isdefined(Base.Experimental, :register_error_hint) diff --git a/src/chains.jl b/src/chains.jl index 2b5976b9b..892423822 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -133,3 +133,60 @@ function ParamsWithStats( end return ParamsWithStats(params, stats) end + +""" + ParamsWithStats( + param_vector::AbstractVector, + ldf::DynamicPPL.Experimental.FastLDF, + stats::NamedTuple=NamedTuple(); + include_colon_eq::Bool=true, + include_log_probs::Bool=true, + ) + +Generate a `ParamsWithStats` by re-evaluating the given `ldf` with the provided +`param_vector`. + +This method is intended to replace the old method of obtaining parameters and statistics +via `unflatten` plus re-evaluation. It is faster for two reasons: + +1. It does not rely on `deepcopy`-ing the VarInfo object (this used to be mandatory as + otherwise re-evaluation would mutate the VarInfo, rendering it unusable for subsequent + MCMC iterations). +2. The re-evaluation is faster as it uses `OnlyAccsVarInfo`. +""" +function ParamsWithStats( + param_vector::AbstractVector, + ldf::DynamicPPL.Experimental.FastLDF, + stats::NamedTuple=NamedTuple(); + include_colon_eq::Bool=true, + include_log_probs::Bool=true, +) + strategy = InitFromParams( + VectorWithRanges(ldf._iden_varname_ranges, ldf._varname_ranges, param_vector), + nothing, + ) + accs = if include_log_probs + ( + DynamicPPL.LogPriorAccumulator(), + DynamicPPL.LogLikelihoodAccumulator(), + DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq), + ) + else + (DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq),) + end + _, vi = DynamicPPL.Experimental.fast_evaluate!!( + ldf.model, strategy, AccumulatorTuple(accs) + ) + params = DynamicPPL.getacc(vi, Val(:ValuesAsInModel)).values + if include_log_probs + stats = merge( + stats, + ( + logprior=DynamicPPL.getlogprior(vi), + loglikelihood=DynamicPPL.getloglikelihood(vi), + lp=DynamicPPL.getlogjoint(vi), + ), + ) + end + return ParamsWithStats(params, stats) +end diff --git a/src/fasteval.jl b/src/fasteval.jl index 4f402f4a8..722760fa1 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -3,6 +3,7 @@ using DynamicPPL: AccumulatorTuple, InitContext, InitFromParams, + AbstractInitStrategy, LogJacobianAccumulator, LogLikelihoodAccumulator, LogPriorAccumulator, @@ -28,6 +29,60 @@ using LogDensityProblems: LogDensityProblems import DifferentiationInterface as DI using Random: Random +""" + DynamicPPL.Experimental.fast_evaluate!!( + [rng::Random.AbstractRNG,] + model::Model, + strategy::AbstractInitStrategy, + accs::AccumulatorTuple, params::AbstractVector{<:Real} + ) + +Evaluate a model using parameters obtained via `strategy`, and only computing the results in +the provided accumulators. + +It is assumed that the accumulators passed in have been initialised to appropriate values, +as this function will not reset them. The default constructors for each accumulator will do +this for you correctly. + +Returns a tuple of the model's return value, plus an `OnlyAccsVarInfo`. Note that the `accs` +argument may be mutated (depending on how the accumulators are implemented); hence the `!!` +in the function name. +""" +@inline function fast_evaluate!!( + # Note that this `@inline` is mandatory for performance. If it's not inlined, it leads + # to extra allocations (even for trivial models) and much slower runtime. + rng::Random.AbstractRNG, + model::Model, + strategy::AbstractInitStrategy, + accs::AccumulatorTuple, +) + ctx = InitContext(rng, strategy) + model = DynamicPPL.setleafcontext(model, ctx) + # Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!, + # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!` + # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic + # here. + # TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what + # it _should_ do, but this is wrong regardless. + # https://github.com/TuringLang/DynamicPPL.jl/issues/1086 + vi = if Threads.nthreads() > 1 + param_eltype = DynamicPPL.get_param_eltype(strategy) + accs = map(accs) do acc + DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc) + end + ThreadSafeVarInfo(OnlyAccsVarInfo(accs)) + else + OnlyAccsVarInfo(accs) + end + return DynamicPPL._evaluate!!(model, vi) +end +@inline function fast_evaluate!!( + model::Model, strategy::AbstractInitStrategy, accs::AccumulatorTuple +) + # This `@inline` is also mandatory for performance + return fast_evaluate!!(Random.default_rng(), model, strategy, accs) +end + """ FastLDF( model::Model, @@ -213,31 +268,11 @@ struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple} varname_ranges::Dict{VarName,RangeAndLinked} end function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) - ctx = InitContext( - Random.default_rng(), - InitFromParams( - VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing - ), + strategy = InitFromParams( + VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing ) - model = DynamicPPL.setleafcontext(f.model, ctx) accs = fast_ldf_accs(f.getlogdensity) - # Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!, - # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!` - # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic - # here. - # TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what - # it _should_ do, but this is wrong regardless. - # https://github.com/TuringLang/DynamicPPL.jl/issues/1086 - vi = if Threads.nthreads() > 1 - accs = map( - acc -> DynamicPPL.convert_eltype(float_type_with_fallback(eltype(params)), acc), - accs, - ) - ThreadSafeVarInfo(OnlyAccsVarInfo(accs)) - else - OnlyAccsVarInfo(accs) - end - _, vi = DynamicPPL._evaluate!!(model, vi) + _, vi = fast_evaluate!!(f.model, strategy, accs) return f.getlogdensity(vi) end diff --git a/test/chains.jl b/test/chains.jl index ab0ff4475..43b877d62 100644 --- a/test/chains.jl +++ b/test/chains.jl @@ -4,7 +4,7 @@ using DynamicPPL using Distributions using Test -@testset "ParamsWithStats" begin +@testset "ParamsWithStats from VarInfo" begin @model function f(z) x ~ Normal() y := x + 1 @@ -66,4 +66,30 @@ using Test end end +@testset "ParamsWithStats from FastLDF" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + unlinked_vi = VarInfo(m) + @testset "$islinked" for islinked in (false, true) + vi = if islinked + DynamicPPL.link!!(unlinked_vi, m) + else + unlinked_vi + end + params = [x for x in vi[:]] + + # Get the ParamsWithStats using FastLDF + fldf = DynamicPPL.Experimental.FastLDF(m, getlogjoint, vi) + ps = ParamsWithStats(params, fldf) + + # Check that length of parameters is as expected + @test length(ps.params) == length(keys(vi)) + + # Iterate over all variables to check that their values match + for vn in keys(vi) + @test ps.params[vn] == vi[vn] + end + end + end +end + end # module From 766f6635903c401a79d3c2427dc60225f0053dad Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 25 Nov 2025 11:41:51 +0000 Subject: [PATCH 015/148] Make FastLDF the default (#1139) * Make FastLDF the default * Add miscellaneous LogDensityProblems tests * Use `init!!` instead of `fast_evaluate!!` * Rename files, rebalance tests --- HISTORY.md | 23 +- docs/src/api.md | 8 +- ext/DynamicPPLMarginalLogDensitiesExt.jl | 11 +- src/DynamicPPL.jl | 4 + src/chains.jl | 8 +- src/experimental.jl | 2 - src/fasteval.jl | 387 --------------- src/logdensityfunction.jl | 579 +++++++++++------------ src/model.jl | 54 ++- test/ad.jl | 137 ------ test/chains.jl | 8 +- test/fasteval.jl | 233 --------- test/logdensityfunction.jl | 263 ++++++++-- test/runtests.jl | 7 +- 14 files changed, 584 insertions(+), 1140 deletions(-) delete mode 100644 src/fasteval.jl delete mode 100644 test/ad.jl delete mode 100644 test/fasteval.jl diff --git a/HISTORY.md b/HISTORY.md index 0f0102ce4..91306c219 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,6 +4,17 @@ ### Breaking changes +#### Fast Log Density Functions + +This version provides a reimplementation of `LogDensityFunction` that provides performance improvements on the order of 2–10× for both model evaluation as well as automatic differentiation. +Exact speedups depend on the model size: larger models have less significant speedups because the bulk of the work is done in calls to `logpdf`. + +For more information about how this is accomplished, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/fasteval.jl` file, which contains extensive comments. + +As a result of this change, `LogDensityFunction` no longer stores a VarInfo inside it. +In general, if `ldf` is a `LogDensityFunction`, it is now only valid to access `ldf.model` and `ldf.adtype`. +If you were previously relying on this behaviour, you will need to store a VarInfo separately. + #### Parent and leaf contexts The `DynamicPPL.NodeTrait` function has been removed. @@ -24,18 +35,6 @@ Removed the method `returned(::Model, values, keys)`; please use `returned(::Mod The method `DynamicPPL.init` (for implementing `AbstractInitStrategy`) now has a different signature: it must return a tuple of the generated value, plus a transform function that maps it back to unlinked space. This is a generalisation of the previous behaviour, where `init` would always return an unlinked value (in effect forcing the transform to be the identity function). -### Other changes - -#### FastLDF - -Added `DynamicPPL.Experimental.FastLDF`, a version of `LogDensityFunction` that provides performance improvements on the order of 2–10× for both model evaluation as well as automatic differentiation. -Exact speedups depend on the model size: larger models have less significant speedups because the bulk of the work is done in calls to `logpdf`. - -Please note that `FastLDF` is currently considered internal and its API may change without warning. -We intend to replace `LogDensityFunction` with `FastLDF` in a release in the near future, but until then we recommend not using it. - -For more information about `FastLDF`, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/fasteval.jl` file, which contains extensive comments. - ## 0.38.9 Remove warning when using Enzyme as the AD backend. diff --git a/docs/src/api.md b/docs/src/api.md index e81f18dc7..adb476db5 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -66,6 +66,12 @@ The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) inte LogDensityFunction ``` +Internally, this is accomplished using [`init!!`](@ref) on: + +```@docs +OnlyAccsVarInfo +``` + ## Condition and decondition A [`Model`](@ref) can be conditioned on a set of observations with [`AbstractPPL.condition`](@ref) or its alias [`|`](@ref). @@ -510,7 +516,7 @@ The function `init!!` is used to initialise, or overwrite, values in a VarInfo. It is really a thin wrapper around using `evaluate!!` with an `InitContext`. ```@docs -DynamicPPL.init!! +init!! ``` To accomplish this, an initialisation _strategy_ is required, which defines how new values are to be obtained. diff --git a/ext/DynamicPPLMarginalLogDensitiesExt.jl b/ext/DynamicPPLMarginalLogDensitiesExt.jl index 2155fa161..8b3040757 100644 --- a/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ b/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -6,8 +6,13 @@ using MarginalLogDensities: MarginalLogDensities # A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by # MarginalLogDensities. It's helpful to have a struct so that we can dispatch on its type # below. -struct LogDensityFunctionWrapper{L<:DynamicPPL.LogDensityFunction} +struct LogDensityFunctionWrapper{ + L<:DynamicPPL.LogDensityFunction,V<:DynamicPPL.AbstractVarInfo +} logdensity::L + # This field is used only to reconstruct the VarInfo later on; it's not needed for the + # actual log-density evaluation. + varinfo::V end function (lw::LogDensityFunctionWrapper)(x, _) return LogDensityProblems.logdensity(lw.logdensity, x) @@ -101,7 +106,7 @@ function DynamicPPL.marginalize( # Construct the marginal log-density model. f = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo) mld = MarginalLogDensities.MarginalLogDensity( - LogDensityFunctionWrapper(f), varinfo[:], varindices, (), method; kwargs... + LogDensityFunctionWrapper(f, varinfo), varinfo[:], varindices, (), method; kwargs... ) return mld end @@ -190,7 +195,7 @@ function DynamicPPL.VarInfo( unmarginalized_params::Union{AbstractVector,Nothing}=nothing, ) # Extract the original VarInfo. Its contents will in general be junk. - original_vi = mld.logdensity.logdensity.varinfo + original_vi = mld.logdensity.varinfo # Extract the stored parameters, which includes the modes for any marginalized # parameters full_params = MarginalLogDensities.cached_params(mld) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 6d3900e91..a885f6a96 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -92,8 +92,12 @@ export AbstractVarInfo, getargnames, extract_priors, values_as_in_model, + # evaluation + evaluate!!, + init!!, # LogDensityFunction LogDensityFunction, + OnlyAccsVarInfo, # Leaf contexts AbstractContext, contextualize, diff --git a/src/chains.jl b/src/chains.jl index 892423822..f176b8e68 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -137,7 +137,7 @@ end """ ParamsWithStats( param_vector::AbstractVector, - ldf::DynamicPPL.Experimental.FastLDF, + ldf::DynamicPPL.LogDensityFunction, stats::NamedTuple=NamedTuple(); include_colon_eq::Bool=true, include_log_probs::Bool=true, @@ -156,7 +156,7 @@ via `unflatten` plus re-evaluation. It is faster for two reasons: """ function ParamsWithStats( param_vector::AbstractVector, - ldf::DynamicPPL.Experimental.FastLDF, + ldf::DynamicPPL.LogDensityFunction, stats::NamedTuple=NamedTuple(); include_colon_eq::Bool=true, include_log_probs::Bool=true, @@ -174,9 +174,7 @@ function ParamsWithStats( else (DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq),) end - _, vi = DynamicPPL.Experimental.fast_evaluate!!( - ldf.model, strategy, AccumulatorTuple(accs) - ) + _, vi = DynamicPPL.init!!(ldf.model, OnlyAccsVarInfo(AccumulatorTuple(accs)), strategy) params = DynamicPPL.getacc(vi, Val(:ValuesAsInModel)).values if include_log_probs stats = merge( diff --git a/src/experimental.jl b/src/experimental.jl index c644c09b2..8c82dca68 100644 --- a/src/experimental.jl +++ b/src/experimental.jl @@ -2,8 +2,6 @@ module Experimental using DynamicPPL: DynamicPPL -include("fasteval.jl") - # This file only defines the names of the functions, and their docstrings. The actual implementations are in `ext/DynamicPPLJETExt.jl`, since we don't want to depend on JET.jl other than as a weak dependency. """ is_suitable_varinfo(model::Model, varinfo::AbstractVarInfo; kwargs...) diff --git a/src/fasteval.jl b/src/fasteval.jl deleted file mode 100644 index 722760fa1..000000000 --- a/src/fasteval.jl +++ /dev/null @@ -1,387 +0,0 @@ -using DynamicPPL: - AbstractVarInfo, - AccumulatorTuple, - InitContext, - InitFromParams, - AbstractInitStrategy, - LogJacobianAccumulator, - LogLikelihoodAccumulator, - LogPriorAccumulator, - Model, - ThreadSafeVarInfo, - VarInfo, - OnlyAccsVarInfo, - RangeAndLinked, - VectorWithRanges, - Metadata, - VarNamedVector, - default_accumulators, - float_type_with_fallback, - getlogjoint, - getlogjoint_internal, - getloglikelihood, - getlogprior, - getlogprior_internal -using ADTypes: ADTypes -using BangBang: BangBang -using AbstractPPL: AbstractPPL, VarName -using LogDensityProblems: LogDensityProblems -import DifferentiationInterface as DI -using Random: Random - -""" - DynamicPPL.Experimental.fast_evaluate!!( - [rng::Random.AbstractRNG,] - model::Model, - strategy::AbstractInitStrategy, - accs::AccumulatorTuple, params::AbstractVector{<:Real} - ) - -Evaluate a model using parameters obtained via `strategy`, and only computing the results in -the provided accumulators. - -It is assumed that the accumulators passed in have been initialised to appropriate values, -as this function will not reset them. The default constructors for each accumulator will do -this for you correctly. - -Returns a tuple of the model's return value, plus an `OnlyAccsVarInfo`. Note that the `accs` -argument may be mutated (depending on how the accumulators are implemented); hence the `!!` -in the function name. -""" -@inline function fast_evaluate!!( - # Note that this `@inline` is mandatory for performance. If it's not inlined, it leads - # to extra allocations (even for trivial models) and much slower runtime. - rng::Random.AbstractRNG, - model::Model, - strategy::AbstractInitStrategy, - accs::AccumulatorTuple, -) - ctx = InitContext(rng, strategy) - model = DynamicPPL.setleafcontext(model, ctx) - # Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!, - # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!` - # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic - # here. - # TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what - # it _should_ do, but this is wrong regardless. - # https://github.com/TuringLang/DynamicPPL.jl/issues/1086 - vi = if Threads.nthreads() > 1 - param_eltype = DynamicPPL.get_param_eltype(strategy) - accs = map(accs) do acc - DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc) - end - ThreadSafeVarInfo(OnlyAccsVarInfo(accs)) - else - OnlyAccsVarInfo(accs) - end - return DynamicPPL._evaluate!!(model, vi) -end -@inline function fast_evaluate!!( - model::Model, strategy::AbstractInitStrategy, accs::AccumulatorTuple -) - # This `@inline` is also mandatory for performance - return fast_evaluate!!(Random.default_rng(), model, strategy, accs) -end - -""" - FastLDF( - model::Model, - getlogdensity::Function=getlogjoint_internal, - varinfo::AbstractVarInfo=VarInfo(model); - adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, - ) - -A struct which contains a model, along with all the information necessary to: - - - calculate its log density at a given point; - - and if `adtype` is provided, calculate the gradient of the log density at that point. - -This information can be extracted using the LogDensityProblems.jl interface, specifically, -using `LogDensityProblems.logdensity` and `LogDensityProblems.logdensity_and_gradient`. If -`adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a concrete AD -backend type, then `logdensity_and_gradient` is also implemented. - -There are several options for `getlogdensity` that are 'supported' out of the box: - -- [`getlogjoint_internal`](@ref): calculate the log joint, including the log-Jacobian term - for any variables that have been linked in the provided VarInfo. -- [`getlogprior_internal`](@ref): calculate the log prior, including the log-Jacobian term - for any variables that have been linked in the provided VarInfo. -- [`getlogjoint`](@ref): calculate the log joint in the model space, ignoring any effects of - linking -- [`getlogprior`](@ref): calculate the log prior in the model space, ignoring any effects of - linking -- [`getloglikelihood`](@ref): calculate the log likelihood (this is unaffected by linking, - since transforms are only applied to random variables) - -!!! note - By default, `FastLDF` uses `getlogjoint_internal`, i.e., the result of - `LogDensityProblems.logdensity(f, x)` will depend on whether the `FastLDF` was created - with a linked or unlinked VarInfo. This is done primarily to ease interoperability with - MCMC samplers. - -If you provide one of these functions, a `VarInfo` will be automatically created for you. If -you provide a different function, you have to manually create a VarInfo and pass it as the -third argument. - -If the `adtype` keyword argument is provided, then this struct will also store the adtype -along with other information for efficient calculation of the gradient of the log density. -Note that preparing a `FastLDF` with an AD type `AutoBackend()` requires the AD backend -itself to have been loaded (e.g. with `import Backend`). - -## Fields - -Note that it is undefined behaviour to access any of a `FastLDF`'s fields, apart from: - -- `fastldf.model`: The original model from which this `FastLDF` was constructed. -- `fastldf.adtype`: The AD type used for gradient calculations, or `nothing` if no AD - type was provided. - -# Extended help - -Up until DynamicPPL v0.38, there have been two ways of evaluating a DynamicPPL model at a -given set of parameters: - -1. With `unflatten` + `evaluate!!` with `DefaultContext`: this stores a vector of parameters - inside a VarInfo's metadata, then reads parameter values from the VarInfo during evaluation. - -2. With `InitFromParams`: this reads parameter values from a NamedTuple or a Dict, and stores - them inside a VarInfo's metadata. - -In general, both of these approaches work fine, but the fact that they modify the VarInfo's -metadata can often be quite wasteful. In particular, it is very common that the only outputs -we care about from model evaluation are those which are stored in accumulators, such as log -probability densities, or `ValuesAsInModel`. - -To avoid this issue, we use `OnlyAccsVarInfo`, which is a VarInfo that only contains -accumulators. It implements enough of the `AbstractVarInfo` interface to not error during -model evaluation. - -Because `OnlyAccsVarInfo` does not store any parameter values, when evaluating a model with -it, it is mandatory that parameters are provided from outside the VarInfo, namely via -`InitContext`. - -The main problem that we face is that it is not possible to directly implement -`DynamicPPL.init(rng, vn, dist, strategy)` for `strategy::InitFromParams{<:AbstractVector}`. -In particular, it is not clear: - - - which parts of the vector correspond to which random variables, and - - whether the variables are linked or unlinked. - -Traditionally, this problem has been solved by `unflatten`, because that function would -place values into the VarInfo's metadata alongside the information about ranges and linking. -That way, when we evaluate with `DefaultContext`, we can read this information out again. -However, we want to avoid using a metadata. Thus, here, we _extract this information from -the VarInfo_ a single time when constructing a `FastLDF` object. Inside the FastLDF, we -store a mapping from VarNames to ranges in that vector, along with link status. - -For VarNames with identity optics, this is stored in a NamedTuple for efficiency. For all -other VarNames, this is stored in a Dict. The internal data structure used to represent this -could almost certainly be optimised further. See e.g. the discussion in -https://github.com/TuringLang/DynamicPPL.jl/issues/1116. - -When evaluating the model, this allows us to combine the parameter vector together with those -ranges to create an `InitFromParams{VectorWithRanges}`, which lets us very quickly read -parameter values from the vector. - -Note that this assumes that the ranges and link status are static throughout the lifetime of -the `FastLDF` object. Therefore, a `FastLDF` object cannot handle models which have variable -numbers of parameters, or models which may visit random variables in different orders depending -on stochastic control flow. **Indeed, silent errors may occur with such models.** This is a -general limitation of vectorised parameters: the original `unflatten` + `evaluate!!` -approach also fails with such models. -""" -struct FastLDF{ - M<:Model, - AD<:Union{ADTypes.AbstractADType,Nothing}, - F<:Function, - N<:NamedTuple, - ADP<:Union{Nothing,DI.GradientPrep}, -} - model::M - adtype::AD - _getlogdensity::F - _iden_varname_ranges::N - _varname_ranges::Dict{VarName,RangeAndLinked} - _adprep::ADP - _dim::Int - - function FastLDF( - model::Model, - getlogdensity::Function=getlogjoint_internal, - varinfo::AbstractVarInfo=VarInfo(model); - adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, - ) - # Figure out which variable corresponds to which index, and - # which variables are linked. - all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo) - x = [val for val in varinfo[:]] - dim = length(x) - # Do AD prep if needed - prep = if adtype === nothing - nothing - else - # Make backend-specific tweaks to the adtype - adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo) - DI.prepare_gradient( - FastLogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges), - adtype, - x, - ) - end - return new{ - typeof(model), - typeof(adtype), - typeof(getlogdensity), - typeof(all_iden_ranges), - typeof(prep), - }( - model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep, dim - ) - end -end - -################################### -# LogDensityProblems.jl interface # -################################### -""" - fast_ldf_accs(getlogdensity::Function) - -Determine which accumulators are needed for fast evaluation with the given -`getlogdensity` function. -""" -fast_ldf_accs(::Function) = default_accumulators() -fast_ldf_accs(::typeof(getlogjoint_internal)) = default_accumulators() -function fast_ldf_accs(::typeof(getlogjoint)) - return AccumulatorTuple((LogPriorAccumulator(), LogLikelihoodAccumulator())) -end -function fast_ldf_accs(::typeof(getlogprior_internal)) - return AccumulatorTuple((LogPriorAccumulator(), LogJacobianAccumulator())) -end -fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),)) -fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),)) - -struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple} - model::M - getlogdensity::F - iden_varname_ranges::N - varname_ranges::Dict{VarName,RangeAndLinked} -end -function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) - strategy = InitFromParams( - VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing - ) - accs = fast_ldf_accs(f.getlogdensity) - _, vi = fast_evaluate!!(f.model, strategy, accs) - return f.getlogdensity(vi) -end - -function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) - return FastLogDensityAt( - fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges - )( - params - ) -end - -function LogDensityProblems.logdensity_and_gradient( - fldf::FastLDF, params::AbstractVector{<:Real} -) - return DI.value_and_gradient( - FastLogDensityAt( - fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges - ), - fldf._adprep, - fldf.adtype, - params, - ) -end - -function LogDensityProblems.capabilities( - ::Type{<:DynamicPPL.Experimental.FastLDF{M,Nothing}} -) where {M} - return LogDensityProblems.LogDensityOrder{0}() -end -function LogDensityProblems.capabilities( - ::Type{<:DynamicPPL.Experimental.FastLDF{M,<:ADTypes.AbstractADType}} -) where {M} - return LogDensityProblems.LogDensityOrder{1}() -end -function LogDensityProblems.dimension(fldf::FastLDF) - return fldf._dim -end - -###################################################### -# Helper functions to extract ranges and link status # -###################################################### - -# This fails for SimpleVarInfo, but honestly there is no reason to support that here. The -# fact is that evaluation doesn't use a VarInfo, it only uses it once to generate the ranges -# and link status. So there is no motivation to use SimpleVarInfo inside a -# LogDensityFunction any more, we can just always use typed VarInfo. In fact one could argue -# that there is no purpose in supporting untyped VarInfo either. -""" - get_ranges_and_linked(varinfo::VarInfo) - -Given a `VarInfo`, extract the ranges of each variable in the vectorised parameter -representation, along with whether each variable is linked or unlinked. - -This function should return a tuple containing: - -- A NamedTuple mapping VarNames with identity optics to their corresponding `RangeAndLinked` -- A Dict mapping all other VarNames to their corresponding `RangeAndLinked`. -""" -function get_ranges_and_linked(varinfo::VarInfo{<:NamedTuple{syms}}) where {syms} - all_iden_ranges = NamedTuple() - all_ranges = Dict{VarName,RangeAndLinked}() - offset = 1 - for sym in syms - md = varinfo.metadata[sym] - this_md_iden, this_md_others, offset = get_ranges_and_linked_metadata(md, offset) - all_iden_ranges = merge(all_iden_ranges, this_md_iden) - all_ranges = merge(all_ranges, this_md_others) - end - return all_iden_ranges, all_ranges -end -function get_ranges_and_linked(varinfo::VarInfo{<:Union{Metadata,VarNamedVector}}) - all_iden, all_others, _ = get_ranges_and_linked_metadata(varinfo.metadata, 1) - return all_iden, all_others -end -function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int) - all_iden_ranges = NamedTuple() - all_ranges = Dict{VarName,RangeAndLinked}() - offset = start_offset - for (vn, idx) in md.idcs - is_linked = md.is_transformed[idx] - range = md.ranges[idx] .+ (start_offset - 1) - if AbstractPPL.getoptic(vn) === identity - all_iden_ranges = merge( - all_iden_ranges, - NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), - ) - else - all_ranges[vn] = RangeAndLinked(range, is_linked) - end - offset += length(range) - end - return all_iden_ranges, all_ranges, offset -end -function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int) - all_iden_ranges = NamedTuple() - all_ranges = Dict{VarName,RangeAndLinked}() - offset = start_offset - for (vn, idx) in vnv.varname_to_index - is_linked = vnv.is_unconstrained[idx] - range = vnv.ranges[idx] .+ (start_offset - 1) - if AbstractPPL.getoptic(vn) === identity - all_iden_ranges = merge( - all_iden_ranges, - NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), - ) - else - all_ranges[vn] = RangeAndLinked(range, is_linked) - end - offset += length(range) - end - return all_iden_ranges, all_ranges, offset -end diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 7c7438c9f..65eab448e 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -1,312 +1,263 @@ -using AbstractMCMC: AbstractModel +using DynamicPPL: + AbstractVarInfo, + AccumulatorTuple, + InitContext, + InitFromParams, + AbstractInitStrategy, + LogJacobianAccumulator, + LogLikelihoodAccumulator, + LogPriorAccumulator, + Model, + ThreadSafeVarInfo, + VarInfo, + OnlyAccsVarInfo, + RangeAndLinked, + VectorWithRanges, + Metadata, + VarNamedVector, + default_accumulators, + float_type_with_fallback, + getlogjoint, + getlogjoint_internal, + getloglikelihood, + getlogprior, + getlogprior_internal +using ADTypes: ADTypes +using BangBang: BangBang +using AbstractPPL: AbstractPPL, VarName +using LogDensityProblems: LogDensityProblems import DifferentiationInterface as DI +using Random: Random """ - is_supported(adtype::AbstractADType) - -Check if the given AD type is formally supported by DynamicPPL. - -AD backends that are not formally supported can still be used for gradient -calculation; it is just that the DynamicPPL developers do not commit to -maintaining compatibility with them. -""" -is_supported(::ADTypes.AbstractADType) = false -is_supported(::ADTypes.AutoEnzyme) = true -is_supported(::ADTypes.AutoForwardDiff) = true -is_supported(::ADTypes.AutoMooncake) = true -is_supported(::ADTypes.AutoReverseDiff) = true - -""" - LogDensityFunction( + DynamicPPL.LogDensityFunction( model::Model, getlogdensity::Function=getlogjoint_internal, - varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity); - adtype::Union{ADTypes.AbstractADType,Nothing}=nothing + varinfo::AbstractVarInfo=VarInfo(model); + adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, ) A struct which contains a model, along with all the information necessary to: - calculate its log density at a given point; - - and if `adtype` is provided, calculate the gradient of the log density at - that point. - -This information can be extracted using the LogDensityProblems.jl interface, -specifically, using `LogDensityProblems.logdensity` and -`LogDensityProblems.logdensity_and_gradient`. If `adtype` is nothing, then only -`logdensity` is implemented. If `adtype` is a concrete AD backend type, then -`logdensity_and_gradient` is also implemented. - -There are several options for `getlogdensity` that are 'supported' out of the -box: - -- [`getlogjoint_internal`](@ref): calculate the log joint, including the - log-Jacobian term for any variables that have been linked in the provided - VarInfo. -- [`getlogprior_internal`](@ref): calculate the log prior, including the - log-Jacobian term for any variables that have been linked in the provided - VarInfo. -- [`getlogjoint`](@ref): calculate the log joint in the model space, ignoring - any effects of linking -- [`getlogprior`](@ref): calculate the log prior in the model space, ignoring - any effects of linking -- [`getloglikelihood`](@ref): calculate the log likelihood (this is unaffected - by linking, since transforms are only applied to random variables) + - and if `adtype` is provided, calculate the gradient of the log density at that point. + +This information can be extracted using the LogDensityProblems.jl interface, specifically, +using `LogDensityProblems.logdensity` and `LogDensityProblems.logdensity_and_gradient`. If +`adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a concrete AD +backend type, then `logdensity_and_gradient` is also implemented. + +There are several options for `getlogdensity` that are 'supported' out of the box: + +- [`getlogjoint_internal`](@ref): calculate the log joint, including the log-Jacobian term + for any variables that have been linked in the provided VarInfo. +- [`getlogprior_internal`](@ref): calculate the log prior, including the log-Jacobian term + for any variables that have been linked in the provided VarInfo. +- [`getlogjoint`](@ref): calculate the log joint in the model space, ignoring any effects of + linking +- [`getlogprior`](@ref): calculate the log prior in the model space, ignoring any effects of + linking +- [`getloglikelihood`](@ref): calculate the log likelihood (this is unaffected by linking, + since transforms are only applied to random variables) !!! note - By default, `LogDensityFunction` uses `getlogjoint_internal`, i.e., the - result of `LogDensityProblems.logdensity(f, x)` will depend on whether the - `LogDensityFunction` was created with a linked or unlinked VarInfo. This - is done primarily to ease interoperability with MCMC samplers. - -If you provide one of these functions, a `VarInfo` will be automatically created -for you. If you provide a different function, you have to manually create a -VarInfo and pass it as the third argument. - -If the `adtype` keyword argument is provided, then this struct will also store -the adtype along with other information for efficient calculation of the -gradient of the log density. Note that preparing a `LogDensityFunction` with an -AD type `AutoBackend()` requires the AD backend itself to have been loaded -(e.g. with `import Backend`). - -# Fields -$(FIELDS) - -# Examples - -```jldoctest -julia> using Distributions - -julia> using DynamicPPL: LogDensityFunction, setaccs!! - -julia> @model function demo(x) - m ~ Normal() - x ~ Normal(m, 1) - end -demo (generic function with 2 methods) - -julia> model = demo(1.0); - -julia> f = LogDensityFunction(model); - -julia> # It implements the interface of LogDensityProblems.jl. - using LogDensityProblems - -julia> LogDensityProblems.logdensity(f, [0.0]) --2.3378770664093453 - -julia> LogDensityProblems.dimension(f) -1 - -julia> # By default it uses `VarInfo` under the hood, but this is not necessary. - f = LogDensityFunction(model, getlogjoint_internal, SimpleVarInfo(model)); - -julia> LogDensityProblems.logdensity(f, [0.0]) --2.3378770664093453 - -julia> # One can also specify evaluating e.g. the log prior only: - f_prior = LogDensityFunction(model, getlogprior); - -julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0) -true - -julia> # If we also need to calculate the gradient, we can specify an AD backend. - import ForwardDiff, ADTypes - -julia> f = LogDensityFunction(model, adtype=ADTypes.AutoForwardDiff()); - -julia> LogDensityProblems.logdensity_and_gradient(f, [0.0]) -(-2.3378770664093453, [1.0]) -``` + By default, `LogDensityFunction` uses `getlogjoint_internal`, i.e., the result of + `LogDensityProblems.logdensity(f, x)` will depend on whether the `LogDensityFunction` + was created with a linked or unlinked VarInfo. This is done primarily to ease + interoperability with MCMC samplers. + +If you provide one of these functions, a `VarInfo` will be automatically created for you. If +you provide a different function, you have to manually create a VarInfo and pass it as the +third argument. + +If the `adtype` keyword argument is provided, then this struct will also store the adtype +along with other information for efficient calculation of the gradient of the log density. +Note that preparing a `LogDensityFunction` with an AD type `AutoBackend()` requires the AD +backend itself to have been loaded (e.g. with `import Backend`). + +## Fields + +Note that it is undefined behaviour to access any of a `LogDensityFunction`'s fields, apart +from: + +- `ldf.model`: The original model from which this `LogDensityFunction` was constructed. +- `ldf.adtype`: The AD type used for gradient calculations, or `nothing` if no AD + type was provided. + +# Extended help + +Up until DynamicPPL v0.38, there have been two ways of evaluating a DynamicPPL model at a +given set of parameters: + +1. With `unflatten` + `evaluate!!` with `DefaultContext`: this stores a vector of parameters + inside a VarInfo's metadata, then reads parameter values from the VarInfo during evaluation. + +2. With `InitFromParams`: this reads parameter values from a NamedTuple or a Dict, and stores + them inside a VarInfo's metadata. + +In general, both of these approaches work fine, but the fact that they modify the VarInfo's +metadata can often be quite wasteful. In particular, it is very common that the only outputs +we care about from model evaluation are those which are stored in accumulators, such as log +probability densities, or `ValuesAsInModel`. + +To avoid this issue, we use `OnlyAccsVarInfo`, which is a VarInfo that only contains +accumulators. It implements enough of the `AbstractVarInfo` interface to not error during +model evaluation. + +Because `OnlyAccsVarInfo` does not store any parameter values, when evaluating a model with +it, it is mandatory that parameters are provided from outside the VarInfo, namely via +`InitContext`. + +The main problem that we face is that it is not possible to directly implement +`DynamicPPL.init(rng, vn, dist, strategy)` for `strategy::InitFromParams{<:AbstractVector}`. +In particular, it is not clear: + + - which parts of the vector correspond to which random variables, and + - whether the variables are linked or unlinked. + +Traditionally, this problem has been solved by `unflatten`, because that function would +place values into the VarInfo's metadata alongside the information about ranges and linking. +That way, when we evaluate with `DefaultContext`, we can read this information out again. +However, we want to avoid using a metadata. Thus, here, we _extract this information from +the VarInfo_ a single time when constructing a `LogDensityFunction` object. Inside the +LogDensityFunction, we store a mapping from VarNames to ranges in that vector, along with +link status. + +For VarNames with identity optics, this is stored in a NamedTuple for efficiency. For all +other VarNames, this is stored in a Dict. The internal data structure used to represent this +could almost certainly be optimised further. See e.g. the discussion in +https://github.com/TuringLang/DynamicPPL.jl/issues/1116. + +When evaluating the model, this allows us to combine the parameter vector together with those +ranges to create an `InitFromParams{VectorWithRanges}`, which lets us very quickly read +parameter values from the vector. + +Note that this assumes that the ranges and link status are static throughout the lifetime of +the `LogDensityFunction` object. Therefore, a `LogDensityFunction` object cannot handle +models which have variable numbers of parameters, or models which may visit random variables +in different orders depending on stochastic control flow. **Indeed, silent errors may occur +with such models.** This is a general limitation of vectorised parameters: the original +`unflatten` + `evaluate!!` approach also fails with such models. """ struct LogDensityFunction{ - M<:Model,F<:Function,V<:AbstractVarInfo,AD<:Union{Nothing,ADTypes.AbstractADType} -} <: AbstractModel - "model used for evaluation" + M<:Model, + AD<:Union{ADTypes.AbstractADType,Nothing}, + F<:Function, + N<:NamedTuple, + ADP<:Union{Nothing,DI.GradientPrep}, +} model::M - "function to be called on `varinfo` to extract the log density. By default `getlogjoint_internal`." - getlogdensity::F - "varinfo used for evaluation. If not specified, generated with `ldf_default_varinfo`." - varinfo::V - "AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated" adtype::AD - "(internal use only) gradient preparation object for the model" - prep::Union{Nothing,DI.GradientPrep} + _getlogdensity::F + _iden_varname_ranges::N + _varname_ranges::Dict{VarName,RangeAndLinked} + _adprep::ADP + _dim::Int function LogDensityFunction( model::Model, getlogdensity::Function=getlogjoint_internal, - varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity); + varinfo::AbstractVarInfo=VarInfo(model); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, ) - if adtype === nothing - prep = nothing + # Figure out which variable corresponds to which index, and + # which variables are linked. + all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo) + x = [val for val in varinfo[:]] + dim = length(x) + # Do AD prep if needed + prep = if adtype === nothing + nothing else # Make backend-specific tweaks to the adtype - adtype = tweak_adtype(adtype, model, varinfo) - # Check whether it is supported - is_supported(adtype) || - @warn "The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed." - # Get a set of dummy params to use for prep - x = [val for val in varinfo[:]] - if use_closure(adtype) - prep = DI.prepare_gradient( - LogDensityAt(model, getlogdensity, varinfo), adtype, x - ) - else - prep = DI.prepare_gradient( - logdensity_at, - adtype, - x, - DI.Constant(model), - DI.Constant(getlogdensity), - DI.Constant(varinfo), - ) - end + adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo) + DI.prepare_gradient( + LogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges), + adtype, + x, + ) end - return new{typeof(model),typeof(getlogdensity),typeof(varinfo),typeof(adtype)}( - model, getlogdensity, varinfo, adtype, prep + return new{ + typeof(model), + typeof(adtype), + typeof(getlogdensity), + typeof(all_iden_ranges), + typeof(prep), + }( + model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep, dim ) end end +################################### +# LogDensityProblems.jl interface # +################################### """ - LogDensityFunction( - ldf::LogDensityFunction, - adtype::Union{Nothing,ADTypes.AbstractADType} - ) + fast_ldf_accs(getlogdensity::Function) -Create a new LogDensityFunction using the model and varinfo from the given -`ldf` argument, but with the AD type set to `adtype`. To remove the AD type, -pass `nothing` as the second argument. +Determine which accumulators are needed for fast evaluation with the given +`getlogdensity` function. """ -function LogDensityFunction( - f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType} -) - return if adtype === f.adtype - f # Avoid recomputing prep if not needed - else - LogDensityFunction(f.model, f.getlogdensity, f.varinfo; adtype=adtype) - end +fast_ldf_accs(::Function) = default_accumulators() +fast_ldf_accs(::typeof(getlogjoint_internal)) = default_accumulators() +function fast_ldf_accs(::typeof(getlogjoint)) + return AccumulatorTuple((LogPriorAccumulator(), LogLikelihoodAccumulator())) end - -""" - ldf_default_varinfo(model::Model, getlogdensity::Function) - -Create the default AbstractVarInfo that should be used for evaluating the log density. - -Only the accumulators necesessary for `getlogdensity` will be used. -""" -function ldf_default_varinfo(::Model, getlogdensity::Function) - msg = """ - LogDensityFunction does not know what sort of VarInfo should be used when \ - `getlogdensity` is $getlogdensity. Please specify a VarInfo explicitly. - """ - return error(msg) +function fast_ldf_accs(::typeof(getlogprior_internal)) + return AccumulatorTuple((LogPriorAccumulator(), LogJacobianAccumulator())) end +fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),)) +fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),)) -ldf_default_varinfo(model::Model, ::typeof(getlogjoint_internal)) = VarInfo(model) - -function ldf_default_varinfo(model::Model, ::typeof(getlogjoint)) - return setaccs!!(VarInfo(model), (LogPriorAccumulator(), LogLikelihoodAccumulator())) -end - -function ldf_default_varinfo(model::Model, ::typeof(getlogprior_internal)) - return setaccs!!(VarInfo(model), (LogPriorAccumulator(), LogJacobianAccumulator())) -end - -function ldf_default_varinfo(model::Model, ::typeof(getlogprior)) - return setaccs!!(VarInfo(model), (LogPriorAccumulator(),)) -end - -function ldf_default_varinfo(model::Model, ::typeof(getloglikelihood)) - return setaccs!!(VarInfo(model), (LogLikelihoodAccumulator(),)) +struct LogDensityAt{M<:Model,F<:Function,N<:NamedTuple} + model::M + getlogdensity::F + iden_varname_ranges::N + varname_ranges::Dict{VarName,RangeAndLinked} end - -""" - logdensity_at( - x::AbstractVector, - model::Model, - getlogdensity::Function, - varinfo::AbstractVarInfo, +function (f::LogDensityAt)(params::AbstractVector{<:Real}) + strategy = InitFromParams( + VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing ) + accs = fast_ldf_accs(f.getlogdensity) + _, vi = DynamicPPL.init!!(f.model, OnlyAccsVarInfo(accs), strategy) + return f.getlogdensity(vi) +end -Evaluate the log density of the given `model` at the given parameter values -`x`, using the given `varinfo`. Note that the `varinfo` argument is provided -only for its structure, in the sense that the parameters from the vector `x` -are inserted into it, and its own parameters are discarded. `getlogdensity` is -the function that extracts the log density from the evaluated varinfo. -""" -function logdensity_at( - x::AbstractVector, model::Model, getlogdensity::Function, varinfo::AbstractVarInfo +function LogDensityProblems.logdensity( + ldf::LogDensityFunction, params::AbstractVector{<:Real} ) - varinfo_new = unflatten(varinfo, x) - varinfo_eval = last(evaluate!!(model, varinfo_new)) - return getlogdensity(varinfo_eval) + return LogDensityAt( + ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges + )( + params + ) end -""" - LogDensityAt{M<:Model,F<:Function,V<:AbstractVarInfo}( - model::M - getlogdensity::F, - varinfo::V +function LogDensityProblems.logdensity_and_gradient( + ldf::LogDensityFunction, params::AbstractVector{<:Real} +) + return DI.value_and_gradient( + LogDensityAt( + ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges + ), + ldf._adprep, + ldf.adtype, + params, ) - -A callable struct that serves the same purpose as `x -> logdensity_at(x, model, -getlogdensity, varinfo)`. -""" -struct LogDensityAt{M<:Model,F<:Function,V<:AbstractVarInfo} - model::M - getlogdensity::F - varinfo::V -end -function (ld::LogDensityAt)(x::AbstractVector) - return logdensity_at(x, ld.model, ld.getlogdensity, ld.varinfo) end -### LogDensityProblems interface - -function LogDensityProblems.capabilities( - ::Type{<:LogDensityFunction{M,F,V,Nothing}} -) where {M,F,V} +function LogDensityProblems.capabilities(::Type{<:LogDensityFunction{M,Nothing}}) where {M} return LogDensityProblems.LogDensityOrder{0}() end function LogDensityProblems.capabilities( - ::Type{<:LogDensityFunction{M,F,V,AD}} -) where {M,F,V,AD<:ADTypes.AbstractADType} + ::Type{<:LogDensityFunction{M,<:ADTypes.AbstractADType}} +) where {M} return LogDensityProblems.LogDensityOrder{1}() end -function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector) - return logdensity_at(x, f.model, f.getlogdensity, f.varinfo) +function LogDensityProblems.dimension(ldf::LogDensityFunction) + return ldf._dim end -function LogDensityProblems.logdensity_and_gradient( - f::LogDensityFunction{M,F,V,AD}, x::AbstractVector -) where {M,F,V,AD<:ADTypes.AbstractADType} - f.prep === nothing && - error("Gradient preparation not available; this should not happen") - x = [val for val in x] # Concretise type - # Make branching statically inferrable, i.e. type-stable (even if the two - # branches happen to return different types) - return if use_closure(f.adtype) - DI.value_and_gradient( - LogDensityAt(f.model, f.getlogdensity, f.varinfo), f.prep, f.adtype, x - ) - else - DI.value_and_gradient( - logdensity_at, - f.prep, - f.adtype, - x, - DI.Constant(f.model), - DI.Constant(f.getlogdensity), - DI.Constant(f.varinfo), - ) - end -end - -# TODO: should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)? -LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f)) - -### Utils """ tweak_adtype( @@ -325,53 +276,77 @@ By default, this just returns the input unchanged. """ tweak_adtype(adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo) = adtype -""" - use_closure(adtype::ADTypes.AbstractADType) - -In LogDensityProblems, we want to calculate the derivative of logdensity(f, x) -with respect to x, where f is the model (in our case LogDensityFunction) and is -a constant. However, DifferentiationInterface generally expects a -single-argument function g(x) to differentiate. +###################################################### +# Helper functions to extract ranges and link status # +###################################################### -There are two ways of dealing with this: - -1. Construct a closure over the model, i.e. let g = Base.Fix1(logdensity, f) - -2. Use a constant DI.Context. This lets us pass a two-argument function to DI, - as long as we also give it the 'inactive argument' (i.e. the model) wrapped - in `DI.Constant`. - -The relative performance of the two approaches, however, depends on the AD -backend used. Some benchmarks are provided here: -https://github.com/TuringLang/DynamicPPL.jl/issues/946#issuecomment-2931604829 - -This function is used to determine whether a given AD backend should use a -closure or a constant. If `use_closure(adtype)` returns `true`, then the -closure approach will be used. By default, this function returns `false`, i.e. -the constant approach will be used. +# This fails for SimpleVarInfo, but honestly there is no reason to support that here. The +# fact is that evaluation doesn't use a VarInfo, it only uses it once to generate the ranges +# and link status. So there is no motivation to use SimpleVarInfo inside a +# LogDensityFunction any more, we can just always use typed VarInfo. In fact one could argue +# that there is no purpose in supporting untyped VarInfo either. """ -use_closure(::ADTypes.AbstractADType) = true -use_closure(::ADTypes.AutoEnzyme) = false + get_ranges_and_linked(varinfo::VarInfo) -""" - getmodel(f) +Given a `VarInfo`, extract the ranges of each variable in the vectorised parameter +representation, along with whether each variable is linked or unlinked. -Return the `DynamicPPL.Model` wrapped in the given log-density function `f`. -""" -getmodel(f::DynamicPPL.LogDensityFunction) = f.model +This function should return a tuple containing: +- A NamedTuple mapping VarNames with identity optics to their corresponding `RangeAndLinked` +- A Dict mapping all other VarNames to their corresponding `RangeAndLinked`. """ - setmodel(f, model[, adtype]) - -Set the `DynamicPPL.Model` in the given log-density function `f` to `model`. -""" -function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model) - return LogDensityFunction(model, f.getlogdensity, f.varinfo; adtype=f.adtype) +function get_ranges_and_linked(varinfo::VarInfo{<:NamedTuple{syms}}) where {syms} + all_iden_ranges = NamedTuple() + all_ranges = Dict{VarName,RangeAndLinked}() + offset = 1 + for sym in syms + md = varinfo.metadata[sym] + this_md_iden, this_md_others, offset = get_ranges_and_linked_metadata(md, offset) + all_iden_ranges = merge(all_iden_ranges, this_md_iden) + all_ranges = merge(all_ranges, this_md_others) + end + return all_iden_ranges, all_ranges +end +function get_ranges_and_linked(varinfo::VarInfo{<:Union{Metadata,VarNamedVector}}) + all_iden, all_others, _ = get_ranges_and_linked_metadata(varinfo.metadata, 1) + return all_iden, all_others +end +function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int) + all_iden_ranges = NamedTuple() + all_ranges = Dict{VarName,RangeAndLinked}() + offset = start_offset + for (vn, idx) in md.idcs + is_linked = md.is_transformed[idx] + range = md.ranges[idx] .+ (start_offset - 1) + if AbstractPPL.getoptic(vn) === identity + all_iden_ranges = merge( + all_iden_ranges, + NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), + ) + else + all_ranges[vn] = RangeAndLinked(range, is_linked) + end + offset += length(range) + end + return all_iden_ranges, all_ranges, offset +end +function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int) + all_iden_ranges = NamedTuple() + all_ranges = Dict{VarName,RangeAndLinked}() + offset = start_offset + for (vn, idx) in vnv.varname_to_index + is_linked = vnv.is_unconstrained[idx] + range = vnv.ranges[idx] .+ (start_offset - 1) + if AbstractPPL.getoptic(vn) === identity + all_iden_ranges = merge( + all_iden_ranges, + NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), + ) + else + all_ranges[vn] = RangeAndLinked(range, is_linked) + end + offset += length(range) + end + return all_iden_ranges, all_ranges, offset end - -""" - getparams(f::LogDensityFunction) - -Return the parameters of the wrapped varinfo as a vector. -""" -getparams(f::LogDensityFunction) = f.varinfo[:] diff --git a/src/model.jl b/src/model.jl index 2bcfe8f98..9029318b1 100644 --- a/src/model.jl +++ b/src/model.jl @@ -881,30 +881,56 @@ end [init_strategy::AbstractInitStrategy=InitFromPrior()] ) -Evaluate the `model` and replace the values of the model's random variables -in the given `varinfo` with new values, using a specified initialisation strategy. -If the values in `varinfo` are not set, they will be added -using a specified initialisation strategy. +Evaluate the `model` and replace the values of the model's random variables in the given +`varinfo` with new values, using a specified initialisation strategy. If the values in +`varinfo` are not set, they will be added using a specified initialisation strategy. If `init_strategy` is not provided, defaults to `InitFromPrior()`. Returns a tuple of the model's return value, plus the updated `varinfo` object. """ -function init!!( +@inline function init!!( + # Note that this `@inline` is mandatory for performance, especially for + # LogDensityFunction. If it's not inlined, it leads to extra allocations (even for + # trivial models) and much slower runtime. rng::Random.AbstractRNG, model::Model, - varinfo::AbstractVarInfo, - init_strategy::AbstractInitStrategy=InitFromPrior(), + vi::AbstractVarInfo, + strategy::AbstractInitStrategy=InitFromPrior(), ) - new_model = setleafcontext(model, InitContext(rng, init_strategy)) - return evaluate!!(new_model, varinfo) + ctx = InitContext(rng, strategy) + model = DynamicPPL.setleafcontext(model, ctx) + # TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what + # it _should_ do, but this is wrong regardless. + # https://github.com/TuringLang/DynamicPPL.jl/issues/1086 + return if Threads.nthreads() > 1 + # TODO(penelopeysm): The logic for setting eltype of accs is very similar to that + # used in `unflatten`. The reason why we need it here is because the VarInfo `vi` + # won't have been filled with parameters prior to `init!!` being called. + # + # Note that this eltype promotion is only needed for threadsafe evaluation. In an + # ideal world, this code should be handled inside `evaluate_threadsafe!!` or a + # similar method. In other words, it should not be here, and it should not be inside + # `unflatten` either. The problem is performance. Shifting this code around can have + # massive, inexplicable, impacts on performance. This should be investigated + # properly. + param_eltype = DynamicPPL.get_param_eltype(strategy) + accs = map(vi.accs) do acc + DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc) + end + vi = DynamicPPL.setaccs!!(vi, accs) + tsvi = ThreadSafeVarInfo(resetaccs!!(vi)) + retval, tsvi_new = DynamicPPL._evaluate!!(model, tsvi) + return retval, setaccs!!(tsvi_new.varinfo, DynamicPPL.getaccs(tsvi_new)) + else + return DynamicPPL._evaluate!!(model, resetaccs!!(vi)) + end end -function init!!( - model::Model, - varinfo::AbstractVarInfo, - init_strategy::AbstractInitStrategy=InitFromPrior(), +@inline function init!!( + model::Model, vi::AbstractVarInfo, strategy::AbstractInitStrategy=InitFromPrior() ) - return init!!(Random.default_rng(), model, varinfo, init_strategy) + # This `@inline` is also mandatory for performance + return init!!(Random.default_rng(), model, vi, strategy) end """ diff --git a/test/ad.jl b/test/ad.jl deleted file mode 100644 index 0236c232f..000000000 --- a/test/ad.jl +++ /dev/null @@ -1,137 +0,0 @@ -using DynamicPPL: LogDensityFunction -using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest - -@testset "Automatic differentiation" begin - # Used as the ground truth that others are compared against. - ref_adtype = AutoForwardDiff() - - test_adtypes = [ - AutoReverseDiff(; compile=false), - AutoReverseDiff(; compile=true), - AutoMooncake(; config=nothing), - ] - - @testset "Unsupported backends" begin - @model demo() = x ~ Normal() - @test_logs (:warn, r"not officially supported") LogDensityFunction( - demo(); adtype=AutoZygote() - ) - end - - @testset "Correctness" begin - @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m) - vns = DynamicPPL.TestUtils.varnames(m) - varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns) - - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - linked_varinfo = DynamicPPL.link(varinfo, m) - f = LogDensityFunction(m, getlogjoint_internal, linked_varinfo) - x = DynamicPPL.getparams(f) - - # Calculate reference logp + gradient of logp using ForwardDiff - ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) - ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual - - @testset "$adtype" for adtype in test_adtypes - @info "Testing AD on: $(m.f) - $(short_varinfo_name(linked_varinfo)) - $adtype" - - # Put predicates here to avoid long lines - is_mooncake = adtype isa AutoMooncake - is_1_10 = v"1.10" <= VERSION < v"1.11" - is_1_11_or_1_12 = v"1.11" <= VERSION < v"1.13" - is_svi_vnv = - linked_varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector} - is_svi_od = linked_varinfo isa SimpleVarInfo{<:OrderedDict} - - # Mooncake doesn't work with several combinations of SimpleVarInfo. - if is_mooncake && is_1_11_or_1_12 && is_svi_vnv - # https://github.com/compintell/Mooncake.jl/issues/470 - @test_throws ArgumentError DynamicPPL.LogDensityFunction( - m, getlogjoint_internal, linked_varinfo; adtype=adtype - ) - elseif is_mooncake && is_1_10 && is_svi_vnv - # TODO: report upstream - @test_throws UndefRefError DynamicPPL.LogDensityFunction( - m, getlogjoint_internal, linked_varinfo; adtype=adtype - ) - elseif is_mooncake && is_1_10 && is_svi_od - # TODO: report upstream - @test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.LogDensityFunction( - m, getlogjoint_internal, linked_varinfo; adtype=adtype - ) - else - @test run_ad( - m, - adtype; - varinfo=linked_varinfo, - test=WithExpectedResult(ref_logp, ref_grad), - ) isa Any - end - end - end - end - end - - # Test that various different ways of specifying array types as arguments work with all - # ADTypes. - @testset "Array argument types" begin - test_m = randn(2, 3) - - function eval_logp_and_grad(model, m, adtype) - ldf = LogDensityFunction(model(); adtype=adtype) - return LogDensityProblems.logdensity_and_gradient(ldf, m[:]) - end - - @model function scalar_matrix_model(::Type{T}=Float64) where {T<:Real} - m = Matrix{T}(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - scalar_matrix_model_reference = eval_logp_and_grad( - scalar_matrix_model, test_m, ref_adtype - ) - - @model function matrix_model(::Type{T}=Matrix{Float64}) where {T} - m = T(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - matrix_model_reference = eval_logp_and_grad(matrix_model, test_m, ref_adtype) - - @model function scalar_array_model(::Type{T}=Float64) where {T<:Real} - m = Array{T}(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - scalar_array_model_reference = eval_logp_and_grad( - scalar_array_model, test_m, ref_adtype - ) - - @model function array_model(::Type{T}=Array{Float64}) where {T} - m = T(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - array_model_reference = eval_logp_and_grad(array_model, test_m, ref_adtype) - - @testset "$adtype" for adtype in test_adtypes - scalar_matrix_model_logp_and_grad = eval_logp_and_grad( - scalar_matrix_model, test_m, adtype - ) - @test scalar_matrix_model_logp_and_grad[1] ≈ scalar_matrix_model_reference[1] - @test scalar_matrix_model_logp_and_grad[2] ≈ scalar_matrix_model_reference[2] - matrix_model_logp_and_grad = eval_logp_and_grad(matrix_model, test_m, adtype) - @test matrix_model_logp_and_grad[1] ≈ matrix_model_reference[1] - @test matrix_model_logp_and_grad[2] ≈ matrix_model_reference[2] - scalar_array_model_logp_and_grad = eval_logp_and_grad( - scalar_array_model, test_m, adtype - ) - @test scalar_array_model_logp_and_grad[1] ≈ scalar_array_model_reference[1] - @test scalar_array_model_logp_and_grad[2] ≈ scalar_array_model_reference[2] - array_model_logp_and_grad = eval_logp_and_grad(array_model, test_m, adtype) - @test array_model_logp_and_grad[1] ≈ array_model_reference[1] - @test array_model_logp_and_grad[2] ≈ array_model_reference[2] - end - end -end diff --git a/test/chains.jl b/test/chains.jl index 43b877d62..12a9ece71 100644 --- a/test/chains.jl +++ b/test/chains.jl @@ -66,7 +66,7 @@ using Test end end -@testset "ParamsWithStats from FastLDF" begin +@testset "ParamsWithStats from LogDensityFunction" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS unlinked_vi = VarInfo(m) @testset "$islinked" for islinked in (false, true) @@ -77,9 +77,9 @@ end end params = [x for x in vi[:]] - # Get the ParamsWithStats using FastLDF - fldf = DynamicPPL.Experimental.FastLDF(m, getlogjoint, vi) - ps = ParamsWithStats(params, fldf) + # Get the ParamsWithStats using LogDensityFunction + ldf = DynamicPPL.LogDensityFunction(m, getlogjoint, vi) + ps = ParamsWithStats(params, ldf) # Check that length of parameters is as expected @test length(ps.params) == length(keys(vi)) diff --git a/test/fasteval.jl b/test/fasteval.jl deleted file mode 100644 index db2333711..000000000 --- a/test/fasteval.jl +++ /dev/null @@ -1,233 +0,0 @@ -module DynamicPPLFastLDFTests - -using AbstractPPL: AbstractPPL -using Chairmarks -using DynamicPPL -using Distributions -using DistributionsAD: filldist -using ADTypes -using DynamicPPL.Experimental: FastLDF -using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest -using LinearAlgebra: I -using Test -using LogDensityProblems: LogDensityProblems - -using ForwardDiff: ForwardDiff -using ReverseDiff: ReverseDiff -# Need to include this block here in case we run this test file standalone -@static if VERSION < v"1.12" - using Pkg - Pkg.add("Mooncake") - using Mooncake: Mooncake -end - -@testset "FastLDF: Correctness" begin - @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - @testset "$varinfo_func" for varinfo_func in [ - DynamicPPL.untyped_varinfo, - DynamicPPL.typed_varinfo, - DynamicPPL.untyped_vector_varinfo, - DynamicPPL.typed_vector_varinfo, - ] - unlinked_vi = varinfo_func(m) - @testset "$islinked" for islinked in (false, true) - vi = if islinked - DynamicPPL.link!!(unlinked_vi, m) - else - unlinked_vi - end - nt_ranges, dict_ranges = DynamicPPL.Experimental.get_ranges_and_linked(vi) - params = [x for x in vi[:]] - # Iterate over all variables - for vn in keys(vi) - # Check that `getindex_internal` returns the same thing as using the ranges - # directly - range_with_linked = if AbstractPPL.getoptic(vn) === identity - nt_ranges[AbstractPPL.getsym(vn)] - else - dict_ranges[vn] - end - @test params[range_with_linked.range] == - DynamicPPL.getindex_internal(vi, vn) - # Check that the link status is correct - @test range_with_linked.is_linked == islinked - end - - # Compare results of FastLDF vs ordinary LogDensityFunction. These tests - # can eventually go once we replace LogDensityFunction with FastLDF, but - # for now it helps to have this check! (Eventually we should just check - # against manually computed log-densities). - # - # TODO(penelopeysm): I think we need to add tests for some really - # pathological models here. - @testset "$getlogdensity" for getlogdensity in ( - DynamicPPL.getlogjoint_internal, - DynamicPPL.getlogjoint, - DynamicPPL.getloglikelihood, - DynamicPPL.getlogprior_internal, - DynamicPPL.getlogprior, - ) - ldf = DynamicPPL.LogDensityFunction(m, getlogdensity, vi) - fldf = FastLDF(m, getlogdensity, vi) - @test LogDensityProblems.logdensity(ldf, params) ≈ - LogDensityProblems.logdensity(fldf, params) - end - end - end - end - - @testset "Threaded observe" begin - if Threads.nthreads() > 1 - @model function threaded(y) - x ~ Normal() - Threads.@threads for i in eachindex(y) - y[i] ~ Normal(x) - end - end - N = 100 - model = threaded(zeros(N)) - ldf = DynamicPPL.Experimental.FastLDF(model) - - xs = [1.0] - @test LogDensityProblems.logdensity(ldf, xs) ≈ - logpdf(Normal(), xs[1]) + N * logpdf(Normal(xs[1]), 0.0) - end - end -end - -@testset "FastLDF: performance" begin - if Threads.nthreads() == 1 - # Evaluating these three models should not lead to any allocations (but only when - # not using TSVI). - @model function f() - x ~ Normal() - return 1.0 ~ Normal(x) - end - @model function submodel_inner() - m ~ Normal(0, 1) - s ~ Exponential() - return (m=m, s=s) - end - # Note that for the allocation tests to work on this one, `inner` has - # to be passed as an argument to `submodel_outer`, instead of just - # being called inside the model function itself - @model function submodel_outer(inner) - params ~ to_submodel(inner) - y ~ Normal(params.m, params.s) - return 1.0 ~ Normal(y) - end - @testset for model in - (f(), submodel_inner() | (; s=0.0), submodel_outer(submodel_inner())) - vi = VarInfo(model) - fldf = DynamicPPL.Experimental.FastLDF( - model, DynamicPPL.getlogjoint_internal, vi - ) - x = vi[:] - bench = median(@be LogDensityProblems.logdensity(fldf, x)) - @test iszero(bench.allocs) - end - end -end - -@testset "AD with FastLDF" begin - # Used as the ground truth that others are compared against. - ref_adtype = AutoForwardDiff() - - test_adtypes = @static if VERSION < v"1.12" - [ - AutoReverseDiff(; compile=false), - AutoReverseDiff(; compile=true), - AutoMooncake(; config=nothing), - ] - else - [AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true)] - end - - @testset "Correctness" begin - @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - varinfo = VarInfo(m) - linked_varinfo = DynamicPPL.link(varinfo, m) - f = FastLDF(m, getlogjoint_internal, linked_varinfo) - x = [p for p in linked_varinfo[:]] - - # Calculate reference logp + gradient of logp using ForwardDiff - ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) - ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual - - @testset "$adtype" for adtype in test_adtypes - @info "Testing AD on: $(m.f) - $adtype" - - @test run_ad( - m, - adtype; - varinfo=linked_varinfo, - test=WithExpectedResult(ref_logp, ref_grad), - ) isa Any - end - end - end - - # Test that various different ways of specifying array types as arguments work with all - # ADTypes. - @testset "Array argument types" begin - test_m = randn(2, 3) - - function eval_logp_and_grad(model, m, adtype) - ldf = FastLDF(model(); adtype=adtype) - return LogDensityProblems.logdensity_and_gradient(ldf, m[:]) - end - - @model function scalar_matrix_model(::Type{T}=Float64) where {T<:Real} - m = Matrix{T}(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - scalar_matrix_model_reference = eval_logp_and_grad( - scalar_matrix_model, test_m, ref_adtype - ) - - @model function matrix_model(::Type{T}=Matrix{Float64}) where {T} - m = T(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - matrix_model_reference = eval_logp_and_grad(matrix_model, test_m, ref_adtype) - - @model function scalar_array_model(::Type{T}=Float64) where {T<:Real} - m = Array{T}(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - scalar_array_model_reference = eval_logp_and_grad( - scalar_array_model, test_m, ref_adtype - ) - - @model function array_model(::Type{T}=Array{Float64}) where {T} - m = T(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - array_model_reference = eval_logp_and_grad(array_model, test_m, ref_adtype) - - @testset "$adtype" for adtype in test_adtypes - scalar_matrix_model_logp_and_grad = eval_logp_and_grad( - scalar_matrix_model, test_m, adtype - ) - @test scalar_matrix_model_logp_and_grad[1] ≈ scalar_matrix_model_reference[1] - @test scalar_matrix_model_logp_and_grad[2] ≈ scalar_matrix_model_reference[2] - matrix_model_logp_and_grad = eval_logp_and_grad(matrix_model, test_m, adtype) - @test matrix_model_logp_and_grad[1] ≈ matrix_model_reference[1] - @test matrix_model_logp_and_grad[2] ≈ matrix_model_reference[2] - scalar_array_model_logp_and_grad = eval_logp_and_grad( - scalar_array_model, test_m, adtype - ) - @test scalar_array_model_logp_and_grad[1] ≈ scalar_array_model_reference[1] - @test scalar_array_model_logp_and_grad[2] ≈ scalar_array_model_reference[2] - array_model_logp_and_grad = eval_logp_and_grad(array_model, test_m, adtype) - @test array_model_logp_and_grad[1] ≈ array_model_reference[1] - @test array_model_logp_and_grad[2] ≈ array_model_reference[2] - end - end -end - -end diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index fbd868f71..06492d6e1 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -1,49 +1,240 @@ -using Test, DynamicPPL, ADTypes, LogDensityProblems, ForwardDiff - -@testset "`getmodel` and `setmodel`" begin - @testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS - model = DynamicPPL.TestUtils.DEMO_MODELS[1] - ℓ = DynamicPPL.LogDensityFunction(model) - @test DynamicPPL.getmodel(ℓ) == model - @test DynamicPPL.setmodel(ℓ, model).model == model +module DynamicPPLLDFTests + +using AbstractPPL: AbstractPPL +using Chairmarks +using DynamicPPL +using Distributions +using DistributionsAD: filldist +using ADTypes +using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest +using LinearAlgebra: I +using Test +using LogDensityProblems: LogDensityProblems + +using ForwardDiff: ForwardDiff +using ReverseDiff: ReverseDiff +using Mooncake: Mooncake + +@testset "LogDensityFunction: Correctness" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$varinfo_func" for varinfo_func in [ + DynamicPPL.untyped_varinfo, + DynamicPPL.typed_varinfo, + DynamicPPL.untyped_vector_varinfo, + DynamicPPL.typed_vector_varinfo, + ] + unlinked_vi = varinfo_func(m) + @testset "$islinked" for islinked in (false, true) + vi = if islinked + DynamicPPL.link!!(unlinked_vi, m) + else + unlinked_vi + end + nt_ranges, dict_ranges = DynamicPPL.get_ranges_and_linked(vi) + params = [x for x in vi[:]] + # Iterate over all variables + for vn in keys(vi) + # Check that `getindex_internal` returns the same thing as using the ranges + # directly + range_with_linked = if AbstractPPL.getoptic(vn) === identity + nt_ranges[AbstractPPL.getsym(vn)] + else + dict_ranges[vn] + end + @test params[range_with_linked.range] == + DynamicPPL.getindex_internal(vi, vn) + # Check that the link status is correct + @test range_with_linked.is_linked == islinked + end + end + end + end + + @testset "Threaded observe" begin + if Threads.nthreads() > 1 + @model function threaded(y) + x ~ Normal() + Threads.@threads for i in eachindex(y) + y[i] ~ Normal(x) + end + end + N = 100 + model = threaded(zeros(N)) + ldf = DynamicPPL.LogDensityFunction(model) + + xs = [1.0] + @test LogDensityProblems.logdensity(ldf, xs) ≈ + logpdf(Normal(), xs[1]) + N * logpdf(Normal(xs[1]), 0.0) + end end end -@testset "LogDensityFunction" begin - @testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS - example_values = DynamicPPL.TestUtils.rand_prior_true(model) - vns = DynamicPPL.TestUtils.varnames(model) - varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns) - - vi = first(varinfos) - theta = vi[:] - ldf_joint = DynamicPPL.LogDensityFunction(model) - @test LogDensityProblems.logdensity(ldf_joint, theta) ≈ logjoint(model, vi) - ldf_prior = DynamicPPL.LogDensityFunction(model, getlogprior) - @test LogDensityProblems.logdensity(ldf_prior, theta) ≈ logprior(model, vi) - ldf_likelihood = DynamicPPL.LogDensityFunction(model, getloglikelihood) - @test LogDensityProblems.logdensity(ldf_likelihood, theta) ≈ - loglikelihood(model, vi) - - @testset "$(varinfo)" for varinfo in varinfos - # Note use of `getlogjoint` rather than `getlogjoint_internal` here ... - logdensity = DynamicPPL.LogDensityFunction(model, getlogjoint, varinfo) - θ = varinfo[:] - # ... because it has to match with `logjoint(model, vi)`, which always returns - # the unlinked value - @test LogDensityProblems.logdensity(logdensity, θ) ≈ logjoint(model, varinfo) - @test LogDensityProblems.dimension(logdensity) == length(θ) +@testset "LogDensityFunction: interface" begin + # miscellaneous parts of the LogDensityProblems interface + @testset "dimensions" begin + @model function m1() + x ~ Normal() + y ~ Normal() + return nothing + end + model = m1() + ldf = DynamicPPL.LogDensityFunction(model) + @test LogDensityProblems.dimension(ldf) == 2 + + @model function m2() + x ~ Dirichlet(ones(4)) + y ~ Categorical(x) + return nothing end + model = m2() + ldf = DynamicPPL.LogDensityFunction(model) + @test LogDensityProblems.dimension(ldf) == 5 + linked_vi = DynamicPPL.link!!(VarInfo(model), model) + ldf = DynamicPPL.LogDensityFunction(model, getlogjoint_internal, linked_vi) + @test LogDensityProblems.dimension(ldf) == 4 end @testset "capabilities" begin - model = DynamicPPL.TestUtils.DEMO_MODELS[1] + @model f() = x ~ Normal() + model = f() + # No adtype ldf = DynamicPPL.LogDensityFunction(model) @test LogDensityProblems.capabilities(typeof(ldf)) == LogDensityProblems.LogDensityOrder{0}() - - ldf_with_ad = DynamicPPL.LogDensityFunction(model; adtype=AutoForwardDiff()) - @test LogDensityProblems.capabilities(typeof(ldf_with_ad)) == + # With adtype + ldf = DynamicPPL.LogDensityFunction(model; adtype=AutoForwardDiff()) + @test LogDensityProblems.capabilities(typeof(ldf)) == LogDensityProblems.LogDensityOrder{1}() end end + +@testset "LogDensityFunction: performance" begin + if Threads.nthreads() == 1 + # Evaluating these three models should not lead to any allocations (but only when + # not using TSVI). + @model function f() + x ~ Normal() + return 1.0 ~ Normal(x) + end + @model function submodel_inner() + m ~ Normal(0, 1) + s ~ Exponential() + return (m=m, s=s) + end + # Note that for the allocation tests to work on this one, `inner` has + # to be passed as an argument to `submodel_outer`, instead of just + # being called inside the model function itself + @model function submodel_outer(inner) + params ~ to_submodel(inner) + y ~ Normal(params.m, params.s) + return 1.0 ~ Normal(y) + end + @testset for model in + (f(), submodel_inner() | (; s=0.0), submodel_outer(submodel_inner())) + vi = VarInfo(model) + ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint_internal, vi) + x = vi[:] + bench = median(@be LogDensityProblems.logdensity(ldf, x)) + @test iszero(bench.allocs) + end + end +end + +@testset "AD with LogDensityFunction" begin + # Used as the ground truth that others are compared against. + ref_adtype = AutoForwardDiff() + + test_adtypes = [ + AutoReverseDiff(; compile=false), + AutoReverseDiff(; compile=true), + AutoMooncake(; config=nothing), + ] + + @testset "Correctness" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + varinfo = VarInfo(m) + linked_varinfo = DynamicPPL.link(varinfo, m) + f = LogDensityFunction(m, getlogjoint_internal, linked_varinfo) + x = [p for p in linked_varinfo[:]] + + # Calculate reference logp + gradient of logp using ForwardDiff + ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) + ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual + + @testset "$adtype" for adtype in test_adtypes + @info "Testing AD on: $(m.f) - $adtype" + + @test run_ad( + m, + adtype; + varinfo=linked_varinfo, + test=WithExpectedResult(ref_logp, ref_grad), + ) isa Any + end + end + end + + # Test that various different ways of specifying array types as arguments work with all + # ADTypes. + @testset "Array argument types" begin + test_m = randn(2, 3) + + function eval_logp_and_grad(model, m, adtype) + ldf = LogDensityFunction(model(); adtype=adtype) + return LogDensityProblems.logdensity_and_gradient(ldf, m[:]) + end + + @model function scalar_matrix_model(::Type{T}=Float64) where {T<:Real} + m = Matrix{T}(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + scalar_matrix_model_reference = eval_logp_and_grad( + scalar_matrix_model, test_m, ref_adtype + ) + + @model function matrix_model(::Type{T}=Matrix{Float64}) where {T} + m = T(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + matrix_model_reference = eval_logp_and_grad(matrix_model, test_m, ref_adtype) + + @model function scalar_array_model(::Type{T}=Float64) where {T<:Real} + m = Array{T}(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + scalar_array_model_reference = eval_logp_and_grad( + scalar_array_model, test_m, ref_adtype + ) + + @model function array_model(::Type{T}=Array{Float64}) where {T} + m = T(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + array_model_reference = eval_logp_and_grad(array_model, test_m, ref_adtype) + + @testset "$adtype" for adtype in test_adtypes + scalar_matrix_model_logp_and_grad = eval_logp_and_grad( + scalar_matrix_model, test_m, adtype + ) + @test scalar_matrix_model_logp_and_grad[1] ≈ scalar_matrix_model_reference[1] + @test scalar_matrix_model_logp_and_grad[2] ≈ scalar_matrix_model_reference[2] + matrix_model_logp_and_grad = eval_logp_and_grad(matrix_model, test_m, adtype) + @test matrix_model_logp_and_grad[1] ≈ matrix_model_reference[1] + @test matrix_model_logp_and_grad[2] ≈ matrix_model_reference[2] + scalar_array_model_logp_and_grad = eval_logp_and_grad( + scalar_array_model, test_m, adtype + ) + @test scalar_array_model_logp_and_grad[1] ≈ scalar_array_model_reference[1] + @test scalar_array_model_logp_and_grad[2] ≈ scalar_array_model_reference[2] + array_model_logp_and_grad = eval_logp_and_grad(array_model, test_m, adtype) + @test array_model_logp_and_grad[1] ≈ array_model_reference[1] + @test array_model_logp_and_grad[2] ≈ array_model_reference[2] + end + end +end + +end diff --git a/test/runtests.jl b/test/runtests.jl index 1474b426a..9649aebbb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,7 @@ using ForwardDiff using LogDensityProblems using MacroTools using MCMCChains +using Mooncake using StableRNGs using ReverseDiff using Mooncake @@ -57,7 +58,6 @@ include("test_util.jl") include("simple_varinfo.jl") include("model.jl") include("distribution_wrappers.jl") - include("logdensityfunction.jl") include("linking.jl") include("serialization.jl") include("pointwise_logdensities.jl") @@ -68,10 +68,11 @@ include("test_util.jl") include("debug_utils.jl") include("submodels.jl") include("chains.jl") - include("bijector.jl") end if GROUP == "All" || GROUP == "Group2" + include("bijector.jl") + include("logdensityfunction.jl") @testset "extensions" begin include("ext/DynamicPPLMCMCChainsExt.jl") include("ext/DynamicPPLJETExt.jl") @@ -80,8 +81,6 @@ include("test_util.jl") @testset "ad" begin include("ext/DynamicPPLForwardDiffExt.jl") include("ext/DynamicPPLMooncakeExt.jl") - include("ad.jl") - include("fasteval.jl") end @testset "prob and logprob macro" begin @test_throws ErrorException prob"..." From c1b935b5356ca98c040cc271e5ad8414d9fd2f61 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 25 Nov 2025 12:18:10 +0000 Subject: [PATCH 016/148] Minor refactor --- src/varnamedtuple.jl | 28 +-- test/varnamedtuple.jl | 402 +++++++++++++++++++++--------------------- 2 files changed, 216 insertions(+), 214 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 006e8f0d5..8d35b34a7 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -93,15 +93,17 @@ function _partial_array_dim_size(min_dim) return factor^(Int(ceil(log(factor, min_dim)))) end +function _min_size(iarr::PartialArray, inds) + return ntuple(i -> max(_internal_size(iarr, i), _length_needed(inds[i])), length(inds)) +end + function _resize_partialarray(iarr::PartialArray, inds) - min_sizes = ntuple( - i -> max(_internal_size(iarr, i), _length_needed(inds[i])), length(inds) - ) - new_sizes = map(_partial_array_dim_size, min_sizes) + min_size = _min_size(iarr, inds) + new_size = map(_partial_array_dim_size, min_size) # Generic multidimensional Arrays can not be resized, so we need to make a new one. # See https://github.com/JuliaLang/julia/issues/37900 - new_data = Array{eltype(iarr.data),ndims(iarr)}(undef, new_sizes) - new_mask = fill(false, new_sizes) + new_data = Array{eltype(iarr.data),ndims(iarr)}(undef, new_size) + new_mask = fill(false, new_size) # Note that we have to use CartesianIndices instead of eachindex, because the latter # may use a linear index that does not match between the old and the new arrays. for i in CartesianIndices(iarr.data) @@ -133,14 +135,12 @@ Base.getindex(pa::PartialArray, optic::IndexLens) = Base.getindex(pa, optic.indi Base.haskey(pa::PartialArray, optic::IndexLens) = Base.haskey(pa, optic.indices) function BangBang.setindex!!(iarr::PartialArray, value, inds::Vararg{INDEX_TYPES}) + if length(inds) != ndims(iarr) + throw(BoundsError(iarr, inds)) + end if _has_colon(inds) - # TODO(mhauru) This could be implemented by getting size information from `value`. - # However, the corresponding getindex is more fundamentally ill-defined. throw(ArgumentError("Indexing with colons is not supported")) end - if length(inds) != ndims(iarr) - throw(ArgumentError("Invalid index $(inds)")) - end iarr = if checkbounds(Bool, iarr.mask, inds...) iarr else @@ -156,12 +156,12 @@ function BangBang.setindex!!(iarr::PartialArray, value, inds::Vararg{INDEX_TYPES end function Base.getindex(iarr::PartialArray, inds::Vararg{INDEX_TYPES}) - if _has_colon(inds) - throw(ArgumentError("Indexing with colons is not supported")) - end if length(inds) != ndims(iarr) throw(ArgumentError("Invalid index $(inds)")) end + if _has_colon(inds) + throw(ArgumentError("Indexing with colons is not supported")) + end if !haskey(iarr, inds) throw(BoundsError(iarr, inds)) end diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 99f528175..02ed3bca8 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -4,206 +4,208 @@ using Test: @inferred, @test, @test_throws, @testset using DynamicPPL: @varname, VarNamedTuple using BangBang: setindex!! -@testset "Basic sets and gets" begin - vnt = VarNamedTuple() - vnt = @inferred(setindex!!(vnt, 32.0, @varname(a))) - @test @inferred(getindex(vnt, @varname(a))) == 32.0 - - vnt = @inferred(setindex!!(vnt, [1, 2, 3], @varname(b))) - @test @inferred(getindex(vnt, @varname(b))) == [1, 2, 3] - @test @inferred(getindex(vnt, @varname(b[2]))) == 2 - - vnt = @inferred(setindex!!(vnt, 64.0, @varname(a))) - @test @inferred(getindex(vnt, @varname(a))) == 64.0 - @test @inferred(getindex(vnt, @varname(b))) == [1, 2, 3] - - vnt = @inferred(setindex!!(vnt, 15, @varname(b[2]))) - @test @inferred(getindex(vnt, @varname(b))) == [1, 15, 3] - @test @inferred(getindex(vnt, @varname(b[2]))) == 15 - - vnt = @inferred(setindex!!(vnt, [10], @varname(c.x.y))) - @test @inferred(getindex(vnt, @varname(c.x.y))) == [10] - - vnt = @inferred(setindex!!(vnt, 11, @varname(c.x.y[1]))) - @test @inferred(getindex(vnt, @varname(c.x.y))) == [11] - @test @inferred(getindex(vnt, @varname(c.x.y[1]))) == 11 - - vnt = @inferred(setindex!!(vnt, -1.0, @varname(d[4]))) - @test @inferred(getindex(vnt, @varname(d[4]))) == -1.0 - - vnt = @inferred(setindex!!(vnt, -2.0, @varname(d[4]))) - @test @inferred(getindex(vnt, @varname(d[4]))) == -2.0 - - # These can't be @inferred because `d` now has an abstract element type. Note that this - # does not ruin type stability for other varnames that don't involve `d`. - vnt = setindex!!(vnt, "a", @varname(d[5])) - @test getindex(vnt, @varname(d[5])) == "a" - - vnt = @inferred(setindex!!(vnt, 1.0, @varname(e.f[3].g.h[2].i))) - @test @inferred(getindex(vnt, @varname(e.f[3].g.h[2].i))) == 1.0 - - vnt = @inferred(setindex!!(vnt, 2.0, @varname(e.f[3].g.h[2].i))) - @test @inferred(getindex(vnt, @varname(e.f[3].g.h[2].i))) == 2.0 - - vec = fill(1.0, 4) - vnt = @inferred(setindex!!(vnt, vec, @varname(j[1:4]))) - @test @inferred(getindex(vnt, @varname(j[1:4]))) == vec - @test @inferred(getindex(vnt, @varname(j[2]))) == vec[2] - @test haskey(vnt, @varname(j[4])) - @test !haskey(vnt, @varname(j[5])) - @test_throws BoundsError getindex(vnt, @varname(j[5])) - - vec = fill(2.0, 4) - vnt = @inferred(setindex!!(vnt, vec, @varname(j[2:5]))) - @test @inferred(getindex(vnt, @varname(j[1]))) == 1.0 - @test @inferred(getindex(vnt, @varname(j[2:5]))) == vec - @test haskey(vnt, @varname(j[5])) - - arr = fill(2.0, (4, 2)) - vn = @varname(k.l[2:5, 3, 1:2, 2]) - vnt = @inferred(setindex!!(vnt, arr, vn)) - @test @inferred(getindex(vnt, vn)) == arr - # A subset of the elements set just now. - @test @inferred(getindex(vnt, @varname(k.l[2, 3, 1:2, 2]))) == fill(2.0, 2) - - # Not enough, or too many, indices. - @test_throws "Invalid index" setindex!!(vnt, 0.0, @varname(k.l[1, 2, 3])) - @test_throws "Invalid index" setindex!!(vnt, 0.0, @varname(k.l[1, 2, 3, 4, 5])) - - arr = fill(3.0, (3, 3)) - vn = @varname(k.l[1, 1:3, 1:3, 1]) - vnt = @inferred(setindex!!(vnt, arr, vn)) - @test @inferred(getindex(vnt, vn)) == arr - # A subset of the elements set just now. - @test @inferred(getindex(vnt, @varname(k.l[1, 1:2, 1:2, 1]))) == fill(3.0, 2, 2) - # A subset of the elements set previously. - @test @inferred(getindex(vnt, @varname(k.l[2, 3, 1:2, 2]))) == fill(2.0, 2) - @test !haskey(vnt, @varname(k.l[2, 3, 3, 2])) - - vnt = @inferred(setindex!!(vnt, 1.0, @varname(m[2]))) - vnt = @inferred(setindex!!(vnt, 1.0, @varname(m[3]))) - @test @inferred(getindex(vnt, @varname(m[2:3]))) == [1.0, 1.0] - @test !haskey(vnt, @varname(m[1])) -end - -@testset "equality" begin - vnt1 = VarNamedTuple() - vnt2 = VarNamedTuple() - @test vnt1 == vnt2 - - vnt1 = setindex!!(vnt1, 1.0, @varname(a)) - @test vnt1 != vnt2 - - vnt2 = setindex!!(vnt2, 1.0, @varname(a)) - @test vnt1 == vnt2 - - vnt1 = setindex!!(vnt1, [1, 2], @varname(b)) - vnt2 = setindex!!(vnt2, [1, 2], @varname(b)) - @test vnt1 == vnt2 - - vnt2 = setindex!!(vnt2, [1, 3], @varname(b)) - @test vnt1 != vnt2 - vnt2 = setindex!!(vnt2, [1, 2], @varname(b)) - - # Try with index lenses too - vnt1 = setindex!!(vnt1, 2, @varname(c[2])) - vnt2 = setindex!!(vnt2, 2, @varname(c[2])) - @test vnt1 == vnt2 - - vnt2 = setindex!!(vnt2, 3, @varname(c[2])) - @test vnt1 != vnt2 - vnt2 = setindex!!(vnt2, 2, @varname(c[2])) - - vnt1 = setindex!!(vnt1, ["a", "b"], @varname(d.e[1:2])) - vnt2 = setindex!!(vnt2, ["a", "b"], @varname(d.e[1:2])) - @test vnt1 == vnt2 - - vnt2 = setindex!!(vnt2, :b, @varname(d.e[2])) - @test vnt1 != vnt2 -end - -@testset "merge" begin - vnt1 = VarNamedTuple() - vnt2 = VarNamedTuple() - expected_merge = VarNamedTuple() - # TODO(mhauru) Wrap this merge in @inferred, likewise other merges where it makes sense. - @test merge(vnt1, vnt2) == expected_merge - - vnt1 = setindex!!(vnt1, 1.0, @varname(a)) - vnt2 = setindex!!(vnt2, 2.0, @varname(b)) - vnt1 = setindex!!(vnt1, 1, @varname(c)) - vnt2 = setindex!!(vnt2, 2, @varname(c)) - expected_merge = setindex!!(expected_merge, 1.0, @varname(a)) - expected_merge = setindex!!(expected_merge, 2, @varname(c)) - expected_merge = setindex!!(expected_merge, 2.0, @varname(b)) - @test merge(vnt1, vnt2) == expected_merge - - vnt1 = VarNamedTuple() - vnt2 = VarNamedTuple() - expected_merge = VarNamedTuple() - vnt1 = setindex!!(vnt1, [1], @varname(d.a)) - vnt2 = setindex!!(vnt2, [2, 2], @varname(d.b)) - vnt1 = setindex!!(vnt1, [1], @varname(d.c)) - vnt2 = setindex!!(vnt2, [2, 2], @varname(d.c)) - expected_merge = setindex!!(expected_merge, [1], @varname(d.a)) - expected_merge = setindex!!(expected_merge, [2, 2], @varname(d.c)) - expected_merge = setindex!!(expected_merge, [2, 2], @varname(d.b)) - @test merge(vnt1, vnt2) == expected_merge - - vnt1 = setindex!!(vnt1, 1, @varname(e.a[1])) - vnt2 = setindex!!(vnt2, 2, @varname(e.a[2])) - expected_merge = setindex!!(expected_merge, 1, @varname(e.a[1])) - expected_merge = setindex!!(expected_merge, 2, @varname(e.a[2])) - vnt1 = setindex!!(vnt1, 1, @varname(e.a[3])) - vnt2 = setindex!!(vnt2, 2, @varname(e.a[3])) - expected_merge = setindex!!(expected_merge, 2, @varname(e.a[3])) - @test merge(vnt1, vnt2) == expected_merge - - vnt1 = setindex!!(vnt1, fill(1, 4), @varname(e.a[7:10])) - vnt2 = setindex!!(vnt2, fill(2, 4), @varname(e.a[8:11])) - expected_merge = setindex!!(expected_merge, 1, @varname(e.a[7])) - expected_merge = setindex!!(expected_merge, fill(2, 4), @varname(e.a[8:11])) - @test merge(vnt1, vnt2) == expected_merge - - vnt1 = setindex!!(vnt1, ["1", "1"], @varname(f.a[1].b.c[2, 2].d[1, 3:4])) - vnt2 = setindex!!(vnt2, ["2", "2"], @varname(f.a[1].b.c[2, 2].d[1, 3:4])) - expected_merge = setindex!!( - expected_merge, ["2", "2"], @varname(f.a[1].b.c[2, 2].d[1, 3:4]) - ) - vnt1 = setindex!!(vnt1, :1, @varname(f.a[1].b.c[3, 2].d[1, 1])) - vnt2 = setindex!!(vnt2, :2, @varname(f.a[1].b.c[4, 2].d[1, 1])) - expected_merge = setindex!!(expected_merge, :1, @varname(f.a[1].b.c[3, 2].d[1, 1])) - expected_merge = setindex!!(expected_merge, :2, @varname(f.a[1].b.c[4, 2].d[1, 1])) - @test merge(vnt1, vnt2) == expected_merge - - # PartialArrays with different sizes. - vnt1 = VarNamedTuple() - vnt2 = VarNamedTuple() - vnt1 = setindex!!(vnt1, 1, @varname(a[1])) - vnt1 = setindex!!(vnt1, 1, @varname(a[257])) - vnt2 = setindex!!(vnt2, 2, @varname(a[1])) - vnt2 = setindex!!(vnt2, 2, @varname(a[2])) - expected_merge_12 = VarNamedTuple() - expected_merge_12 = setindex!!(expected_merge_12, 1, @varname(a[257])) - expected_merge_12 = setindex!!(expected_merge_12, 2, @varname(a[1])) - expected_merge_12 = setindex!!(expected_merge_12, 2, @varname(a[2])) - @test merge(vnt1, vnt2) == expected_merge_12 - expected_merge_21 = setindex!!(expected_merge_12, 1, @varname(a[1])) - @test merge(vnt2, vnt1) == expected_merge_21 - - vnt1 = VarNamedTuple() - vnt2 = VarNamedTuple() - vnt1 = setindex!!(vnt1, 1, @varname(a[1, 1])) - vnt1 = setindex!!(vnt1, 1, @varname(a[257, 1])) - vnt2 = setindex!!(vnt2, :2, @varname(a[1, 1])) - vnt2 = setindex!!(vnt2, :2, @varname(a[1, 257])) - expected_merge_12 = VarNamedTuple() - expected_merge_12 = setindex!!(expected_merge_12, :2, @varname(a[1, 1])) - expected_merge_12 = setindex!!(expected_merge_12, 1, @varname(a[257, 1])) - expected_merge_12 = setindex!!(expected_merge_12, :2, @varname(a[1, 257])) - @test merge(vnt1, vnt2) == expected_merge_12 - expected_merge_21 = setindex!!(expected_merge_12, 1, @varname(a[1, 1])) - @test merge(vnt2, vnt1) == expected_merge_21 +@testset "VarNamedTuple" begin + @testset "Basic sets and gets" begin + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, 32.0, @varname(a))) + @test @inferred(getindex(vnt, @varname(a))) == 32.0 + + vnt = @inferred(setindex!!(vnt, [1, 2, 3], @varname(b))) + @test @inferred(getindex(vnt, @varname(b))) == [1, 2, 3] + @test @inferred(getindex(vnt, @varname(b[2]))) == 2 + + vnt = @inferred(setindex!!(vnt, 64.0, @varname(a))) + @test @inferred(getindex(vnt, @varname(a))) == 64.0 + @test @inferred(getindex(vnt, @varname(b))) == [1, 2, 3] + + vnt = @inferred(setindex!!(vnt, 15, @varname(b[2]))) + @test @inferred(getindex(vnt, @varname(b))) == [1, 15, 3] + @test @inferred(getindex(vnt, @varname(b[2]))) == 15 + + vnt = @inferred(setindex!!(vnt, [10], @varname(c.x.y))) + @test @inferred(getindex(vnt, @varname(c.x.y))) == [10] + + vnt = @inferred(setindex!!(vnt, 11, @varname(c.x.y[1]))) + @test @inferred(getindex(vnt, @varname(c.x.y))) == [11] + @test @inferred(getindex(vnt, @varname(c.x.y[1]))) == 11 + + vnt = @inferred(setindex!!(vnt, -1.0, @varname(d[4]))) + @test @inferred(getindex(vnt, @varname(d[4]))) == -1.0 + + vnt = @inferred(setindex!!(vnt, -2.0, @varname(d[4]))) + @test @inferred(getindex(vnt, @varname(d[4]))) == -2.0 + + # These can't be @inferred because `d` now has an abstract element type. Note that this + # does not ruin type stability for other varnames that don't involve `d`. + vnt = setindex!!(vnt, "a", @varname(d[5])) + @test getindex(vnt, @varname(d[5])) == "a" + + vnt = @inferred(setindex!!(vnt, 1.0, @varname(e.f[3].g.h[2].i))) + @test @inferred(getindex(vnt, @varname(e.f[3].g.h[2].i))) == 1.0 + + vnt = @inferred(setindex!!(vnt, 2.0, @varname(e.f[3].g.h[2].i))) + @test @inferred(getindex(vnt, @varname(e.f[3].g.h[2].i))) == 2.0 + + vec = fill(1.0, 4) + vnt = @inferred(setindex!!(vnt, vec, @varname(j[1:4]))) + @test @inferred(getindex(vnt, @varname(j[1:4]))) == vec + @test @inferred(getindex(vnt, @varname(j[2]))) == vec[2] + @test haskey(vnt, @varname(j[4])) + @test !haskey(vnt, @varname(j[5])) + @test_throws BoundsError getindex(vnt, @varname(j[5])) + + vec = fill(2.0, 4) + vnt = @inferred(setindex!!(vnt, vec, @varname(j[2:5]))) + @test @inferred(getindex(vnt, @varname(j[1]))) == 1.0 + @test @inferred(getindex(vnt, @varname(j[2:5]))) == vec + @test haskey(vnt, @varname(j[5])) + + arr = fill(2.0, (4, 2)) + vn = @varname(k.l[2:5, 3, 1:2, 2]) + vnt = @inferred(setindex!!(vnt, arr, vn)) + @test @inferred(getindex(vnt, vn)) == arr + # A subset of the elements set just now. + @test @inferred(getindex(vnt, @varname(k.l[2, 3, 1:2, 2]))) == fill(2.0, 2) + + # Not enough, or too many, indices. + @test_throws BoundsError setindex!!(vnt, 0.0, @varname(k.l[1, 2, 3])) + @test_throws BoundsError setindex!!(vnt, 0.0, @varname(k.l[1, 2, 3, 4, 5])) + + arr = fill(3.0, (3, 3)) + vn = @varname(k.l[1, 1:3, 1:3, 1]) + vnt = @inferred(setindex!!(vnt, arr, vn)) + @test @inferred(getindex(vnt, vn)) == arr + # A subset of the elements set just now. + @test @inferred(getindex(vnt, @varname(k.l[1, 1:2, 1:2, 1]))) == fill(3.0, 2, 2) + # A subset of the elements set previously. + @test @inferred(getindex(vnt, @varname(k.l[2, 3, 1:2, 2]))) == fill(2.0, 2) + @test !haskey(vnt, @varname(k.l[2, 3, 3, 2])) + + vnt = @inferred(setindex!!(vnt, 1.0, @varname(m[2]))) + vnt = @inferred(setindex!!(vnt, 1.0, @varname(m[3]))) + @test @inferred(getindex(vnt, @varname(m[2:3]))) == [1.0, 1.0] + @test !haskey(vnt, @varname(m[1])) + end + + @testset "equality" begin + vnt1 = VarNamedTuple() + vnt2 = VarNamedTuple() + @test vnt1 == vnt2 + + vnt1 = setindex!!(vnt1, 1.0, @varname(a)) + @test vnt1 != vnt2 + + vnt2 = setindex!!(vnt2, 1.0, @varname(a)) + @test vnt1 == vnt2 + + vnt1 = setindex!!(vnt1, [1, 2], @varname(b)) + vnt2 = setindex!!(vnt2, [1, 2], @varname(b)) + @test vnt1 == vnt2 + + vnt2 = setindex!!(vnt2, [1, 3], @varname(b)) + @test vnt1 != vnt2 + vnt2 = setindex!!(vnt2, [1, 2], @varname(b)) + + # Try with index lenses too + vnt1 = setindex!!(vnt1, 2, @varname(c[2])) + vnt2 = setindex!!(vnt2, 2, @varname(c[2])) + @test vnt1 == vnt2 + + vnt2 = setindex!!(vnt2, 3, @varname(c[2])) + @test vnt1 != vnt2 + vnt2 = setindex!!(vnt2, 2, @varname(c[2])) + + vnt1 = setindex!!(vnt1, ["a", "b"], @varname(d.e[1:2])) + vnt2 = setindex!!(vnt2, ["a", "b"], @varname(d.e[1:2])) + @test vnt1 == vnt2 + + vnt2 = setindex!!(vnt2, :b, @varname(d.e[2])) + @test vnt1 != vnt2 + end + + @testset "merge" begin + vnt1 = VarNamedTuple() + vnt2 = VarNamedTuple() + expected_merge = VarNamedTuple() + # TODO(mhauru) Wrap this merge in @inferred, likewise other merges where it makes sense. + @test merge(vnt1, vnt2) == expected_merge + + vnt1 = setindex!!(vnt1, 1.0, @varname(a)) + vnt2 = setindex!!(vnt2, 2.0, @varname(b)) + vnt1 = setindex!!(vnt1, 1, @varname(c)) + vnt2 = setindex!!(vnt2, 2, @varname(c)) + expected_merge = setindex!!(expected_merge, 1.0, @varname(a)) + expected_merge = setindex!!(expected_merge, 2, @varname(c)) + expected_merge = setindex!!(expected_merge, 2.0, @varname(b)) + @test merge(vnt1, vnt2) == expected_merge + + vnt1 = VarNamedTuple() + vnt2 = VarNamedTuple() + expected_merge = VarNamedTuple() + vnt1 = setindex!!(vnt1, [1], @varname(d.a)) + vnt2 = setindex!!(vnt2, [2, 2], @varname(d.b)) + vnt1 = setindex!!(vnt1, [1], @varname(d.c)) + vnt2 = setindex!!(vnt2, [2, 2], @varname(d.c)) + expected_merge = setindex!!(expected_merge, [1], @varname(d.a)) + expected_merge = setindex!!(expected_merge, [2, 2], @varname(d.c)) + expected_merge = setindex!!(expected_merge, [2, 2], @varname(d.b)) + @test merge(vnt1, vnt2) == expected_merge + + vnt1 = setindex!!(vnt1, 1, @varname(e.a[1])) + vnt2 = setindex!!(vnt2, 2, @varname(e.a[2])) + expected_merge = setindex!!(expected_merge, 1, @varname(e.a[1])) + expected_merge = setindex!!(expected_merge, 2, @varname(e.a[2])) + vnt1 = setindex!!(vnt1, 1, @varname(e.a[3])) + vnt2 = setindex!!(vnt2, 2, @varname(e.a[3])) + expected_merge = setindex!!(expected_merge, 2, @varname(e.a[3])) + @test merge(vnt1, vnt2) == expected_merge + + vnt1 = setindex!!(vnt1, fill(1, 4), @varname(e.a[7:10])) + vnt2 = setindex!!(vnt2, fill(2, 4), @varname(e.a[8:11])) + expected_merge = setindex!!(expected_merge, 1, @varname(e.a[7])) + expected_merge = setindex!!(expected_merge, fill(2, 4), @varname(e.a[8:11])) + @test merge(vnt1, vnt2) == expected_merge + + vnt1 = setindex!!(vnt1, ["1", "1"], @varname(f.a[1].b.c[2, 2].d[1, 3:4])) + vnt2 = setindex!!(vnt2, ["2", "2"], @varname(f.a[1].b.c[2, 2].d[1, 3:4])) + expected_merge = setindex!!( + expected_merge, ["2", "2"], @varname(f.a[1].b.c[2, 2].d[1, 3:4]) + ) + vnt1 = setindex!!(vnt1, :1, @varname(f.a[1].b.c[3, 2].d[1, 1])) + vnt2 = setindex!!(vnt2, :2, @varname(f.a[1].b.c[4, 2].d[1, 1])) + expected_merge = setindex!!(expected_merge, :1, @varname(f.a[1].b.c[3, 2].d[1, 1])) + expected_merge = setindex!!(expected_merge, :2, @varname(f.a[1].b.c[4, 2].d[1, 1])) + @test merge(vnt1, vnt2) == expected_merge + + # PartialArrays with different sizes. + vnt1 = VarNamedTuple() + vnt2 = VarNamedTuple() + vnt1 = setindex!!(vnt1, 1, @varname(a[1])) + vnt1 = setindex!!(vnt1, 1, @varname(a[257])) + vnt2 = setindex!!(vnt2, 2, @varname(a[1])) + vnt2 = setindex!!(vnt2, 2, @varname(a[2])) + expected_merge_12 = VarNamedTuple() + expected_merge_12 = setindex!!(expected_merge_12, 1, @varname(a[257])) + expected_merge_12 = setindex!!(expected_merge_12, 2, @varname(a[1])) + expected_merge_12 = setindex!!(expected_merge_12, 2, @varname(a[2])) + @test merge(vnt1, vnt2) == expected_merge_12 + expected_merge_21 = setindex!!(expected_merge_12, 1, @varname(a[1])) + @test merge(vnt2, vnt1) == expected_merge_21 + + vnt1 = VarNamedTuple() + vnt2 = VarNamedTuple() + vnt1 = setindex!!(vnt1, 1, @varname(a[1, 1])) + vnt1 = setindex!!(vnt1, 1, @varname(a[257, 1])) + vnt2 = setindex!!(vnt2, :2, @varname(a[1, 1])) + vnt2 = setindex!!(vnt2, :2, @varname(a[1, 257])) + expected_merge_12 = VarNamedTuple() + expected_merge_12 = setindex!!(expected_merge_12, :2, @varname(a[1, 1])) + expected_merge_12 = setindex!!(expected_merge_12, 1, @varname(a[257, 1])) + expected_merge_12 = setindex!!(expected_merge_12, :2, @varname(a[1, 257])) + @test merge(vnt1, vnt2) == expected_merge_12 + expected_merge_21 = setindex!!(expected_merge_12, 1, @varname(a[1, 1])) + @test merge(vnt2, vnt1) == expected_merge_21 + end end end From 262a6f98c05f2536ef2e9dcdf6910f332e292e36 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 25 Nov 2025 12:21:42 +0000 Subject: [PATCH 017/148] Remove IndexDict --- src/varnamedtuple.jl | 36 ++---------------------------------- 1 file changed, 2 insertions(+), 34 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 8d35b34a7..d07555db2 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -30,11 +30,6 @@ function Base.:(==)(vnt1::VarNamedTuple, vnt2::VarNamedTuple) return vnt1.make_leaf === vnt2.make_leaf && vnt1.data == vnt2.data end -struct IndexDict{T<:Function,Keys,Values} - data::Dict{Keys,Values} - make_leaf::T -end - struct PartialArray{T<:Function,ElType,numdims} data::Array{ElType,numdims} mask::Array{Bool,numdims} @@ -273,18 +268,6 @@ function make_leaf_array(value, optic::IndexLens{T}) where {T} return setindex!!(iarr, value, optic) end -function make_leaf_dict(value, ::PropertyLens{S}) where {S} - return VarNamedTuple(NamedTuple{(S,)}((value,)), make_leaf_dict) -end -make_leaf_dict(value, ::typeof(identity)) = value -function make_leaf_dict(value, optic::ComposedFunction) - sub = make_leaf_dict(value, optic.outer) - return make_leaf_dict(sub, optic.inner) -end -function make_leaf_dict(value, optic::IndexLens) - return IndexDict(Dict(optic.indices => value), make_leaf_dict) -end - VarNamedTuple() = VarNamedTuple((;), make_leaf_array) function Base.show(io::IO, vnt::VarNamedTuple) @@ -299,10 +282,6 @@ function Base.show(io::IO, vnt::VarNamedTuple) return print(io, ")") end -function Base.show(io::IO, id::IndexDict) - return print(io, id.data) -end - Base.getindex(vnt::VarNamedTuple, name::Symbol) = vnt.data[name] function varname_to_lens(name::VarName{S}) where {S} @@ -312,18 +291,13 @@ end function Base.getindex(vnt::VarNamedTuple, name::VarName) return getindex(vnt, varname_to_lens(name)) end -function Base.getindex( - x::Union{VarNamedTuple,IndexDict,PartialArray}, optic::ComposedFunction -) +function Base.getindex(x::Union{VarNamedTuple,PartialArray}, optic::ComposedFunction) subdata = getindex(x, optic.inner) return getindex(subdata, optic.outer) end function Base.getindex(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} return getindex(vnt.data, S) end -function Base.getindex(id::IndexDict, optic::IndexLens) - return getindex(id.data, optic.indices) -end function Base.haskey(vnt::VarNamedTuple, name::VarName) return haskey(vnt, varname_to_lens(name)) @@ -336,9 +310,7 @@ function Base.haskey(vnt::VarNamedTuple, optic::ComposedFunction) end Base.haskey(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} = haskey(vnt.data, S) -Base.haskey(id::IndexDict, optic::IndexLens) = haskey(id.data, optic.indices) Base.haskey(::VarNamedTuple, ::IndexLens) = false -Base.haskey(::IndexDict, ::PropertyLens) = false # TODO(mhauru) This is type piracy. Base.getindex(arr::AbstractArray, optic::IndexLens) = getindex(arr, optic.indices...) @@ -353,7 +325,7 @@ function BangBang.setindex!!(vnt::VarNamedTuple, value, name::VarName) end function BangBang.setindex!!( - vnt::Union{VarNamedTuple,IndexDict,PartialArray}, value, optic::ComposedFunction + vnt::Union{VarNamedTuple,PartialArray}, value, optic::ComposedFunction ) sub = if haskey(vnt, optic.inner) BangBang.setindex!!(getindex(vnt, optic.inner), value, optic.outer) @@ -371,10 +343,6 @@ function BangBang.setindex!!(vnt::VarNamedTuple, value, ::PropertyLens{S}) where return VarNamedTuple(merge(vnt.data, NamedTuple{(S,)}((value,))), vnt.make_leaf) end -function BangBang.setindex!!(id::IndexDict, value, optic::IndexLens) - return IndexDict(setindex!!(id.data, value, optic.indices), id.make_leaf) -end - function apply(func, vnt::VarNamedTuple, name::VarName) if !haskey(vnt.data, name.name) throw(KeyError(repr(name))) From abea08782d93e1d4668c1838bce95d7c8b8c7483 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 25 Nov 2025 12:27:17 +0000 Subject: [PATCH 018/148] Remove make_leaf as a field --- src/varnamedtuple.jl | 59 +++++++++++++++++++------------------------- 1 file changed, 25 insertions(+), 34 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index d07555db2..0dfb9ec11 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -19,28 +19,24 @@ function _is_multiindex(::T) where {T<:Tuple} return any(x <: UnitRange || x <: Colon for x in T.parameters) end -struct VarNamedTuple{T<:Function,Names,Values} +struct VarNamedTuple{Names,Values} data::NamedTuple{Names,Values} - make_leaf::T end # TODO(mhauru) Since I define this, should I also define `isequal` and `hash`? Same for # PartialArrays. -function Base.:(==)(vnt1::VarNamedTuple, vnt2::VarNamedTuple) - return vnt1.make_leaf === vnt2.make_leaf && vnt1.data == vnt2.data -end +Base.:(==)(vnt1::VarNamedTuple, vnt2::VarNamedTuple) = vnt1.data == vnt2.data -struct PartialArray{T<:Function,ElType,numdims} +struct PartialArray{ElType,numdims} data::Array{ElType,numdims} mask::Array{Bool,numdims} - make_leaf::T end -function PartialArray(eltype, num_dims, make_leaf=make_leaf_array) +function PartialArray(eltype, num_dims) dims = ntuple(_ -> PARTIAL_ARRAY_DIM_GROWTH_FACTOR, num_dims) data = Array{eltype,num_dims}(undef, dims) mask = fill(false, dims) - return PartialArray(data, mask, make_leaf) + return PartialArray(data, mask) end Base.ndims(iarr::PartialArray) = ndims(iarr.data) @@ -50,11 +46,11 @@ Base.ndims(iarr::PartialArray) = ndims(iarr.data) _internal_size(iarr::PartialArray, args...) = size(iarr.data, args...) function Base.copy(pa::PartialArray) - return PartialArray(copy(pa.data), copy(pa.mask), pa.make_leaf) + return PartialArray(copy(pa.data), copy(pa.mask)) end function Base.:(==)(pa1::PartialArray, pa2::PartialArray) - if (pa1.make_leaf !== pa2.make_leaf) || (ndims(pa1) != ndims(pa2)) + if ndims(pa1) != ndims(pa2) return false end size1 = _internal_size(pa1) @@ -108,11 +104,11 @@ function _resize_partialarray(iarr::PartialArray, inds) @inbounds new_data[i] = iarr.data[i] end end - return PartialArray(new_data, new_mask, iarr.make_leaf) + return PartialArray(new_data, new_mask) end # The below implements the same functionality as above, but more performantly for 1D arrays. -function _resize_partialarray(iarr::PartialArray{T,Eltype,1}, (ind,)) where {T,Eltype} +function _resize_partialarray(iarr::PartialArray{Eltype,1}, (ind,)) where {Eltype} # Resize arrays to accommodate new indices. old_size = _internal_size(iarr, 1) min_size = max(old_size, _length_needed(ind)) @@ -147,7 +143,7 @@ function BangBang.setindex!!(iarr::PartialArray, value, inds::Vararg{INDEX_TYPES else iarr.mask[inds...] = true end - return PartialArray(new_data, iarr.mask, iarr.make_leaf) + return PartialArray(new_data, iarr.mask) end function Base.getindex(iarr::PartialArray, inds::Vararg{INDEX_TYPES}) @@ -194,11 +190,6 @@ function _merge_recursive(pa1::PartialArray, pa2::PartialArray) ArgumentError("Cannot merge PartialArrays with different number of dimensions") ) end - if pa1.make_leaf !== pa2.make_leaf - throw( - ArgumentError("Cannot merge PartialArrays with different make_leaf functions") - ) - end num_dims = ndims(pa1) merge_size = ntuple(i -> max(_internal_size(pa1, i), _internal_size(pa2, i)), num_dims) result = if merge_size == _internal_size(pa2) @@ -229,7 +220,7 @@ function _merge_recursive(pa1::PartialArray, pa2::PartialArray) et = promote_type(eltype(pa1), eltype(pa2)) new_data = Array{et,num_dims}(undef, merge_size) new_mask = fill(false, merge_size) - result = PartialArray(new_data, new_mask, pa2.make_leaf) + result = PartialArray(new_data, new_mask) for i in CartesianIndices(pa2.data) @inbounds if pa2.mask[i] result.mask[i] = true @@ -249,26 +240,26 @@ function _merge_recursive(pa1::PartialArray, pa2::PartialArray) return result end -function make_leaf_array(value, ::PropertyLens{S}) where {S} - return VarNamedTuple(NamedTuple{(S,)}((value,)), make_leaf_array) +function make_leaf(value, ::PropertyLens{S}) where {S} + return VarNamedTuple(NamedTuple{(S,)}((value,))) end -make_leaf_array(value, ::typeof(identity)) = value -function make_leaf_array(value, optic::ComposedFunction) - sub = make_leaf_array(value, optic.outer) - return make_leaf_array(sub, optic.inner) +make_leaf(value, ::typeof(identity)) = value +function make_leaf(value, optic::ComposedFunction) + sub = make_leaf(value, optic.outer) + return make_leaf(sub, optic.inner) end -function make_leaf_array(value, optic::IndexLens{T}) where {T} +function make_leaf(value, optic::IndexLens{T}) where {T} inds = optic.indices num_inds = length(inds) # Check if any of the indices are ranges or colons. If yes, value needs to be an # AbstractArray. Otherwise it needs to be an individual value. et = _is_multiindex(optic.indices) ? eltype(value) : typeof(value) - iarr = PartialArray(et, num_inds, make_leaf_array) + iarr = PartialArray(et, num_inds) return setindex!!(iarr, value, optic) end -VarNamedTuple() = VarNamedTuple((;), make_leaf_array) +VarNamedTuple() = VarNamedTuple((;)) function Base.show(io::IO, vnt::VarNamedTuple) print(io, "(") @@ -330,17 +321,17 @@ function BangBang.setindex!!( sub = if haskey(vnt, optic.inner) BangBang.setindex!!(getindex(vnt, optic.inner), value, optic.outer) else - vnt.make_leaf(value, optic.outer) + make_leaf(value, optic.outer) end return BangBang.setindex!!(vnt, sub, optic.inner) end function BangBang.setindex!!(vnt::VarNamedTuple, value, ::PropertyLens{S}) where {S} # I would like this to just read - # return VarNamedTuple(BangBang.setindex!!(vnt.data, value, S), vnt.make_leaf) + # return VarNamedTuple(BangBang.setindex!!(vnt.data, value, S)) # but that seems to be type unstable. Why? Shouldn't it obviously be the same as the # below? - return VarNamedTuple(merge(vnt.data, NamedTuple{(S,)}((value,))), vnt.make_leaf) + return VarNamedTuple(merge(vnt.data, NamedTuple{(S,)}((value,)))) end function apply(func, vnt::VarNamedTuple, name::VarName) @@ -354,7 +345,7 @@ end function Base.map(func, vnt::VarNamedTuple) new_data = NamedTuple{keys(vnt.data)}(map(func, values(vnt.data))) - return VarNamedTuple(new_data, vnt.make_leaf) + return VarNamedTuple(new_data) end function Base.keys(vnt::VarNamedTuple) @@ -409,7 +400,7 @@ function _merge_recursive(vnt1::VarNamedTuple, vnt2::VarNamedTuple) end Accessors.@reset result_data[k] = val end - return VarNamedTuple(result_data, vnt2.make_leaf) + return VarNamedTuple(result_data) end end From 5900f6906a0b3951baf5db3f419fd838a215a10a Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 25 Nov 2025 14:35:18 +0000 Subject: [PATCH 019/148] Document, refactor, and fix PartialArray --- src/varnamedtuple.jl | 369 ++++++++++++++++++++++++++++-------------- test/varnamedtuple.jl | 18 ++- 2 files changed, 267 insertions(+), 120 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 0dfb9ec11..79a09f678 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -8,42 +8,139 @@ using ..DynamicPPL: _compose_no_identity export VarNamedTuple -"""The factor by which we increase the dimensions of PartialArrays when resizing them.""" -const PARTIAL_ARRAY_DIM_GROWTH_FACTOR = 4 - -const INDEX_TYPES = Union{Integer,UnitRange,Colon} - +# We define our own getindex, setindex!!, and haskey functions to be able to override their +# behaviour for some types exported from elsewhere without type piracy. This is needed +# because +# 1. We want to index into things with lenses (from Accessors.jl) using getindex and +# setindex!!. +# 2. We want to use getindex, setindex!!, and haskey as the universal functions for getting, +# setting, checking. This includes e.g. checking whether an index is valid for an Array, +# which would normally be done with checkbounds. +_haskey(x, key) = Base.haskey(x, key) +_getindex(x, inds...) = Base.getindex(x, inds...) +_setindex!!(x, value, inds...) = BangBang.setindex!!(x, value, inds...) +_getindex(arr::AbstractArray, optic::IndexLens) = _getindex(arr, optic.indices...) +_haskey(arr::AbstractArray, optic::IndexLens) = _haskey(arr, optic.indices) +function _setindex!!(arr::AbstractArray, value, optic::IndexLens) + return _setindex!!(arr, value, optic.indices...) +end +_haskey(arr::AbstractArray, inds) = checkbounds(Bool, arr, inds...) + +# Some utilities for checking what sort of indices we are dealing with. _has_colon(::T) where {T<:Tuple} = any(x <: Colon for x in T.parameters) - function _is_multiindex(::T) where {T<:Tuple} return any(x <: UnitRange || x <: Colon for x in T.parameters) end -struct VarNamedTuple{Names,Values} - data::NamedTuple{Names,Values} -end +""" + _merge_recursive(x1, x2) -# TODO(mhauru) Since I define this, should I also define `isequal` and `hash`? Same for -# PartialArrays. -Base.:(==)(vnt1::VarNamedTuple, vnt2::VarNamedTuple) = vnt1.data == vnt2.data +Recursively merge two values `x1` and `x2`. + +Unlike `Base.merge`, this function is defined for all types, and by default returns the +second argument. It is overridden for `PartialArray` and `VarNamedTuple`, since they are +nested containers, and calls itself recursively on all elements that are found in both +`x1` and `x2`. + +In other words, if both `x` and `y` are collections with the key `a`, `Base.merge(x, y)[a]` +is `y[a]`, whereas `_merge_recursive(x, y)[a]` be `_merge_recursive(x[a], y[a])`, unless +no specific method is defined for the type of `x` and `y`, in which case +`_merge_recursive(x, y) === y` +""" +_merge_recursive(_, x2) = x2 + +"""The factor by which we increase the dimensions of PartialArrays when resizing them.""" +const PARTIAL_ARRAY_DIM_GROWTH_FACTOR = 4 + +"""A convenience for defining method argument type bounds.""" +const INDEX_TYPES = Union{Integer,UnitRange,Colon} -struct PartialArray{ElType,numdims} - data::Array{ElType,numdims} - mask::Array{Bool,numdims} +""" + PartialArray{ElType,numdims} + +An array-like like structure that may only have some of its elements defined. + +A `PartialArray` is like a `Base.Array,` except not all of its elements are necessarily +defined. That is to say, one can create an empty `PartialArray` `arr` and e.g. set +`arr[3,2] = 5`, but asking for `arr[1,1]` may throw a `BoundsError` if `[1, 1]` has not been +explicitly set yet. + +`PartialArray`s can be indexed with integer indices and ranges. Indexing is always 1-based. +Other types of indexing allowed by `Base.Array` are not supported. Some of these are simply +because we haven't seen a need and haven't bothered to implement them, namely boolean +indexing, linear indexing into multidimensional arrays, and indexing with arrays. However, +notably, indexing with colons (i.e. `:`) is not supported for more fundamental reasons. + +To understand this, note that a `PartialArray` has no well-defined size. For example, if one +creates an empty array and sets `arr[3,2]`, it is unclear if that should be taken to mean +that the array has size `(3,2)`: It could be larger, and saying that the size is `(3,2)` +would also misleadingly suggest that all elements within `1:3,1:2` are set. This is also why +colon indexing is ill-defined: If one would e.g. set `arr[2,:] = [1,2,3]`, we would have no +way of saying whether the right hand side is of an acceptable size or not. + +The fact that its size is ill-defined also means that `PartialArray` is not a subtype of +`AbstractArray`. + +All indexing into `PartialArray`s are done with `getindex` and `setindex!!`. `setindex!`, +`push!`, etc. are not defined. The element type of a `PartialArray` will change as needed +under `setindex!!` to accomoddate the new values. + +Like `Base.Array`s, `PartialArray`s have a well-defined, compile-time-known element type +`ElType` and number of dimensions `numdims`. + +The internal implementation of an `PartialArray` consists of two arrays: one holding the +data and the other one being a boolean mask indicating which elements are defined. These +internal arrays may need resizing when new elements are set that have index ranges larger +than the current internal arrays. To avoid resizing too often, the internal arrays are +resized in exponentially increasing steps. This means that most `setindex!!` calls are very +fast, but some may incur substantial overhead due to resizing and copying data. It also +means that the largest index set so far determines the memory usage of the `PartialArray`. +`PartialArray`s are thus well-suited when most values in it will eventually be set. If only +a few scattered values are set, a structure like `SparseArray` may be more appropriate. +""" +struct PartialArray{ElType,num_dims} + data::Array{ElType,num_dims} + mask::Array{Bool,num_dims} + + function PartialArray( + data::Array{ElType,num_dims}, mask::Array{Bool,num_dims} + ) where {ElType,num_dims} + if size(data) != size(mask) + throw(ArgumentError("Data and mask arrays must have the same size")) + end + return new{ElType,num_dims}(data, mask) + end end -function PartialArray(eltype, num_dims) +""" + PartialArray{ElType,num_dims}(min_size=nothing) + +Create a new empty `PartialArray` with set element type and number of dimensions. + +The optional argument `min_size` can be used to specify the minimum initial size. This is +purely a performance optimisation, to avoid resizing if the eventual size is known ahead of +time. +""" +function PartialArray{ElType,num_dims}( + min_size::Union{Tuple,Nothing}=nothing +) where {ElType,num_dims} + if min_size === nothing + dims = ntuple(_ -> PARTIAL_ARRAY_DIM_GROWTH_FACTOR, num_dims) + else + dims = map(_partial_array_dim_size, min_size) + end dims = ntuple(_ -> PARTIAL_ARRAY_DIM_GROWTH_FACTOR, num_dims) - data = Array{eltype,num_dims}(undef, dims) + data = Array{ElType,num_dims}(undef, dims) mask = fill(false, dims) return PartialArray(data, mask) end -Base.ndims(iarr::PartialArray) = ndims(iarr.data) +Base.ndims(::PartialArray{ElType,num_dims}) where {ElType,num_dims} = num_dims +Base.eltype(::PartialArray{ElType}) where {ElType} = ElType # We deliberately don't define Base.size for PartialArray, because it is ill-defined. # The size of the .data field is an implementation detail. -_internal_size(iarr::PartialArray, args...) = size(iarr.data, args...) +_internal_size(pa::PartialArray, args...) = size(pa.data, args...) function Base.copy(pa::PartialArray) return PartialArray(copy(pa.data), copy(pa.mask)) @@ -55,7 +152,8 @@ function Base.:(==)(pa1::PartialArray, pa2::PartialArray) end size1 = _internal_size(pa1) size2 = _internal_size(pa2) - # TODO(mhauru) This could be optimised, but not sure it's worth it. + # TODO(mhauru) This could be optimised by not calling checkbounds on all elements + # outside the size of an array, but not sure it's worth it. merge_size = ntuple(i -> max(size1[i], size2[i]), ndims(pa1)) for i in CartesianIndices(merge_size) m1 = checkbounds(Bool, pa1.mask, Tuple(i)...) ? pa1.mask[i] : false @@ -70,9 +168,20 @@ function Base.:(==)(pa1::PartialArray, pa2::PartialArray) return true end +function Base.hash(pa::PartialArray, h::UInt) + h = hash(ndims(pa), h) + for i in eachindex(pa.mask) + @inbounds if pa.mask[i] + h = hash(i, h) + h = hash(pa.data[i], h) + end + end + return h +end + +"""Return the length needed in a dimension given an index.""" _length_needed(i::Integer) = i _length_needed(r::UnitRange) = last(r) -_length_needed(::Colon) = 0 """Take the minimum size that a dimension of a PartialArray needs to be, and return the size we choose it to be. This size will be the smallest possible power of @@ -84,92 +193,100 @@ function _partial_array_dim_size(min_dim) return factor^(Int(ceil(log(factor, min_dim)))) end -function _min_size(iarr::PartialArray, inds) - return ntuple(i -> max(_internal_size(iarr, i), _length_needed(inds[i])), length(inds)) +"""Return the minimum internal size needed for a `PartialArray` to be able set the value +at inds. +""" +function _min_size(pa::PartialArray, inds) + return ntuple(i -> max(_internal_size(pa, i), _length_needed(inds[i])), length(inds)) end -function _resize_partialarray(iarr::PartialArray, inds) - min_size = _min_size(iarr, inds) +"""Resize a PartialArray to be able to accommodate the index inds. This operates in place +for vectors, but makes a copy for higher-dimensional arrays, unless no resizing is +necessary, in which case this is a no-op.""" +function _resize_partialarray!!(pa::PartialArray, inds) + min_size = _min_size(pa, inds) new_size = map(_partial_array_dim_size, min_size) + if new_size == _internal_size(pa) + return pa + end # Generic multidimensional Arrays can not be resized, so we need to make a new one. # See https://github.com/JuliaLang/julia/issues/37900 - new_data = Array{eltype(iarr.data),ndims(iarr)}(undef, new_size) + new_data = Array{eltype(pa),ndims(pa)}(undef, new_size) new_mask = fill(false, new_size) # Note that we have to use CartesianIndices instead of eachindex, because the latter # may use a linear index that does not match between the old and the new arrays. - for i in CartesianIndices(iarr.data) - mask_val = iarr.mask[i] - @inbounds new_mask[i] = mask_val + @inbounds for i in CartesianIndices(pa.data) + mask_val = pa.mask[i] + new_mask[i] = mask_val if mask_val - @inbounds new_data[i] = iarr.data[i] + new_data[i] = pa.data[i] end end return PartialArray(new_data, new_mask) end # The below implements the same functionality as above, but more performantly for 1D arrays. -function _resize_partialarray(iarr::PartialArray{Eltype,1}, (ind,)) where {Eltype} +function _resize_partialarray!!(pa::PartialArray{Eltype,1}, (ind,)) where {Eltype} # Resize arrays to accommodate new indices. - old_size = _internal_size(iarr, 1) + old_size = _internal_size(pa, 1) min_size = max(old_size, _length_needed(ind)) new_size = _partial_array_dim_size(min_size) - resize!(iarr.data, new_size) - resize!(iarr.mask, new_size) - @inbounds iarr.mask[(old_size + 1):new_size] .= false - return iarr + if new_size == old_size + return pa + end + resize!(pa.data, new_size) + resize!(pa.mask, new_size) + @inbounds pa.mask[(old_size + 1):new_size] .= false + return pa end -function BangBang.setindex!!(pa::PartialArray, value, optic::IndexLens) - return BangBang.setindex!!(pa, value, optic.indices...) +_getindex(pa::PartialArray, optic::IndexLens) = _getindex(pa, optic.indices...) +_haskey(pa::PartialArray, optic::IndexLens) = _haskey(pa, optic.indices) +function _setindex!!(pa::PartialArray, value, optic::IndexLens) + return _setindex!!(pa, value, optic.indices...) end -Base.getindex(pa::PartialArray, optic::IndexLens) = Base.getindex(pa, optic.indices...) -Base.haskey(pa::PartialArray, optic::IndexLens) = Base.haskey(pa, optic.indices) -function BangBang.setindex!!(iarr::PartialArray, value, inds::Vararg{INDEX_TYPES}) - if length(inds) != ndims(iarr) - throw(BoundsError(iarr, inds)) +"""Throw an appropriate error if the given indices are invalid for `pa`.""" +function _check_index_validity(pa::PartialArray, inds::NTuple{N,INDEX_TYPES}) where {N} + if length(inds) != ndims(pa) + throw(BoundsError(pa, inds)) end if _has_colon(inds) - throw(ArgumentError("Indexing with colons is not supported")) - end - iarr = if checkbounds(Bool, iarr.mask, inds...) - iarr - else - _resize_partialarray(iarr, inds) - end - new_data = setindex!!(iarr.data, value, inds...) - if _is_multiindex(inds) - iarr.mask[inds...] .= true - else - iarr.mask[inds...] = true + throw(ArgumentError("Indexing PartialArrays with Colon is not supported")) end - return PartialArray(new_data, iarr.mask) + return nothing end -function Base.getindex(iarr::PartialArray, inds::Vararg{INDEX_TYPES}) - if length(inds) != ndims(iarr) - throw(ArgumentError("Invalid index $(inds)")) +function _getindex(pa::PartialArray, inds::Vararg{INDEX_TYPES}) + _check_index_validity(pa, inds) + if !_haskey(pa, inds) + throw(BoundsError(pa, inds)) end - if _has_colon(inds) - throw(ArgumentError("Indexing with colons is not supported")) - end - if !haskey(iarr, inds) - throw(BoundsError(iarr, inds)) - end - return getindex(iarr.data, inds...) + return getindex(pa.data, inds...) end -function Base.haskey(iarr::PartialArray, inds) - if _has_colon(inds) - throw(ArgumentError("Indexing with colons is not supported")) +function _haskey(pa::PartialArray, inds::NTuple{N,INDEX_TYPES}) where {N} + _check_index_validity(pa, inds) + return checkbounds(Bool, pa.mask, inds...) && all(@inbounds(getindex(pa.mask, inds...))) +end + +function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) + _check_index_validity(pa, inds) + pa = if checkbounds(Bool, pa.mask, inds...) + pa + else + _resize_partialarray!!(pa, inds) end - return checkbounds(Bool, iarr.mask, inds...) && - all(@inbounds(getindex(iarr.mask, inds...))) + new_data = setindex!!(pa.data, value, inds...) + if _is_multiindex(inds) + pa.mask[inds...] .= true + else + pa.mask[inds...] = true + end + return PartialArray(new_data, pa.mask) end Base.merge(x1::PartialArray, x2::PartialArray) = _merge_recursive(x1, x2) -Base.merge(x1::VarNamedTuple, x2::VarNamedTuple) = _merge_recursive(x1, x2) -_merge_recursive(_, x2) = x2 function _merge_element_recursive(x1::PartialArray, x2::PartialArray, ind::CartesianIndex) m1 = x1.mask[ind] @@ -193,7 +310,7 @@ function _merge_recursive(pa1::PartialArray, pa2::PartialArray) num_dims = ndims(pa1) merge_size = ntuple(i -> max(_internal_size(pa1, i), _internal_size(pa2, i)), num_dims) result = if merge_size == _internal_size(pa2) - # Either pa2 is strictly bigger than pa1, or they are equal in size. + # Either pa2 is strictly bigger than pa1 or they are equal in size. result = copy(pa2) for i in CartesianIndices(pa1.data) @inbounds if pa1.mask[i] @@ -240,6 +357,16 @@ function _merge_recursive(pa1::PartialArray, pa2::PartialArray) return result end +struct VarNamedTuple{Names,Values} + data::NamedTuple{Names,Values} +end + +# TODO(mhauru) Since I define this, should I also define `isequal` and `hash`? Same for +# PartialArrays. +Base.:(==)(vnt1::VarNamedTuple, vnt2::VarNamedTuple) = vnt1.data == vnt2.data + +Base.merge(x1::VarNamedTuple, x2::VarNamedTuple) = _merge_recursive(x1, x2) + function make_leaf(value, ::PropertyLens{S}) where {S} return VarNamedTuple(NamedTuple{(S,)}((value,))) end @@ -255,7 +382,7 @@ function make_leaf(value, optic::IndexLens{T}) where {T} # Check if any of the indices are ranges or colons. If yes, value needs to be an # AbstractArray. Otherwise it needs to be an individual value. et = _is_multiindex(optic.indices) ? eltype(value) : typeof(value) - iarr = PartialArray(et, num_inds) + iarr = PartialArray{et,num_inds}() return setindex!!(iarr, value, optic) end @@ -273,62 +400,35 @@ function Base.show(io::IO, vnt::VarNamedTuple) return print(io, ")") end -Base.getindex(vnt::VarNamedTuple, name::Symbol) = vnt.data[name] +_getindex(vnt::VarNamedTuple, name::Symbol) = vnt.data[name] function varname_to_lens(name::VarName{S}) where {S} return _compose_no_identity(getoptic(name), PropertyLens{S}()) end -function Base.getindex(vnt::VarNamedTuple, name::VarName) - return getindex(vnt, varname_to_lens(name)) +function _getindex(vnt::VarNamedTuple, name::VarName) + return _getindex(vnt, varname_to_lens(name)) end -function Base.getindex(x::Union{VarNamedTuple,PartialArray}, optic::ComposedFunction) - subdata = getindex(x, optic.inner) - return getindex(subdata, optic.outer) -end -function Base.getindex(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} - return getindex(vnt.data, S) -end - -function Base.haskey(vnt::VarNamedTuple, name::VarName) - return haskey(vnt, varname_to_lens(name)) +function _getindex(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} + return _getindex(vnt.data, S) end -Base.haskey(vnt::VarNamedTuple, ::typeof(identity)) = true - -function Base.haskey(vnt::VarNamedTuple, optic::ComposedFunction) - return haskey(vnt, optic.inner) && haskey(getindex(vnt, optic.inner), optic.outer) +function _haskey(vnt::VarNamedTuple, name::VarName) + return _haskey(vnt, varname_to_lens(name)) end -Base.haskey(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} = haskey(vnt.data, S) -Base.haskey(::VarNamedTuple, ::IndexLens) = false - -# TODO(mhauru) This is type piracy. -Base.getindex(arr::AbstractArray, optic::IndexLens) = getindex(arr, optic.indices...) +_haskey(vnt::VarNamedTuple, ::typeof(identity)) = true -# TODO(mhauru) This is type piracy. -function BangBang.setindex!!(arr::AbstractArray, value, optic::IndexLens) - return BangBang.setindex!!(arr, value, optic.indices...) -end +_haskey(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} = _haskey(vnt.data, S) +_haskey(::VarNamedTuple, ::IndexLens) = false -function BangBang.setindex!!(vnt::VarNamedTuple, value, name::VarName) - return BangBang.setindex!!(vnt, value, varname_to_lens(name)) -end - -function BangBang.setindex!!( - vnt::Union{VarNamedTuple,PartialArray}, value, optic::ComposedFunction -) - sub = if haskey(vnt, optic.inner) - BangBang.setindex!!(getindex(vnt, optic.inner), value, optic.outer) - else - make_leaf(value, optic.outer) - end - return BangBang.setindex!!(vnt, sub, optic.inner) +function _setindex!!(vnt::VarNamedTuple, value, name::VarName) + return _setindex!!(vnt, value, varname_to_lens(name)) end -function BangBang.setindex!!(vnt::VarNamedTuple, value, ::PropertyLens{S}) where {S} +function _setindex!!(vnt::VarNamedTuple, value, ::PropertyLens{S}) where {S} # I would like this to just read - # return VarNamedTuple(BangBang.setindex!!(vnt.data, value, S)) + # return VarNamedTuple(_setindex!!(vnt.data, value, S)) # but that seems to be type unstable. Why? Shouldn't it obviously be the same as the # below? return VarNamedTuple(merge(vnt.data, NamedTuple{(S,)}((value,)))) @@ -338,9 +438,9 @@ function apply(func, vnt::VarNamedTuple, name::VarName) if !haskey(vnt.data, name.name) throw(KeyError(repr(name))) end - subdata = getindex(vnt, name) + subdata = _getindex(vnt, name) new_subdata = func(subdata) - return BangBang.setindex!!(vnt, new_subdata, name) + return _setindex!!(vnt, new_subdata, name) end function Base.map(func, vnt::VarNamedTuple) @@ -365,7 +465,7 @@ function Base.keys(vnt::VarNamedTuple) return result end -function Base.haskey(vnt::VarNamedTuple, name::VarName{S,Optic}) where {S,Optic} +function _haskey(vnt::VarNamedTuple, name::VarName{S,Optic}) where {S,Optic} if !haskey(vnt.data, S) return false end @@ -403,4 +503,35 @@ function _merge_recursive(vnt1::VarNamedTuple, vnt2::VarNamedTuple) return VarNamedTuple(result_data) end +# The following methods, indexing with ComposedFunction, are exactly the same for +# VarNamedTuple and PartialArray, since they just fall back on indexing with the outer and +# inner lenses. +const VNT_OR_PA = Union{VarNamedTuple,PartialArray} + +function _getindex(x::VNT_OR_PA, optic::ComposedFunction) + subdata = _getindex(x, optic.inner) + return _getindex(subdata, optic.outer) +end + +function _setindex!!(vnt::VNT_OR_PA, value, optic::ComposedFunction) + sub = if _haskey(vnt, optic.inner) + _setindex!!(_getindex(vnt, optic.inner), value, optic.outer) + else + make_leaf(value, optic.outer) + end + return _setindex!!(vnt, sub, optic.inner) +end + +function _haskey(vnt::VNT_OR_PA, optic::ComposedFunction) + return _haskey(vnt, optic.inner) && _haskey(_getindex(vnt, optic.inner), optic.outer) +end + +# The entry points for getting, setting, and checking, using the familiar functions. +Base.haskey(vnt::VarNamedTuple, key) = _haskey(vnt, key) +Base.getindex(vnt::VarNamedTuple, inds...) = _getindex(vnt, inds...) +BangBang.setindex!!(vnt::VarNamedTuple, value, inds...) = _setindex!!(vnt, value, inds...) +Base.haskey(vnt::PartialArray, key) = _haskey(vnt, key) +Base.getindex(vnt::PartialArray, inds...) = _getindex(vnt, inds...) +BangBang.setindex!!(vnt::PartialArray, value, inds...) = _setindex!!(vnt, value, inds...) + end diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 02ed3bca8..8cbf10a64 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -1,18 +1,32 @@ module VarNamedTupleTests using Test: @inferred, @test, @test_throws, @testset -using DynamicPPL: @varname, VarNamedTuple +using DynamicPPL: DynamicPPL, @varname, VarNamedTuple using BangBang: setindex!! @testset "VarNamedTuple" begin + @testset "Construction" begin + pa1 = DynamicPPL.VarNamedTuples.PartialArray{Float64,1}() + pa1 = setindex!!(pa1, 1.0, 16) + pa2 = DynamicPPL.VarNamedTuples.PartialArray{Float64,1}((16,)) + pa2 = setindex!!(pa2, 1.0, 16) + @test pa1 == pa2 + end + @testset "Basic sets and gets" begin vnt = VarNamedTuple() vnt = @inferred(setindex!!(vnt, 32.0, @varname(a))) @test @inferred(getindex(vnt, @varname(a))) == 32.0 + @test haskey(vnt, @varname(a)) + @test !haskey(vnt, @varname(b)) vnt = @inferred(setindex!!(vnt, [1, 2, 3], @varname(b))) @test @inferred(getindex(vnt, @varname(b))) == [1, 2, 3] @test @inferred(getindex(vnt, @varname(b[2]))) == 2 + @test haskey(vnt, @varname(b)) + @test haskey(vnt, @varname(b[1])) + @test haskey(vnt, @varname(b[1:3])) + @test !haskey(vnt, @varname(b[4])) vnt = @inferred(setindex!!(vnt, 64.0, @varname(a))) @test @inferred(getindex(vnt, @varname(a))) == 64.0 @@ -42,6 +56,8 @@ using BangBang: setindex!! vnt = @inferred(setindex!!(vnt, 1.0, @varname(e.f[3].g.h[2].i))) @test @inferred(getindex(vnt, @varname(e.f[3].g.h[2].i))) == 1.0 + @test haskey(vnt, @varname(e.f[3].g.h[2].i)) + @test !haskey(vnt, @varname(e.f[2].g.h[2].i)) vnt = @inferred(setindex!!(vnt, 2.0, @varname(e.f[3].g.h[2].i))) @test @inferred(getindex(vnt, @varname(e.f[3].g.h[2].i))) == 2.0 From 8f17dcf5fdf93a9d011a57b2b7b44cb2edd44981 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 25 Nov 2025 15:25:05 +0000 Subject: [PATCH 020/148] Make PartialArray more type stable. --- src/varnamedtuple.jl | 36 +++++++++++++++++++++++++++++++++++- test/varnamedtuple.jl | 14 ++++++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 79a09f678..aed87017e 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -179,6 +179,40 @@ function Base.hash(pa::PartialArray, h::UInt) return h end +""" + _concretise_eltype!!(pa::PartialArray) + +Concretise the element type of a `PartialArray`. + +Returns a new `PartialArray` with the same data and mask as `pa`, but with its element type +concretised to the most specific type that can hold all currently defined elements. + +Note that this function is fundamentally type unstable if the current element type of `pa` +is not already concrete. + +The name has a `!!` not because it mutates its argument, but because the return value +aliases memory with the argument, and is thus not independent of it. +""" +function _concretise_eltype!!(pa::PartialArray) + if isconcretetype(eltype(pa)) + return pa + end + new_et = promote_type((typeof(pa.data[i]) for i in eachindex(pa.mask) if pa.mask[i])...) + # TODO(mhauru) Should we check as below, or rather isconcretetype(new_et)? + # In other words, does it help to be more concrete, even if we aren't fully concrete? + if new_et === eltype(pa) + # The types of the elements do not allow for concretisation. + return pa + end + new_data = Array{new_et,ndims(pa)}(undef, _internal_size(pa)) + @inbounds for i in eachindex(pa.mask) + if pa.mask[i] + new_data[i] = pa.data[i] + end + end + return PartialArray(new_data, pa.mask) +end + """Return the length needed in a dimension given an index.""" _length_needed(i::Integer) = i _length_needed(r::UnitRange) = last(r) @@ -283,7 +317,7 @@ function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) else pa.mask[inds...] = true end - return PartialArray(new_data, pa.mask) + return _concretise_eltype!!(PartialArray(new_data, pa.mask)) end Base.merge(x1::PartialArray, x2::PartialArray) = _merge_recursive(x1, x2) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 8cbf10a64..08a65b018 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -101,6 +101,20 @@ using BangBang: setindex!! vnt = @inferred(setindex!!(vnt, 1.0, @varname(m[3]))) @test @inferred(getindex(vnt, @varname(m[2:3]))) == [1.0, 1.0] @test !haskey(vnt, @varname(m[1])) + + # The below tests are mostly significant for the type stability aspect. For the last + # test to pass, PartialArray needs to actively tighten its eltype when possible. + vnt = @inferred(setindex!!(vnt, 1.0, @varname(n[1].a))) + @test @inferred(getindex(vnt, @varname(n[1].a))) == 1.0 + vnt = @inferred(setindex!!(vnt, 1.0, @varname(n[2].a))) + @test @inferred(getindex(vnt, @varname(n[2].a))) == 1.0 + # This can't be type stable, because n[1] has inhomogeneous types. + vnt = setindex!!(vnt, 1.0, @varname(n[1].b)) + @test getindex(vnt, @varname(n[1].b)) == 1.0 + # The setindex!! call can't be type stable either, but it should return a + # VarNamedTuple with a concrete element type, and hence getindex can be inferred. + vnt = setindex!!(vnt, 1.0, @varname(n[2].b)) + @test @inferred(getindex(vnt, @varname(n[2].b))) == 1.0 end @testset "equality" begin From 8547e250193496fcfedd0d9d950fe69dae65237c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 25 Nov 2025 15:39:37 +0000 Subject: [PATCH 021/148] Implement `predict`, `returned`, `logjoint`, ... with `OnlyAccsVarInfo` (#1130) * Use OnlyAccsVarInfo for many re-evaluation functions * drop `fast_` prefix * Add a changelog --- HISTORY.md | 5 + ext/DynamicPPLMCMCChainsExt.jl | 184 ++++++++++++++------------------- src/chains.jl | 26 ----- src/logdensityfunction.jl | 16 +-- src/model.jl | 15 +-- 5 files changed, 102 insertions(+), 144 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 91306c219..ff28349d8 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -32,9 +32,14 @@ You should not need to use these directly, please use `AbstractPPL.condition` an Removed the method `returned(::Model, values, keys)`; please use `returned(::Model, ::AbstractDict{<:VarName})` instead. +The unexported functions `supports_varname_indexing(chain)`, `getindex_varname(chain)`, and `varnames(chain)` have been removed. + The method `DynamicPPL.init` (for implementing `AbstractInitStrategy`) now has a different signature: it must return a tuple of the generated value, plus a transform function that maps it back to unlinked space. This is a generalisation of the previous behaviour, where `init` would always return an unlinked value (in effect forcing the transform to be the identity function). +The family of functions `returned(model, chain)`, along with the same signatures of `pointwise_logdensities`, `logjoint`, `loglikelihood`, and `logprior`, have been changed such that if the chain does not contain all variables in the model, an error is thrown. +Previously the behaviour would have been to sample missing variables. + ## 0.38.9 Remove warning when using Enzyme as the AD backend. diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index e74f0b8a9..8ad828648 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -1,41 +1,19 @@ module DynamicPPLMCMCChainsExt -using DynamicPPL: DynamicPPL, AbstractPPL, AbstractMCMC +using DynamicPPL: DynamicPPL, AbstractPPL, AbstractMCMC, Random using MCMCChains: MCMCChains -_has_varname_to_symbol(info::NamedTuple{names}) where {names} = :varname_to_symbol in names - -function DynamicPPL.supports_varname_indexing(chain::MCMCChains.Chains) - return _has_varname_to_symbol(chain.info) -end - -function _check_varname_indexing(c::MCMCChains.Chains) - return DynamicPPL.supports_varname_indexing(c) || - error("This `Chains` object does not support indexing using `VarName`s.") -end - -function DynamicPPL.getindex_varname( +function getindex_varname( c::MCMCChains.Chains, sample_idx, vn::DynamicPPL.VarName, chain_idx ) - _check_varname_indexing(c) return c[sample_idx, c.info.varname_to_symbol[vn], chain_idx] end -function DynamicPPL.varnames(c::MCMCChains.Chains) - _check_varname_indexing(c) +function get_varnames(c::MCMCChains.Chains) + haskey(c.info, :varname_to_symbol) || + error("This `Chains` object does not support indexing using `VarName`s.") return keys(c.info.varname_to_symbol) end -function chain_sample_to_varname_dict( - c::MCMCChains.Chains{Tval}, sample_idx, chain_idx -) where {Tval} - _check_varname_indexing(c) - d = Dict{DynamicPPL.VarName,Tval}() - for vn in DynamicPPL.varnames(c) - d[vn] = DynamicPPL.getindex_varname(c, sample_idx, vn, chain_idx) - end - return d -end - """ AbstractMCMC.from_samples( ::Type{MCMCChains.Chains}, @@ -118,8 +96,8 @@ function AbstractMCMC.to_samples( # Get parameters params_matrix = map(idxs) do (sample_idx, chain_idx) d = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}() - for vn in DynamicPPL.varnames(chain) - d[vn] = DynamicPPL.getindex_varname(chain, sample_idx, vn, chain_idx) + for vn in get_varnames(chain) + d[vn] = getindex_varname(chain, sample_idx, vn, chain_idx) end d end @@ -177,6 +155,46 @@ function AbstractMCMC.bundle_samples( return sort_chain ? sort(chain) : chain end +""" + reevaluate_with_chain( + rng::AbstractRNG, + model::Model, + chain::MCMCChains.Chains + accs::NTuple{N,AbstractAccumulator}; + fallback=nothing, + ) + +Re-evaluate `model` for each sample in `chain` using the accumulators provided in `at`, +returning an matrix of `(retval, updated_at)` tuples. + +This loops over all entries in the chain and uses `DynamicPPL.InitFromParams` as the +initialisation strategy when re-evaluating the model. For many usecases the fallback should +not be provided (as we expect the chain to contain all necessary variables); but for +`predict` this has to be `InitFromPrior()` to allow sampling new variables (i.e. generating +the posterior predictions). +""" +function reevaluate_with_chain( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + chain::MCMCChains.Chains, + accs::NTuple{N,DynamicPPL.AbstractAccumulator}, + fallback::Union{DynamicPPL.AbstractInitStrategy,Nothing}=nothing, +) where {N} + params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain) + vi = DynamicPPL.OnlyAccsVarInfo(DynamicPPL.AccumulatorTuple(accs)) + return map(params_with_stats) do ps + DynamicPPL.init!!(rng, model, vi, DynamicPPL.InitFromParams(ps.params, fallback)) + end +end +function reevaluate_with_chain( + model::DynamicPPL.Model, + chain::MCMCChains.Chains, + accs::NTuple{N,DynamicPPL.AbstractAccumulator}, + fallback::Union{DynamicPPL.AbstractInitStrategy,Nothing}=nothing, +) where {N} + return reevaluate_with_chain(Random.default_rng(), model, chain, accs, fallback) +end + """ predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false) @@ -245,30 +263,18 @@ function DynamicPPL.predict( include_all=false, ) parameter_only_chain = MCMCChains.get_sections(chain, :parameters) - - # Set up a VarInfo with the right accumulators - varinfo = DynamicPPL.setaccs!!( - DynamicPPL.VarInfo(), - ( - DynamicPPL.LogPriorAccumulator(), - DynamicPPL.LogLikelihoodAccumulator(), - DynamicPPL.ValuesAsInModelAccumulator(false), - ), + accs = ( + DynamicPPL.LogPriorAccumulator(), + DynamicPPL.LogLikelihoodAccumulator(), + DynamicPPL.ValuesAsInModelAccumulator(false), ) - _, varinfo = DynamicPPL.init!!(model, varinfo) - varinfo = DynamicPPL.typed_varinfo(varinfo) - - params_and_stats = AbstractMCMC.to_samples( - DynamicPPL.ParamsWithStats, parameter_only_chain + predictions = map( + DynamicPPL.ParamsWithStats ∘ last, + reevaluate_with_chain( + rng, model, parameter_only_chain, accs, DynamicPPL.InitFromPrior() + ), ) - predictions = map(params_and_stats) do ps - _, varinfo = DynamicPPL.init!!( - rng, model, varinfo, DynamicPPL.InitFromParams(ps.params) - ) - DynamicPPL.ParamsWithStats(varinfo) - end chain_result = AbstractMCMC.from_samples(MCMCChains.Chains, predictions) - parameter_names = if include_all MCMCChains.names(chain_result, :parameters) else @@ -348,18 +354,7 @@ julia> returned(model, chain) """ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Chains) chain = MCMCChains.get_sections(chain_full, :parameters) - varinfo = DynamicPPL.VarInfo(model) - iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) - params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain) - return map(params_with_stats) do ps - first( - DynamicPPL.init!!( - model, - varinfo, - DynamicPPL.InitFromParams(ps.params, DynamicPPL.InitFromPrior()), - ), - ) - end + return map(first, reevaluate_with_chain(model, chain, (), nothing)) end """ @@ -452,24 +447,13 @@ function DynamicPPL.pointwise_logdensities( ::Type{Tout}=MCMCChains.Chains, ::Val{whichlogprob}=Val(:both), ) where {whichlogprob,Tout} - vi = DynamicPPL.VarInfo(model) acc = DynamicPPL.PointwiseLogProbAccumulator{whichlogprob}() accname = DynamicPPL.accumulator_name(acc) - vi = DynamicPPL.setaccs!!(vi, (acc,)) parameter_only_chain = MCMCChains.get_sections(chain, :parameters) - iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) - pointwise_logps = map(iters) do (sample_idx, chain_idx) - # Extract values from the chain - values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx) - # Re-evaluate the model - _, vi = DynamicPPL.init!!( - model, - vi, - DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()), - ) - DynamicPPL.getacc(vi, Val(accname)).logps - end - + pointwise_logps = + map(reevaluate_with_chain(model, parameter_only_chain, (acc,), nothing)) do (_, vi) + DynamicPPL.getacc(vi, Val(accname)).logps + end # pointwise_logps is a matrix of OrderedDicts all_keys = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}() for d in pointwise_logps @@ -556,15 +540,15 @@ julia> logjoint(demo_model([1., 2.]), chain) ``` """ function DynamicPPL.logjoint(model::DynamicPPL.Model, chain::MCMCChains.Chains) - var_info = DynamicPPL.VarInfo(model) # extract variables info from the model - map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) - argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}( - vn_parent => DynamicPPL.values_from_chain( - var_info, vn_parent, chain, chain_idx, iteration_idx - ) for vn_parent in keys(var_info) - ) - DynamicPPL.logjoint(model, argvals_dict) - end + return map( + DynamicPPL.getlogjoint ∘ last, + reevaluate_with_chain( + model, + chain, + (DynamicPPL.LogPriorAccumulator(), DynamicPPL.LogLikelihoodAccumulator()), + nothing, + ), + ) end """ @@ -596,15 +580,12 @@ julia> loglikelihood(demo_model([1., 2.]), chain) ``` """ function DynamicPPL.loglikelihood(model::DynamicPPL.Model, chain::MCMCChains.Chains) - var_info = DynamicPPL.VarInfo(model) # extract variables info from the model - map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) - argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}( - vn_parent => DynamicPPL.values_from_chain( - var_info, vn_parent, chain, chain_idx, iteration_idx - ) for vn_parent in keys(var_info) - ) - DynamicPPL.loglikelihood(model, argvals_dict) - end + return map( + DynamicPPL.getloglikelihood ∘ last, + reevaluate_with_chain( + model, chain, (DynamicPPL.LogLikelihoodAccumulator(),), nothing + ), + ) end """ @@ -637,15 +618,10 @@ julia> logprior(demo_model([1., 2.]), chain) ``` """ function DynamicPPL.logprior(model::DynamicPPL.Model, chain::MCMCChains.Chains) - var_info = DynamicPPL.VarInfo(model) # extract variables info from the model - map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) - argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}( - vn_parent => DynamicPPL.values_from_chain( - var_info, vn_parent, chain, chain_idx, iteration_idx - ) for vn_parent in keys(var_info) - ) - DynamicPPL.logprior(model, argvals_dict) - end + return map( + DynamicPPL.getlogprior ∘ last, + reevaluate_with_chain(model, chain, (DynamicPPL.LogPriorAccumulator(),), nothing), + ) end end diff --git a/src/chains.jl b/src/chains.jl index f176b8e68..2fcd4e713 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -1,29 +1,3 @@ -""" - supports_varname_indexing(chain::AbstractChains) - -Return `true` if `chain` supports indexing using `VarName` in place of the -variable name index. -""" -supports_varname_indexing(::AbstractChains) = false - -""" - getindex_varname(chain::AbstractChains, sample_idx, varname::VarName, chain_idx) - -Return the value of `varname` in `chain` at `sample_idx` and `chain_idx`. - -Whether this method is implemented for `chains` is indicated by [`supports_varname_indexing`](@ref). -""" -function getindex_varname end - -""" - varnames(chains::AbstractChains) - -Return an iterator over the varnames present in `chains`. - -Whether this method is implemented for `chains` is indicated by [`supports_varname_indexing`](@ref). -""" -function varnames end - """ ParamsWithStats diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 65eab448e..bcdd0bb25 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -193,21 +193,21 @@ end # LogDensityProblems.jl interface # ################################### """ - fast_ldf_accs(getlogdensity::Function) + ldf_accs(getlogdensity::Function) Determine which accumulators are needed for fast evaluation with the given `getlogdensity` function. """ -fast_ldf_accs(::Function) = default_accumulators() -fast_ldf_accs(::typeof(getlogjoint_internal)) = default_accumulators() -function fast_ldf_accs(::typeof(getlogjoint)) +ldf_accs(::Function) = default_accumulators() +ldf_accs(::typeof(getlogjoint_internal)) = default_accumulators() +function ldf_accs(::typeof(getlogjoint)) return AccumulatorTuple((LogPriorAccumulator(), LogLikelihoodAccumulator())) end -function fast_ldf_accs(::typeof(getlogprior_internal)) +function ldf_accs(::typeof(getlogprior_internal)) return AccumulatorTuple((LogPriorAccumulator(), LogJacobianAccumulator())) end -fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),)) -fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),)) +ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),)) +ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),)) struct LogDensityAt{M<:Model,F<:Function,N<:NamedTuple} model::M @@ -219,7 +219,7 @@ function (f::LogDensityAt)(params::AbstractVector{<:Real}) strategy = InitFromParams( VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing ) - accs = fast_ldf_accs(f.getlogdensity) + accs = ldf_accs(f.getlogdensity) _, vi = DynamicPPL.init!!(f.model, OnlyAccsVarInfo(accs), strategy) return f.getlogdensity(vi) end diff --git a/src/model.jl b/src/model.jl index 9029318b1..7d5bbf2fb 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1181,12 +1181,15 @@ julia> returned(model, Dict{VarName,Float64}(@varname(m) => 2.0)) ``` """ function returned(model::Model, parameters::Union{NamedTuple,AbstractDict{<:VarName}}) - vi = DynamicPPL.setaccs!!(VarInfo(), ()) # Note: we can't use `fix(model, parameters)` because # https://github.com/TuringLang/DynamicPPL.jl/issues/1097 - # Use `nothing` as the fallback to ensure that any missing parameters cause an error - ctx = InitContext(Random.default_rng(), InitFromParams(parameters, nothing)) - new_model = setleafcontext(model, ctx) - # We can't use new_model() because that overwrites it with an InitContext of its own. - return first(evaluate!!(new_model, vi)) + return first( + init!!( + model, + DynamicPPL.OnlyAccsVarInfo(DynamicPPL.AccumulatorTuple()), + # Use `nothing` as the fallback to ensure that any missing parameters cause an + # error + InitFromParams(parameters, nothing), + ), + ) end From 04b3383dafcfa9beb14cee411e42a9d2794043c3 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 25 Nov 2025 16:17:06 +0000 Subject: [PATCH 022/148] Fixes and improvements to VNT --- src/varnamedtuple.jl | 199 ++++++++++++++++++++++++------------------ test/varnamedtuple.jl | 81 +++++++++++++++++ 2 files changed, 194 insertions(+), 86 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index aed87017e..c8c7883dd 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -391,37 +391,57 @@ function _merge_recursive(pa1::PartialArray, pa2::PartialArray) return result end -struct VarNamedTuple{Names,Values} - data::NamedTuple{Names,Values} +function Base.keys(pa::PartialArray) + inds = findall(pa.mask) + lenses = map(x -> IndexLens(Tuple(x)), inds) + ks = Any[] + for l in lenses + val = getindex(pa.data, l.indices...) + if val isa VarNamedTuple + subkeys = keys(val) + for vn in subkeys + lens = varname_to_lens(vn) + push!(ks, _compose_no_identity(lens, l)) + end + else + push!(ks, l) + end + end + return ks end -# TODO(mhauru) Since I define this, should I also define `isequal` and `hash`? Same for -# PartialArrays. -Base.:(==)(vnt1::VarNamedTuple, vnt2::VarNamedTuple) = vnt1.data == vnt2.data +""" + VarNamedTuple{Names,Values} -Base.merge(x1::VarNamedTuple, x2::VarNamedTuple) = _merge_recursive(x1, x2) +A `NamedTuple`-like structure with `VarName` keys. -function make_leaf(value, ::PropertyLens{S}) where {S} - return VarNamedTuple(NamedTuple{(S,)}((value,))) -end -make_leaf(value, ::typeof(identity)) = value -function make_leaf(value, optic::ComposedFunction) - sub = make_leaf(value, optic.outer) - return make_leaf(sub, optic.inner) -end +`VarNamedTuple` is a data structure for storing arbitrary data, keyed by `VarName`s, in an +efficient and type stable manner. It is mainly used through `getindex`, `setindex!!`, and +`haskey`, all of which accept `VarName`s and only `VarName`s as keys. Other notable methods +are `merge`, which recursively merges two `VarNamedTuple`s. -function make_leaf(value, optic::IndexLens{T}) where {T} - inds = optic.indices - num_inds = length(inds) - # Check if any of the indices are ranges or colons. If yes, value needs to be an - # AbstractArray. Otherwise it needs to be an individual value. - et = _is_multiindex(optic.indices) ? eltype(value) : typeof(value) - iarr = PartialArray{et,num_inds}() - return setindex!!(iarr, value, optic) +The one major limitation is that indexing by `VarName`s with `Colon`s, (e.g. `a[:]`) is not +supported. This is because the meaning of `a[:]` is ambiguous if only some elements of `a`, +say `a[1]` and `a[3]`, are defined. + +`setindex!!` and `getindex` on `VarNamedTuple` are type stable as long as one does not store +heterogeneous data under different indices of the same symbol. That is, if one either + +* sets `a[1]` and `a[2]` to be of different types, or +* sets `a[1].b` and `a[2].c`, without setting `a[1].c`. or `a[2].b`, + +then getting values for `a[1]` or `a[2]` will not be type stable. +""" +struct VarNamedTuple{Names,Values} + data::NamedTuple{Names,Values} end VarNamedTuple() = VarNamedTuple((;)) +Base.:(==)(vnt1::VarNamedTuple, vnt2::VarNamedTuple) = vnt1.data == vnt2.data +Base.hash(vnt::VarNamedTuple, h::UInt) = hash(vnt.data, h) + +# TODO(mhauru) Rework this printing. function Base.show(io::IO, vnt::VarNamedTuple) print(io, "(") for (i, (name, value)) in enumerate(pairs(vnt.data)) @@ -434,26 +454,22 @@ function Base.show(io::IO, vnt::VarNamedTuple) return print(io, ")") end -_getindex(vnt::VarNamedTuple, name::Symbol) = vnt.data[name] +""" + varname_to_lens(name::VarName{S}) where {S} +Convert a `VarName` to an `Accessor` lens, wrapping the first symdol in a `PropertyLens`. +""" function varname_to_lens(name::VarName{S}) where {S} return _compose_no_identity(getoptic(name), PropertyLens{S}()) end -function _getindex(vnt::VarNamedTuple, name::VarName) - return _getindex(vnt, varname_to_lens(name)) -end -function _getindex(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} - return _getindex(vnt.data, S) -end - -function _haskey(vnt::VarNamedTuple, name::VarName) - return _haskey(vnt, varname_to_lens(name)) -end - -_haskey(vnt::VarNamedTuple, ::typeof(identity)) = true +_getindex(vnt::VarNamedTuple, name::VarName) = _getindex(vnt, varname_to_lens(name)) +_getindex(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} = _getindex(vnt.data, S) +_getindex(vnt::VarNamedTuple, name::Symbol) = vnt.data[name] +_haskey(vnt::VarNamedTuple, name::VarName) = _haskey(vnt, varname_to_lens(name)) _haskey(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} = _haskey(vnt.data, S) +_haskey(vnt::VarNamedTuple, ::typeof(identity)) = true _haskey(::VarNamedTuple, ::IndexLens) = false function _setindex!!(vnt::VarNamedTuple, value, name::VarName) @@ -468,8 +484,41 @@ function _setindex!!(vnt::VarNamedTuple, value, ::PropertyLens{S}) where {S} return VarNamedTuple(merge(vnt.data, NamedTuple{(S,)}((value,)))) end -function apply(func, vnt::VarNamedTuple, name::VarName) - if !haskey(vnt.data, name.name) +Base.merge(x1::VarNamedTuple, x2::VarNamedTuple) = _merge_recursive(x1, x2) + +# TODO(mhauru) Check the performance of this, and make it into a generated function if +# necessary. +function _merge_recursive(vnt1::VarNamedTuple, vnt2::VarNamedTuple) + result_data = vnt1.data + for k in keys(vnt2.data) + val = if haskey(result_data, k) + _merge_recursive(result_data[k], vnt2.data[k]) + else + vnt2.data[k] + end + Accessors.@reset result_data[k] = val + end + return VarNamedTuple(result_data) +end + +""" + apply!!(func, vnt::VarNamedTuple, name::VarName) + +Apply `func` to the subdata at `name` in `vnt`, and set the result back at `name`. + +```jldoctest +julia> vnt = VarNamedTuple() +() + +julia> vnt = setindex!!(vnt, [1,2,3], @varname(a)) +(a -> [1, 2, 3]) + +julia> VarNamedTuples.apply!!(x -> x .+ 1, vnt, @varname(a)) +(a -> [2, 3, 4]) +``` +""" +function apply!!(func, vnt::VarNamedTuple, name::VarName) + if !haskey(vnt, name) throw(KeyError(repr(name))) end subdata = _getindex(vnt, name) @@ -477,11 +526,6 @@ function apply(func, vnt::VarNamedTuple, name::VarName) return _setindex!!(vnt, new_subdata, name) end -function Base.map(func, vnt::VarNamedTuple) - new_data = NamedTuple{keys(vnt.data)}(map(func, values(vnt.data))) - return VarNamedTuple(new_data) -end - function Base.keys(vnt::VarNamedTuple) result = () for sym in keys(vnt.data) @@ -489,54 +533,18 @@ function Base.keys(vnt::VarNamedTuple) if subdata isa VarNamedTuple subkeys = keys(subdata) result = ( - (AbstractPPL.prefix(sk, VarName{sym}()) for sk in subkeys)..., result... + result..., (AbstractPPL.prefix(sk, VarName{sym}()) for sk in subkeys)... ) + elseif subdata isa PartialArray + subkeys = keys(subdata) + result = (result..., (VarName{sym}(lens) for lens in subkeys)...) else - result = (VarName{sym}(), result...) + result = (result..., VarName{sym}()) end - subkeys = keys(vnt.data[sym]) end return result end -function _haskey(vnt::VarNamedTuple, name::VarName{S,Optic}) where {S,Optic} - if !haskey(vnt.data, S) - return false - end - subdata = vnt.data[S] - return if Optic === typeof(identity) - true - elseif Optic <: IndexLens - try - AbstractPPL.getoptic(name)(subdata) - true - catch e - if e isa BoundsError || e isa KeyError - false - else - rethrow(e) - end - end - else - haskey(subdata, AbstractPPL.unprefix(name, VarName{S}())) - end -end - -# TODO(mhauru) Check the performance of this, and make it into a generated function if -# necessary. -function _merge_recursive(vnt1::VarNamedTuple, vnt2::VarNamedTuple) - result_data = vnt1.data - for k in keys(vnt2.data) - val = if haskey(result_data, k) - _merge_recursive(result_data[k], vnt2.data[k]) - else - vnt2.data[k] - end - Accessors.@reset result_data[k] = val - end - return VarNamedTuple(result_data) -end - # The following methods, indexing with ComposedFunction, are exactly the same for # VarNamedTuple and PartialArray, since they just fall back on indexing with the outer and # inner lenses. @@ -561,11 +569,30 @@ function _haskey(vnt::VNT_OR_PA, optic::ComposedFunction) end # The entry points for getting, setting, and checking, using the familiar functions. -Base.haskey(vnt::VarNamedTuple, key) = _haskey(vnt, key) -Base.getindex(vnt::VarNamedTuple, inds...) = _getindex(vnt, inds...) -BangBang.setindex!!(vnt::VarNamedTuple, value, inds...) = _setindex!!(vnt, value, inds...) +Base.haskey(vnt::VarNamedTuple, vn::VarName) = _haskey(vnt, vn) +Base.getindex(vnt::VarNamedTuple, vn::VarName) = _getindex(vnt, vn) +BangBang.setindex!!(vnt::VarNamedTuple, value, vn::VarName) = _setindex!!(vnt, value, vn) Base.haskey(vnt::PartialArray, key) = _haskey(vnt, key) Base.getindex(vnt::PartialArray, inds...) = _getindex(vnt, inds...) BangBang.setindex!!(vnt::PartialArray, value, inds...) = _setindex!!(vnt, value, inds...) +function make_leaf(value, ::PropertyLens{S}) where {S} + return VarNamedTuple(NamedTuple{(S,)}((value,))) +end +make_leaf(value, ::typeof(identity)) = value +function make_leaf(value, optic::ComposedFunction) + sub = make_leaf(value, optic.outer) + return make_leaf(sub, optic.inner) +end + +function make_leaf(value, optic::IndexLens{T}) where {T} + inds = optic.indices + num_inds = length(inds) + # Check if any of the indices are ranges or colons. If yes, value needs to be an + # AbstractArray. Otherwise it needs to be an individual value. + et = _is_multiindex(optic.indices) ? eltype(value) : typeof(value) + iarr = PartialArray{et,num_inds}() + return setindex!!(iarr, value, optic) +end + end diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 08a65b018..e3e98d270 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -236,6 +236,87 @@ using BangBang: setindex!! expected_merge_21 = setindex!!(expected_merge_12, 1, @varname(a[1, 1])) @test merge(vnt2, vnt1) == expected_merge_21 end + + @testset "keys" begin + vnt = VarNamedTuple() + @test keys(vnt) == () + @test all(x -> haskey(vnt, x), keys(vnt)) + + vnt = setindex!!(vnt, 1.0, @varname(a)) + @test keys(vnt) == (@varname(a),) + @test all(x -> haskey(vnt, x), keys(vnt)) + + vnt = setindex!!(vnt, [1, 2, 3], @varname(b)) + @test keys(vnt) == (@varname(a), @varname(b)) + @test all(x -> haskey(vnt, x), keys(vnt)) + + vnt = setindex!!(vnt, 15, @varname(b[2])) + @test keys(vnt) == (@varname(a), @varname(b)) + @test all(x -> haskey(vnt, x), keys(vnt)) + + vnt = setindex!!(vnt, [10], @varname(c.x.y)) + @test keys(vnt) == (@varname(a), @varname(b), @varname(c.x.y)) + @test all(x -> haskey(vnt, x), keys(vnt)) + + vnt = setindex!!(vnt, -1.0, @varname(d[4])) + @test keys(vnt) == (@varname(a), @varname(b), @varname(c.x.y), @varname(d[4])) + @test all(x -> haskey(vnt, x), keys(vnt)) + + vnt = setindex!!(vnt, 2.0, @varname(e.f[3, 3].g.h[2, 4, 1].i)) + @test keys(vnt) == ( + @varname(a), + @varname(b), + @varname(c.x.y), + @varname(d[4]), + @varname(e.f[3, 3].g.h[2, 4, 1].i), + ) + @test all(x -> haskey(vnt, x), keys(vnt)) + + vnt = setindex!!(vnt, fill(1.0, 4), @varname(j[1:4])) + @test keys(vnt) == ( + @varname(a), + @varname(b), + @varname(c.x.y), + @varname(d[4]), + @varname(e.f[3, 3].g.h[2, 4, 1].i), + @varname(j[1]), + @varname(j[2]), + @varname(j[3]), + @varname(j[4]), + ) + @test all(x -> haskey(vnt, x), keys(vnt)) + + vnt = setindex!!(vnt, 1.0, @varname(j[6])) + @test keys(vnt) == ( + @varname(a), + @varname(b), + @varname(c.x.y), + @varname(d[4]), + @varname(e.f[3, 3].g.h[2, 4, 1].i), + @varname(j[1]), + @varname(j[2]), + @varname(j[3]), + @varname(j[4]), + @varname(j[6]), + ) + @test all(x -> haskey(vnt, x), keys(vnt)) + + vnt = setindex!!(vnt, 1.0, @varname(n[2].a)) + @test keys(vnt) == ( + @varname(a), + @varname(b), + @varname(c.x.y), + @varname(d[4]), + @varname(e.f[3, 3].g.h[2, 4, 1].i), + @varname(j[1]), + @varname(j[2]), + @varname(j[3]), + @varname(j[4]), + @varname(j[6]), + @varname(n[2].a), + ) + @test all(x -> haskey(vnt, x), keys(vnt)) + end end end From 59c4dcbba214d484faa4cbf76e206b76c38496da Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 25 Nov 2025 17:47:38 +0000 Subject: [PATCH 023/148] Proper printing and constructors --- src/varnamedtuple.jl | 103 +++++++++++++++++++++++++++--------------- test/varnamedtuple.jl | 75 +++++++++++++++++++++++++++++- 2 files changed, 141 insertions(+), 37 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index c8c7883dd..7880275a5 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -8,23 +8,23 @@ using ..DynamicPPL: _compose_no_identity export VarNamedTuple -# We define our own getindex, setindex!!, and haskey functions to be able to override their -# behaviour for some types exported from elsewhere without type piracy. This is needed -# because -# 1. We want to index into things with lenses (from Accessors.jl) using getindex and -# setindex!!. -# 2. We want to use getindex, setindex!!, and haskey as the universal functions for getting, -# setting, checking. This includes e.g. checking whether an index is valid for an Array, -# which would normally be done with checkbounds. -_haskey(x, key) = Base.haskey(x, key) -_getindex(x, inds...) = Base.getindex(x, inds...) -_setindex!!(x, value, inds...) = BangBang.setindex!!(x, value, inds...) -_getindex(arr::AbstractArray, optic::IndexLens) = _getindex(arr, optic.indices...) +# We define our own getindex, setindex!!, and haskey functions, which we use to +# get/set/check values in VarNamedTuple and PartialArray. We do this because we want to be +# able to override their behaviour for some types exported from elsewhere without type +# piracy. This is needed because +# 1. We would want to index into things with lenses (from Accessors.jl) using getindex and +# setindex!!, but Accessors does not define these methods. +# 2. We would want `haskey` to fall back onto `checkbounds` when called on Base.Arrays. +function _getindex end +function _haskey end +function _setindex!! end + +_getindex(arr::AbstractArray, optic::IndexLens) = getindex(arr, optic.indices...) _haskey(arr::AbstractArray, optic::IndexLens) = _haskey(arr, optic.indices) +_haskey(arr::AbstractArray, inds) = checkbounds(Bool, arr, inds...) function _setindex!!(arr::AbstractArray, value, optic::IndexLens) - return _setindex!!(arr, value, optic.indices...) + return setindex!!(arr, value, optic.indices...) end -_haskey(arr::AbstractArray, inds) = checkbounds(Bool, arr, inds...) # Some utilities for checking what sort of indices we are dealing with. _has_colon(::T) where {T<:Tuple} = any(x <: Colon for x in T.parameters) @@ -122,22 +122,44 @@ purely a performance optimisation, to avoid resizing if the eventual size is kno time. """ function PartialArray{ElType,num_dims}( - min_size::Union{Tuple,Nothing}=nothing + args::Vararg{Pair}; min_size::Union{Tuple,Nothing}=nothing ) where {ElType,num_dims} - if min_size === nothing - dims = ntuple(_ -> PARTIAL_ARRAY_DIM_GROWTH_FACTOR, num_dims) + dims = if min_size === nothing + ntuple(_ -> PARTIAL_ARRAY_DIM_GROWTH_FACTOR, num_dims) else - dims = map(_partial_array_dim_size, min_size) + map(_partial_array_dim_size, min_size) end - dims = ntuple(_ -> PARTIAL_ARRAY_DIM_GROWTH_FACTOR, num_dims) data = Array{ElType,num_dims}(undef, dims) mask = fill(false, dims) - return PartialArray(data, mask) + pa = PartialArray(data, mask) + + for (inds, value) in args + pa = _setindex!!(pa, convert(ElType, value), inds...) + end + return pa end Base.ndims(::PartialArray{ElType,num_dims}) where {ElType,num_dims} = num_dims Base.eltype(::PartialArray{ElType}) where {ElType} = ElType +function Base.show(io::IO, pa::PartialArray) + print(io, "PartialArray{", eltype(pa), ",", ndims(pa), "}(") + is_first = true + for inds in CartesianIndices(pa.mask) + if @inbounds(!pa.mask[inds]) + continue + end + if !is_first + print(io, ", ") + is_first = false + end + val = @inbounds(pa.data[inds]) + print(io, Tuple(inds), " => ", val) + end + print(")") + return nothing +end + # We deliberately don't define Base.size for PartialArray, because it is ill-defined. # The size of the .data field is an implementation detail. _internal_size(pa::PartialArray, args...) = size(pa.data, args...) @@ -420,9 +442,10 @@ efficient and type stable manner. It is mainly used through `getindex`, `setinde `haskey`, all of which accept `VarName`s and only `VarName`s as keys. Other notable methods are `merge`, which recursively merges two `VarNamedTuple`s. -The one major limitation is that indexing by `VarName`s with `Colon`s, (e.g. `a[:]`) is not -supported. This is because the meaning of `a[:]` is ambiguous if only some elements of `a`, -say `a[1]` and `a[3]`, are defined. +The there are two major limitations to indexing by VarNamedTuples: + +* `VarName`s with `Colon`s, (e.g. `a[:]`) are not supported. This is because the meaning of `a[:]` is ambiguous if only some elements of `a`, say `a[1]` and `a[3]`, are defined. +* Any `VarNames` with IndexLenses` must have a consistent number of indices. That is, one cannot set `a[1]` and `a[1,2]` in the same `VarNamedTuple`. `setindex!!` and `getindex` on `VarNamedTuple` are type stable as long as one does not store heterogeneous data under different indices of the same symbol. That is, if one either @@ -436,20 +459,18 @@ struct VarNamedTuple{Names,Values} data::NamedTuple{Names,Values} end -VarNamedTuple() = VarNamedTuple((;)) +VarNamedTuple(; kwargs...) = VarNamedTuple((; kwargs...)) Base.:(==)(vnt1::VarNamedTuple, vnt2::VarNamedTuple) = vnt1.data == vnt2.data Base.hash(vnt::VarNamedTuple, h::UInt) = hash(vnt.data, h) -# TODO(mhauru) Rework this printing. function Base.show(io::IO, vnt::VarNamedTuple) - print(io, "(") + print(io, "VarNamedTuple(;") for (i, (name, value)) in enumerate(pairs(vnt.data)) if i > 1 - print(io, ", ") + print(io, ",") end - print(io, name, " -> ") - print(io, value) + print(io, " ", name, "=", value) end return print(io, ")") end @@ -464,11 +485,11 @@ function varname_to_lens(name::VarName{S}) where {S} end _getindex(vnt::VarNamedTuple, name::VarName) = _getindex(vnt, varname_to_lens(name)) -_getindex(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} = _getindex(vnt.data, S) +_getindex(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} = getindex(vnt.data, S) _getindex(vnt::VarNamedTuple, name::Symbol) = vnt.data[name] _haskey(vnt::VarNamedTuple, name::VarName) = _haskey(vnt, varname_to_lens(name)) -_haskey(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} = _haskey(vnt.data, S) +_haskey(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} = haskey(vnt.data, S) _haskey(vnt::VarNamedTuple, ::typeof(identity)) = true _haskey(::VarNamedTuple, ::IndexLens) = false @@ -572,14 +593,24 @@ end Base.haskey(vnt::VarNamedTuple, vn::VarName) = _haskey(vnt, vn) Base.getindex(vnt::VarNamedTuple, vn::VarName) = _getindex(vnt, vn) BangBang.setindex!!(vnt::VarNamedTuple, value, vn::VarName) = _setindex!!(vnt, value, vn) + Base.haskey(vnt::PartialArray, key) = _haskey(vnt, key) Base.getindex(vnt::PartialArray, inds...) = _getindex(vnt, inds...) BangBang.setindex!!(vnt::PartialArray, value, inds...) = _setindex!!(vnt, value, inds...) -function make_leaf(value, ::PropertyLens{S}) where {S} - return VarNamedTuple(NamedTuple{(S,)}((value,))) -end +""" + make_leaf(value, optic) + +Make a new leaf node for a VarNamedTuple. + +This is the function that sets any `optic` that is a `PropertyLens` to be stored as a +`VarNamedTuple`, any `IndexLens` to be stored as a `PartialArray`, and other `identity` +optics to be stored as raw values. It is the link that joins `VarNamedTuple` and +`PartialArray` together. +""" make_leaf(value, ::typeof(identity)) = value +make_leaf(value, ::PropertyLens{S}) where {S} = VarNamedTuple(NamedTuple{(S,)}((value,))) + function make_leaf(value, optic::ComposedFunction) sub = make_leaf(value, optic.outer) return make_leaf(sub, optic.inner) @@ -591,8 +622,8 @@ function make_leaf(value, optic::IndexLens{T}) where {T} # Check if any of the indices are ranges or colons. If yes, value needs to be an # AbstractArray. Otherwise it needs to be an individual value. et = _is_multiindex(optic.indices) ? eltype(value) : typeof(value) - iarr = PartialArray{et,num_inds}() - return setindex!!(iarr, value, optic) + pa = PartialArray{et,num_inds}() + return _setindex!!(pa, value, optic) end end diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index e3e98d270..77edefa9a 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -6,11 +6,43 @@ using BangBang: setindex!! @testset "VarNamedTuple" begin @testset "Construction" begin + vnt1 = VarNamedTuple() + vnt1 = setindex!!(vnt1, 1.0, @varname(a)) + vnt1 = setindex!!(vnt1, [1, 2, 3], @varname(b)) + vnt1 = setindex!!(vnt1, "a", @varname(c.d.e)) + vnt2 = VarNamedTuple(; + a=1.0, b=[1, 2, 3], c=VarNamedTuple(; d=VarNamedTuple(; e="a")) + ) + @test vnt1 == vnt2 + pa1 = DynamicPPL.VarNamedTuples.PartialArray{Float64,1}() pa1 = setindex!!(pa1, 1.0, 16) - pa2 = DynamicPPL.VarNamedTuples.PartialArray{Float64,1}((16,)) + pa2 = DynamicPPL.VarNamedTuples.PartialArray{Float64,1}(; min_size=(16,)) pa2 = setindex!!(pa2, 1.0, 16) + pa3 = DynamicPPL.VarNamedTuples.PartialArray{Float64,1}(16 => 1.0) + pa4 = DynamicPPL.VarNamedTuples.PartialArray{Float64,1}((16,) => 1.0) + @test pa1 == pa2 + @test pa1 == pa3 + @test pa1 == pa4 + + pa1 = DynamicPPL.VarNamedTuples.PartialArray{String,3}() + pa1 = setindex!!(pa1, "a", 2, 3, 4) + pa1 = setindex!!(pa1, "b", 1, 2, 4) + pa2 = DynamicPPL.VarNamedTuples.PartialArray{String,3}(; min_size=(16, 16, 16)) + pa2 = setindex!!(pa2, "a", 2, 3, 4) + pa2 = setindex!!(pa2, "b", 1, 2, 4) + pa3 = DynamicPPL.VarNamedTuples.PartialArray{String,3}( + (2, 3, 4) => "a", (1, 2, 4) => "b" + ) @test pa1 == pa2 + @test pa1 == pa3 + + @test_throws BoundsError DynamicPPL.VarNamedTuples.PartialArray{Int,1}((0,) => 1) + @test_throws BoundsError DynamicPPL.VarNamedTuples.PartialArray{Int,1}((1, 2) => 1) + @test_throws MethodError DynamicPPL.VarNamedTuples.PartialArray{Int,1}((1,) => "a") + @test_throws MethodError DynamicPPL.VarNamedTuples.PartialArray{Int,1}( + (1,) => 1; min_size=(2, 2) + ) end @testset "Basic sets and gets" begin @@ -317,6 +349,47 @@ using BangBang: setindex!! ) @test all(x -> haskey(vnt, x), keys(vnt)) end + + @testset "printing" begin + vnt = VarNamedTuple() + io = IOBuffer() + show(io, vnt) + output = String(take!(io)) + @test output == "VarNamedTuple(;)" + + vnt = setindex!!(vnt, 1.0, @varname(a)) + io = IOBuffer() + show(io, vnt) + output = String(take!(io)) + @test output == "VarNamedTuple(; a=1.0)" + + vnt = setindex!!(vnt, [1, 2, 3], @varname(b)) + io = IOBuffer() + show(io, vnt) + output = String(take!(io)) + @test output == "VarNamedTuple(; a=1.0, b=[1, 2, 3])" + + vnt = setindex!!(vnt, 15, @varname(c[2])) + io = IOBuffer() + show(io, vnt) + output = String(take!(io)) + @test output == """ + VarNamedTuple(; a=1.0, b=[1, 2, 3], c=PartialArray{Int64,1}((2,) => 15)""" + + vnt = setindex!!(vnt, [16.0, 17.0], @varname(d.e[3].f.g[1:2])) + io = IOBuffer() + show(io, vnt) + output = String(take!(io)) + @test output == """ + VarNamedTuple(; a=1.0, b=[1, 2, 3], \ + c=PartialArray{Int64,1}((2,) => 15, \ + d=VarNamedTuple(; \ + e=PartialArray{DynamicPPL.VarNamedTuples.VarNamedTuple{(:f,), \ + Tuple{DynamicPPL.VarNamedTuples.VarNamedTuple{(:g,), \ + Tuple{PartialArray{Float64, 1}}}}},1}((3,) => \ + VarNamedTuple(; f=VarNamedTuple(; g=PartialArray{Float64,1}((1,) => \ + 16.0(2,) => 17.0))))""" + end end end From 381b1dd4b1bb4646e61c962295908cbf015f9ff5 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 25 Nov 2025 19:05:57 +0000 Subject: [PATCH 024/148] Fix PartialArray printing --- src/varnamedtuple.jl | 3 ++- test/varnamedtuple.jl | 14 +++++++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 7880275a5..e47b27e9e 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -151,12 +151,13 @@ function Base.show(io::IO, pa::PartialArray) end if !is_first print(io, ", ") + else is_first = false end val = @inbounds(pa.data[inds]) print(io, Tuple(inds), " => ", val) end - print(")") + print(io, ")") return nothing end diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 77edefa9a..7b26e9be7 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -374,7 +374,7 @@ using BangBang: setindex!! show(io, vnt) output = String(take!(io)) @test output == """ - VarNamedTuple(; a=1.0, b=[1, 2, 3], c=PartialArray{Int64,1}((2,) => 15)""" + VarNamedTuple(; a=1.0, b=[1, 2, 3], c=PartialArray{Int64,1}((2,) => 15))""" vnt = setindex!!(vnt, [16.0, 17.0], @varname(d.e[3].f.g[1:2])) io = IOBuffer() @@ -382,13 +382,13 @@ using BangBang: setindex!! output = String(take!(io)) @test output == """ VarNamedTuple(; a=1.0, b=[1, 2, 3], \ - c=PartialArray{Int64,1}((2,) => 15, \ + c=PartialArray{Int64,1}((2,) => 15), \ d=VarNamedTuple(; \ - e=PartialArray{DynamicPPL.VarNamedTuples.VarNamedTuple{(:f,), \ - Tuple{DynamicPPL.VarNamedTuples.VarNamedTuple{(:g,), \ - Tuple{PartialArray{Float64, 1}}}}},1}((3,) => \ - VarNamedTuple(; f=VarNamedTuple(; g=PartialArray{Float64,1}((1,) => \ - 16.0(2,) => 17.0))))""" + e=PartialArray{VarNamedTuple{(:f,), \ + Tuple{VarNamedTuple{(:g,), \ + Tuple{DynamicPPL.VarNamedTuples.PartialArray{Float64, 1}}}}},1}((3,) => \ + VarNamedTuple(; f=VarNamedTuple(; g=PartialArray{Float64,1}((1,) => 16.0, \ + (2,) => 17.0))))))""" end end From 88db66dd8496d772bb37333aca3ec6096c5e6e83 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 25 Nov 2025 19:06:33 +0000 Subject: [PATCH 025/148] Update the design doc --- docs/src/internals/varnamedtuple.md | 173 ++++++++++++++++------------ 1 file changed, 99 insertions(+), 74 deletions(-) diff --git a/docs/src/internals/varnamedtuple.md b/docs/src/internals/varnamedtuple.md index 9f7a84cdb..0194d05d7 100644 --- a/docs/src/internals/varnamedtuple.md +++ b/docs/src/internals/varnamedtuple.md @@ -1,112 +1,137 @@ -# VarNamedTuple as the basis of VarInfo +# VarNamedTuple -This document collects thoughts and ideas for how to unify our multitude of AbstractVarInfo types using a VarNamedTuple type. It may eventually turn into a draft design document, but for now it is more raw than that. +In DynamicPPL there is often a need to store data keyed by `VarName`s. +This comes up when getting conditioned variable values from the user, when tracking values of random variables in the model outputs or inputs, etc. +Historically we've had several different approaches to this: Dictionaries, NamedTuples, vectors with subranges corresponding to different `VarName`s, and various combinations thereof. -## The current situation +To unify the treatment of these use cases, and handle them all in a robust and performant way, is the purpose of `VarNamedTuple`, aka VNT. +It's a data structure that can store arbitrary data, indexed by (nearly) arbitrary `VarName`s, in a type stable and performant manner. -We currently have the following AbstractVarInfo types: +`VarNamedTuple` consists of nested `NamedTuple`s and `PartialArray`. +Let's first talk about the `NamedTuple` part. +This is what is needed for handling `PropertyLens`es in `VarName`s, that is, `VarName`s consisting of nested symbols, like in `@varname(a.b.c)`. +In a `VarNamedTuple` each level of such nesting of `PropertyLens`es corresponds to a level of nested `NamedTuple`s, with the `Symbol`s of the lens as the keys. +For instance, the `VarNamedTuple` mapping `@varname(x) => 1, @varname(y.z) => 2` would be stored as - - A: VarInfo with Metadata - - B: VarInfo with VarNamedVector - - C: VarInfo with NamedTuple, with values being Metadata - - D: VarInfo with NamedTuple, with values being VarNamedVector - - E: SimpleVarInfo with NamedTuples - - F: SimpleVarInfo with OrderedDict - -A and C are the classic ones, and the defaults. C wraps groups the Metadata objects by the lead Symbol of the VarName of a variable, e.g. `x` in `@varname(x.y[1].z)`, which allows different lead Symbols to have different element types and for the VarInfo to still be type stable. B and D were created to simplify A and C, give them a nicer interface, and make them deal better with changing variable sizes, but according to recent (Oct 2025) benchmarks are quite a lot slower, which needs work. +``` +VarNamedTuple(; x=1, y=VarNamedTuple(; z=2)) +``` -E and F are entirely distinct in implementation from the others. E is simply a mapping from Symbols to values, with each VarName being converted to a single symbol, e.g. `Symbol("a[1]")`. F is a mapping from VarNames to values as an OrderedDict, with VarName as the key type. +where `VarNamedTuple(; x=a, y=b)` is just a thin wrapper around the `NamedTuple` `(; x=a, y=b)`. -A-D carry within them values for variables, but also their bijectors/distributions, and store all values vectorised, using the bijectors to map to the original values. They also store for each variable a flag for whether the variable has been linked. E-F store only the raw values, and a global flag for the whole SimpleVarInfo for whether it's linked. The link transform itself is implicit. +It's often handy to think of this as a tree, with each node being a `VarNamedTuple`, like so: -TODO: Write a better summary of pros and cons of each approach. +``` + VNT +x / \ y + 1 VNT + \ z + 2 +``` -## VarNamedTuple +If all `VarName`s consisted of only `PropertyLens`es we would be done designing the data structure. +However, recall that VarNames allow three different kinds of lenses: `PropertyLens`es, `IndexLens`es, and `identity` (the trivial lens). +The `identity` lens presents no complications, and in fact in the above example there was an implicit identity lens in e.g. `@varname(x) => 1`. +It is the `IndexLenses` that require more structure. -VarNamedTuple has been discussed as a possible data structure to generalise the structure used in VarInfo to achieve type stability, i.e. grouping VarNames by their lead Symbol. The same NamedTuple structure has been used elsewhere, too, e.g. in Turing.GibbsContext. The idea was to encapsulate this structure into its own type, reducing code duplication and making the design more robust and powerful. See https://github.com/TuringLang/DynamicPPL.jl/issues/900 for the discussion. +An `IndexLens` is the indexing layer in `VarName`s like `@varname(x[1])`, `@varname(x[1].a.b[2:3])` and `@varname(x[:].b[1,2,3].c[1:5,:])`. +`VarNamedTuple` can not deal with `IndexLens`es in their full generality, for reasons we'll discuss below. +Instead we restrict ourselves to `IndexLens`es where the indices are integers, explicit ranges with end points, like `1:5`, or tuples thereof. -An AbstractVarInfo type could be only one application of VarNamedTuple, but here I'll focus on it exclusively. If we can make VarNamedTuple work for an AbstractVarInfo, I bet we can make it work for other purposes (condition, fix, Gibbs) as well. +When storing data in a `VarNamedTuple`, we recursively go through the nested lenses in the `VarName`, inserting a new `VarNamedTuple` for every `PropertyLens`. +When we meet an `IndexLens`, we instead instert into the tree something called a `PartialArray`. -Without going into full detail, here's @mhauru's current proposal for what it would look like. This proposal remains in constant flux as I develop the code. +A `PartialArray` is like a regular `Base.Array`, but with some elements possibly unset. +It is a data structure we define ourselves for use within `VarNamedTuple`s. +A `PartialArray` has an element type and a number of dimensions, and they are known at compile time, but it does not have a size, and this thus not an `AbstractArray`. +This is because if we set the elements `x[1,2]` and `x[14,10]` in a `PartialArray` called `x`, this does not mean that 14 and 10 are the ends of their respective dimensions. +The typical use of this structure in DynamicPPL is that the user may define values for elements in an array-like structure one by one, and we do not always know how large these arrays are. -A VarNamedTuple is a mapping of VarNames to values. Values can be anything. In the case of using VarNamedTuple to implement an AbstractVarInfo, the values would be random samples for random variables. However, they could hold with them extra information. For instance, we might use a value that is a tuple of a vectorised value, a bijector, and a flag for whether the variable is linked. +This is also the reason why `PartialArray`, and by extension `VarNamedTuple`, do not support indexing by `Colon()`, i.e. `:`, as in `x[:]`. +A `Colon()` says that we should get or set all the values along that dimension, but a `PartialArray` does not know how many values there may be. +If `x[1]` and `x[4]` have been set, asking for `x[:]` is not a well-posed question. -I sometimes shorten VarNamedTuple to VNT. +`PartialArray`s have other restrictions, compared to the full indexing syntax of Julia, as well: +They do not support linearly indexing into multidimemensional arrays (as in `rand(3,3)[8]`), nor indexing with arrays of indices (as in `rand(4)[[1,3]]`), nor indexing with boolean mask arrays as in `rand(4)[[true, false, true, false]]`). +This is mostly because we haven't seen a need to support them, and implementing would complicate the codebase for little gain. +We may add support for them later if needed. -Internally, a VarNamedTuple consists of nested NamedTuples. For instance, the mapping `@varname(x) => 1, @varname(y.z) => 2` would be stored as +`PartialArray`s can hold any values, just like `Base.Array`s, and in particular they can hold `VarNamedTuple`s. +Thus we nest them with `VarNamedTuple`s to support storing `VarName`s with arbitrary combinations of `PropertyLens`es and `IndexLens`es. +A code example illustrates this the best: -``` -(; x=1, y=(; z=2)) -``` +```julia +julia> vnt = VarNamedTuple(); -(This is a slight simplification, really it would be nested VarNamedTuples rather than NamedTuples, but I omit this detail.) -This forms a tree, with each node being a NamedTuple, like so: +julia> vnt = setindex!!(vnt, 1.0, @varname(a)); -``` - NT -x / \ y - 1 NT - \ z - 2 -``` +julia> vnt = setindex!!(vnt, [2.0, 3.0], @varname(b.c)); -Each `NT` marks a NamedTuple, and the labels on the edges its keys. Here the root node has the keys `x` and `y`. This is like with the type stable VarInfo in our current design, except with possibly more levels (our current one only has the root node). Each nested `PropertyLens`, i.e. each `.` in a VarName like `@varname(a.b.c.e)`, creates a new layer of the tree. +julia> vnt = setindex!!(vnt, [:hip, :hop], @varname(d.e[2].f[3:4])); -For simplicity, at least for now, we ban any VarNames where an `IndexLens` precedes a `PropertyLens`. That is, we ban any VarNames like `@varname(a.b[1].c)`. Recall that VarNames allow three different kinds of lenses: `PropertyLens`es, `IndexLens`es, and `identity` (the trivial lens). Thus the only allowed VarName types are `@varname(a.b.c.d)` and `@varname(a.b.c.d[i,j,k])`. +julia> print(vnt) +VarNamedTuple(; a=1.0, b=VarNamedTuple(; c=[2.0, 3.0]), d=VarNamedTuple(; e=PartialArray{VarNamedTuple{(:f,), Tuple{DynamicPPL.VarNamedTuples.PartialArray{Symbol, 1}}},1}((2,) => VarNamedTuple(; f=PartialArray{Symbol,1}((3,) => hip, (4,) => hop))))) +``` -This means that we can add levels to the NamedTuple tree until all `PropertyLenses` have been covered. The leaves of the tree are then of two kinds: They are either the raw value itself if the last lens of the VarName is an `identity`, or otherwise they are something that can be indexed with an `IndexLens`, such as an `Array`. +The output there may be a bit hard bit hard to parse, so to illustrate: -To get a value from a VarNamedTuple is very simple: For `getindex(vnt::VNT, vn::VarName{S})` (`S` being the lead Symbol) you recurse into `getindex(vnt[S], unprefix(vn, S))`. If the last lens of `vn` is an `IndexLens`, we assume that the leaf of the NamedTuple tree we've reached contains something that can be indexed with it. +```julia +julia> vnt[@varname(b)] +VarNamedTuple(; c=[2.0, 3.0]) -Setting values in a VNT is equally simple if there are no `IndexLenses`: For `setindex!!(vnt::VNT, value::Any, vn::VarName)` one simply finds the leaf of the `vnt` tree corresponding to `vn` and sets its value to `value`. +julia> vnt[@varname(b.c[1])] +2.0 -The tricky part is what to do when setting values with `IndexLenses`. There are three possible situations. Say one calls `setindex!!(vnt, 3.0, @varname(a.b[3]))`. +julia> vnt[@varname(d.e)] +PartialArray{VarNamedTuple{(:f,), Tuple{DynamicPPL.VarNamedTuples.PartialArray{Symbol, 1}}},1}((2,) => VarNamedTuple(; f=PartialArray{Symbol,1}((3,) => hip, (4,) => hop))) - 1. If `getindex(vnt, @varname(a.b))` is already a vector of length at least 3, this is easy: Just set the third element. - 2. If `getindex(vnt, @varname(a.b))` is a vector of length less than 3, what should we do? Do we error? Do we extend that vector? - 3. If `getindex(vnt, @varname(a.b))` isn't even set, what do we do? Say for instance that `vnt` is currently empty. We should set `vnt` to be something like `(; a=(; b=x))`, where `x` is such that `x[3] = 3.0`, but what exactly should `x` be? Is it a dictionary? A vector of length 3? If the latter, what are `x[2]` and `x[1]`? Or should this `setindex!!` call simply error? +julia> vnt[@varname(d.e[2].f)] +PartialArray{Symbol,1}((3,) => hip, (4,) => hop) +``` -A note at this point: VarNamedTuples must always use `setindex!!`, the `!!` version that may or may not operate in place. The NamedTuples can't be modified in place, but the values at the leaves may be. Always using a `!!` function makes type stability easier, and makes structures like the type unstable old VarInfo with Metadata unnecessary: Any value can be set into any VarNamedTuple. The type parameters of the VNT will simply expand as necessary. +The above example also highlights how setting indices in a `VarNamedTuple` is done using `BangBang.setindex!!`. +We do not define a method for `Base.setindex!` at all, the `setindex!!` is the only way. +This is because `VarNamedTuple` mixes mutable an immutable data structures. +It is also for user convenience: +One does not ever have to think about whether the value that one is inserting into a `VarNamedTuple` is of the right type to fit in. +Rather the containers will flex to fit it, keeping element types concrete when possible, but making them abstract if needed. +`VarNamedTuple`, or more precisely `PartialArray`, even explicitly concretises element types whenever possible. +For instance, one can make an abstractly typed `VarNamedTuple` like so: -To solve the problem of points 2. and 3. above I propose expanding the definition of VNT a bit. This will also help make VNT more flexible, which may help performance or allow more use cases. The modification is this: +```julia +julia> vnt = VarNamedTuple(); -Unlike I said above, let's say that VNT isn't just nested NamedTuples with some values at the leaves. Let's say it also has a field called `make_leaf`. `make_leaf(value, lens)` is a function that takes any value, and a lens that is either `identity` or an `IndexLens`, and returns the value wrapped in some suitable struct that can be stored in the leaf of the NamedTuple tree. The values should always be such that `make_leaf(value, lens)[lens] == value`. +julia> vnt = setindex!!(vnt, 1.0, @varname(a[1])); -Our earlier example of `VarNamedTuple(@varname(x) => 1, @varname(y.z) => 2; make_leaf=f)` would be stored as a tree like +julia> vnt = setindex!!(vnt, "hello", @varname(a[2])); -``` - --NT-- - x / \ y -f(1, identity) NT - \ z - f(2, identity) +julia> print(vnt) +VarNamedTuple(; a=PartialArray{Any,1}((1,) => 1.0, (2,) => hello)) ``` -The above, first draft of VNT which did not include `make_leaf` is equivalent to the trivial choice `make_leaf(value, lens) = lens === identity ? value : error("Don't know how to deal IndexLenses")`. The problems 2. and 3. above are "solved" by making it `make_leaf`'s problem to figure out what to do. For instance, `make_leaf` can always return a `Dict` that maps lenses to values. This is probably slow, but works for any lens. Or it can initialise a vector type, that can grow as needed when indexed into. +Note the element type of `PartialArray{Any}`. +But if one changes the values to make them homogeneous, the element type is automatically made concrete again: -The idea would be to use `make_leaf` to try out different ways of implementing a VarInfo, find a good default, and ,if necessary, leave the option for power users to customise behaviour. The first ones to implement would be +```julia +julia> vnt = setindex!!(vnt, "me here", @varname(a[1])); - - `make_leaf` that returns a Metadata object. This would be a direct replacement for type stable VarInfo that uses Metadata, except now with more nested levels of NamedTuple. - - `make_leaf` that returns an `OrderedDict`. This would be a direct replacement for SimpleVarInfo with OrderedDict. - -You may ask, have we simple gone from too many VarInfo types to too many `make_leaf` functions. Yes we have. But hopefully we have gained something in the process: - - - The leaf types can be simpler. They do not need to deal with VarNames any more, they only need to deal with `identity` lenses and `IndexLenses`. - - All AbstactVarInfos are as type stable as their leaf types allow. There is no more notion of an untyped VarInfo being converted to a typed one. - - Type stability is maintained even with nested `PropertyLenses` like `@varname(a.b)`, which happens a lot with submodels. - - Many functions that are currently implemented individually for each AbstactVarInfo type would now have a single implementation for the VarNamedTuple-based AbstactVarInfo type, reducing code duplication. I would also hope to get ride of most of the generated functions for in `varinfo.jl`. - -My guess is that the eventual One AbstractVarInfo To Rule Them All would have a `make_leaf` function that stores the raw values when the lens is an `identity`, and uses a flexible Vector, a lot like VarNamedVector, when the lens is an IndexLens. However, I could be wrong on that being the best option. Implementing and benchmarking is the only way to know. +julia> print(vnt) +VarNamedTuple(; a=PartialArray{String,1}((1,) => me here, (2,) => hello)) +``` -I think the two big questions are: +This approach is at the core of why `VarNamedTuple` is performant: +As long as one does not store inhomogeneous types within a single `PartialArray`, by assigning different types to `VarName`s like `@varname(a[1])` and `@varname(a[2])`, different variables in a `VarNamedTuple` can have different types, and all `getindex` and `setindex!!` operations remain type stable. +Note that assigning a value to `@varname(a[1].b)` but not to `@varname(a[2].b)` has the same effect as assigning values of different types to `@varname(a[1])` and `@varname(a[2])`, and also causes a loss of type stability for for `getindex` and `setindex!!`. +Although, this only affects `getindex` and `setindex!!` on sub-`VarName`s of `@varname(a)`, you can still use the same `VarNamedTuple` to store information about an unrelated `@varname(c)` with stability. - - Will we run into some big, unanticipated blockers when we start to implement this. - - Will the nesting of NamedTuples cause performance regressions, if the compiler either chokes or gives up. +Some miscellaneous notes -I'll try to derisk these early on in this PR. +## Limitations -## Questions / issues +This design has a several of benefits, for performance and generality, but it also has limitations: - - People might really need IndexLenses in the middle of VarNames. The one place this comes up is submodels within a loop. I'm still inclined to keep designing without allowing for that, for now, but should keep in mind that that needs to be relaxed eventually. If it makes it easier, we can require that users explicitly tell us the size of any arrays for which this is done. - - When storing values for nested NamedTuples, the actual variable may be a struct. Do we need to be able to reconstruct the struct from the NamedTuple? If so, how do we do that? - - Do `Colon` indices cause any extra trouble for the leafnodes? + 1. The lack of support for `Colon`s in `VarName`s. + 2. The lack of support for some other indexing syntaxes supported by Julia, such as linear indexing and boolean indexing. + 3. An assymmetry between storing arrays with `setindex!!(vnt, array, @varname(a))` and elements of arrays with `setindex!!(vnt, element, @varname(a[i]))`. + The former stores the whole array, which can then be indexed with both `@varname(a)` and `@varname(a[i])`. + The latter stores only individual elements, and even if all elements have been set, one still can't get the value associated with `@varname(a)` as a regular `Base.Array`. From a6d56a2b9074d9da27eea4a6e4a2ab9a3013913f Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 27 Nov 2025 12:08:24 +0000 Subject: [PATCH 026/148] Improve FastLDF type stability when all parameters are linked or unlinked (#1141) * Improve type stability when all parameters are linked or unlinked * fix a merge conflict * fix enzyme gc crash (locally at least) * Fixes from review --- src/chains.jl | 8 +++-- src/contexts/init.jl | 33 ++++++++++++++++--- src/logdensityfunction.jl | 56 +++++++++++++++++++++++++-------- test/integration/enzyme/main.jl | 10 ++++-- test/logdensityfunction.jl | 16 ++++++++++ 5 files changed, 99 insertions(+), 24 deletions(-) diff --git a/src/chains.jl b/src/chains.jl index 2fcd4e713..d01606c3d 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -130,13 +130,15 @@ via `unflatten` plus re-evaluation. It is faster for two reasons: """ function ParamsWithStats( param_vector::AbstractVector, - ldf::DynamicPPL.LogDensityFunction, + ldf::DynamicPPL.LogDensityFunction{Tlink}, stats::NamedTuple=NamedTuple(); include_colon_eq::Bool=true, include_log_probs::Bool=true, -) +) where {Tlink} strategy = InitFromParams( - VectorWithRanges(ldf._iden_varname_ranges, ldf._varname_ranges, param_vector), + VectorWithRanges{Tlink}( + ldf._iden_varname_ranges, ldf._varname_ranges, param_vector + ), nothing, ) accs = if include_log_probs diff --git a/src/contexts/init.jl b/src/contexts/init.jl index a79969a13..80a494c23 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -214,7 +214,7 @@ struct RangeAndLinked end """ - VectorWithRanges( + VectorWithRanges{Tlink}( iden_varname_ranges::NamedTuple, varname_ranges::Dict{VarName,RangeAndLinked}, vect::AbstractVector{<:Real}, @@ -223,6 +223,12 @@ end A struct that wraps a vector of parameter values, plus information about how random variables map to ranges in that vector. +The type parameter `Tlink` can be either `true` or `false`, to mark that the variables in +this `VectorWithRanges` are linked/not linked, or `nothing` if either the linking status is +not known or is mixed, i.e. some are linked while others are not. Using `nothing` does not +affect functionality or correctness, but causes more work to be done at runtime, with +possible impacts on type stability and performance. + In the simplest case, this could be accomplished only with a single dictionary mapping VarNames to ranges and link status. However, for performance reasons, we separate out VarNames with identity optics into a NamedTuple (`iden_varname_ranges`). All @@ -231,13 +237,26 @@ non-identity-optic VarNames are stored in the `varname_ranges` Dict. It would be nice to improve the NamedTuple and Dict approach. See, e.g. https://github.com/TuringLang/DynamicPPL.jl/issues/1116. """ -struct VectorWithRanges{N<:NamedTuple,T<:AbstractVector{<:Real}} +struct VectorWithRanges{Tlink,N<:NamedTuple,T<:AbstractVector{<:Real}} # This NamedTuple stores the ranges for identity VarNames iden_varname_ranges::N # This Dict stores the ranges for all other VarNames varname_ranges::Dict{VarName,RangeAndLinked} # The full parameter vector which we index into to get variable values vect::T + + function VectorWithRanges{Tlink}( + iden_varname_ranges::N, varname_ranges::Dict{VarName,RangeAndLinked}, vect::T + ) where {Tlink,N,T} + if !(Tlink isa Union{Bool,Nothing}) + throw( + ArgumentError( + "VectorWithRanges type parameter has to be one of `true`, `false`, or `nothing`.", + ), + ) + end + return new{Tlink,N,T}(iden_varname_ranges, varname_ranges, vect) + end end function _get_range_and_linked( @@ -252,11 +271,15 @@ function init( ::Random.AbstractRNG, vn::VarName, dist::Distribution, - p::InitFromParams{<:VectorWithRanges}, -) + p::InitFromParams{<:VectorWithRanges{T}}, +) where {T} vr = p.params range_and_linked = _get_range_and_linked(vr, vn) - transform = if range_and_linked.is_linked + # T can either be `nothing` (i.e., link status is mixed, in which + # case we use the stored link status), or `true` / `false`, which + # indicates that all variables are linked / unlinked. + linked = isnothing(T) ? range_and_linked.is_linked : T + transform = if linked from_linked_vec_transform(dist) else from_vec_transform(dist) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index bcdd0bb25..7d1094fa3 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -140,6 +140,9 @@ with such models.** This is a general limitation of vectorised parameters: the o `unflatten` + `evaluate!!` approach also fails with such models. """ struct LogDensityFunction{ + # true if all variables are linked; false if all variables are unlinked; nothing if + # mixed + Tlink, M<:Model, AD<:Union{ADTypes.AbstractADType,Nothing}, F<:Function, @@ -163,6 +166,21 @@ struct LogDensityFunction{ # Figure out which variable corresponds to which index, and # which variables are linked. all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo) + # Figure out if all variables are linked, unlinked, or mixed + link_statuses = Bool[] + for ral in all_iden_ranges + push!(link_statuses, ral.is_linked) + end + for (_, ral) in all_ranges + push!(link_statuses, ral.is_linked) + end + Tlink = if all(link_statuses) + true + elseif all(!s for s in link_statuses) + false + else + nothing + end x = [val for val in varinfo[:]] dim = length(x) # Do AD prep if needed @@ -172,12 +190,13 @@ struct LogDensityFunction{ # Make backend-specific tweaks to the adtype adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo) DI.prepare_gradient( - LogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges), + LogDensityAt{Tlink}(model, getlogdensity, all_iden_ranges, all_ranges), adtype, x, ) end return new{ + Tlink, typeof(model), typeof(adtype), typeof(getlogdensity), @@ -209,15 +228,24 @@ end ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),)) ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),)) -struct LogDensityAt{M<:Model,F<:Function,N<:NamedTuple} +struct LogDensityAt{Tlink,M<:Model,F<:Function,N<:NamedTuple} model::M getlogdensity::F iden_varname_ranges::N varname_ranges::Dict{VarName,RangeAndLinked} + + function LogDensityAt{Tlink}( + model::M, + getlogdensity::F, + iden_varname_ranges::N, + varname_ranges::Dict{VarName,RangeAndLinked}, + ) where {Tlink,M,F,N} + return new{Tlink,M,F,N}(model, getlogdensity, iden_varname_ranges, varname_ranges) + end end -function (f::LogDensityAt)(params::AbstractVector{<:Real}) +function (f::LogDensityAt{Tlink})(params::AbstractVector{<:Real}) where {Tlink} strategy = InitFromParams( - VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing + VectorWithRanges{Tlink}(f.iden_varname_ranges, f.varname_ranges, params), nothing ) accs = ldf_accs(f.getlogdensity) _, vi = DynamicPPL.init!!(f.model, OnlyAccsVarInfo(accs), strategy) @@ -225,9 +253,9 @@ function (f::LogDensityAt)(params::AbstractVector{<:Real}) end function LogDensityProblems.logdensity( - ldf::LogDensityFunction, params::AbstractVector{<:Real} -) - return LogDensityAt( + ldf::LogDensityFunction{Tlink}, params::AbstractVector{<:Real} +) where {Tlink} + return LogDensityAt{Tlink}( ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges )( params @@ -235,10 +263,10 @@ function LogDensityProblems.logdensity( end function LogDensityProblems.logdensity_and_gradient( - ldf::LogDensityFunction, params::AbstractVector{<:Real} -) + ldf::LogDensityFunction{Tlink}, params::AbstractVector{<:Real} +) where {Tlink} return DI.value_and_gradient( - LogDensityAt( + LogDensityAt{Tlink}( ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges ), ldf._adprep, @@ -247,12 +275,14 @@ function LogDensityProblems.logdensity_and_gradient( ) end -function LogDensityProblems.capabilities(::Type{<:LogDensityFunction{M,Nothing}}) where {M} +function LogDensityProblems.capabilities( + ::Type{<:LogDensityFunction{T,M,Nothing}} +) where {T,M} return LogDensityProblems.LogDensityOrder{0}() end function LogDensityProblems.capabilities( - ::Type{<:LogDensityFunction{M,<:ADTypes.AbstractADType}} -) where {M} + ::Type{<:LogDensityFunction{T,M,<:ADTypes.AbstractADType}} +) where {T,M} return LogDensityProblems.LogDensityOrder{1}() end function LogDensityProblems.dimension(ldf::LogDensityFunction) diff --git a/test/integration/enzyme/main.jl b/test/integration/enzyme/main.jl index ea4ec497d..edfd67d18 100644 --- a/test/integration/enzyme/main.jl +++ b/test/integration/enzyme/main.jl @@ -5,11 +5,15 @@ using Test: @test, @testset import Enzyme: set_runtime_activity, Forward, Reverse, Const using ForwardDiff: ForwardDiff # run_ad uses FD for correctness test -ADTYPES = Dict( - "EnzymeForward" => +ADTYPES = ( + ( + "EnzymeForward", AutoEnzyme(; mode=set_runtime_activity(Forward), function_annotation=Const), - "EnzymeReverse" => + ), + ( + "EnzymeReverse", AutoEnzyme(; mode=set_runtime_activity(Reverse), function_annotation=Const), + ), ) @testset "$ad_key" for (ad_key, ad_type) in ADTYPES diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index 06492d6e1..f43ed45a4 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -108,6 +108,22 @@ end end end +@testset "LogDensityFunction: Type stability" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + unlinked_vi = DynamicPPL.VarInfo(m) + @testset "$islinked" for islinked in (false, true) + vi = if islinked + DynamicPPL.link!!(unlinked_vi, m) + else + unlinked_vi + end + ldf = DynamicPPL.LogDensityFunction(m, DynamicPPL.getlogjoint_internal, vi) + x = vi[:] + @inferred LogDensityProblems.logdensity(ldf, x) + end + end +end + @testset "LogDensityFunction: performance" begin if Threads.nthreads() == 1 # Evaluating these three models should not lead to any allocations (but only when From eca65d5cd6f6147199f6fddb6f9be009abcbf454 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 27 Nov 2025 13:14:27 +0000 Subject: [PATCH 027/148] Fix a test --- test/varnamedtuple.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 7b26e9be7..a93db10cc 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -384,8 +384,8 @@ using BangBang: setindex!! VarNamedTuple(; a=1.0, b=[1, 2, 3], \ c=PartialArray{Int64,1}((2,) => 15), \ d=VarNamedTuple(; \ - e=PartialArray{VarNamedTuple{(:f,), \ - Tuple{VarNamedTuple{(:g,), \ + e=PartialArray{DynamicPPL.VarNamedTuples.VarNamedTuple{(:f,), \ + Tuple{DynamicPPL.VarNamedTuples.VarNamedTuple{(:g,), \ Tuple{DynamicPPL.VarNamedTuples.PartialArray{Float64, 1}}}}},1}((3,) => \ VarNamedTuple(; f=VarNamedTuple(; g=PartialArray{Float64,1}((1,) => 16.0, \ (2,) => 17.0))))))""" From 5e27a052725226c6cf15b0e44a38d63cd39eb032 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 27 Nov 2025 13:45:12 +0000 Subject: [PATCH 028/148] Fix copy and show --- src/varnamedtuple.jl | 39 ++++++++++++++++++++++++++++++++++++--- test/varnamedtuple.jl | 18 +++++++++--------- 2 files changed, 45 insertions(+), 12 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index e47b27e9e..2068566b4 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -155,7 +155,12 @@ function Base.show(io::IO, pa::PartialArray) is_first = false end val = @inbounds(pa.data[inds]) - print(io, Tuple(inds), " => ", val) + # Note the distinction: The raw strings that form part of the structure of the print + # out are `print`ed, whereas the keys and values are `show`n. The latter ensures + # that strings are quoted, Symbols are prefixed with :, etc. + show(io, Tuple(inds)) + print(io, " => ") + show(io, val) end print(io, ")") return nothing @@ -166,7 +171,17 @@ end _internal_size(pa::PartialArray, args...) = size(pa.data, args...) function Base.copy(pa::PartialArray) - return PartialArray(copy(pa.data), copy(pa.mask)) + # Make a shallow copy of pa, except for any VarNamedTuple elements, which we recursively + # copy. + pa_copy = PartialArray(copy(pa.data), copy(pa.mask)) + if VarNamedTuple <: eltype(pa) || eltype(pa) <: VarNamedTuple + @inbounds for i in eachindex(pa.mask) + if pa.mask[i] && pa_copy.data[i] isa VarNamedTuple + pa_copy.data[i] = copy(pa.data[i]) + end + end + end + return pa_copy end function Base.:(==)(pa1::PartialArray, pa2::PartialArray) @@ -471,11 +486,29 @@ function Base.show(io::IO, vnt::VarNamedTuple) if i > 1 print(io, ",") end - print(io, " ", name, "=", value) + print(io, " ") + print(io, name) + print(io, "=") + # Note the distinction: The raw strings that form part of the structure of the print + # out are `print`ed, whereas the value itself is `show`n. The latter ensures that + # strings are quoted, Symbols are prefixed with :, etc. + show(io, value) end return print(io, ")") end +function Base.copy(vnt::VarNamedTuple{Names}) where {Names} + # Make a shallow copy of vnt, except for any VarNamedTuple or PartialArray elements, + # which we recursively copy. + return VarNamedTuple( + NamedTuple{Names}( + map( + x -> x isa Union{VarNamedTuple,PartialArray} ? copy(x) : x, values(vnt.data) + ), + ), + ) +end + """ varname_to_lens(name::VarName{S}) where {S} diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index a93db10cc..ad5fba8c1 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -357,35 +357,35 @@ using BangBang: setindex!! output = String(take!(io)) @test output == "VarNamedTuple(;)" - vnt = setindex!!(vnt, 1.0, @varname(a)) + vnt = setindex!!(vnt, "s", @varname(a)) io = IOBuffer() show(io, vnt) output = String(take!(io)) - @test output == "VarNamedTuple(; a=1.0)" + @test output == """VarNamedTuple(; a="s")""" vnt = setindex!!(vnt, [1, 2, 3], @varname(b)) io = IOBuffer() show(io, vnt) output = String(take!(io)) - @test output == "VarNamedTuple(; a=1.0, b=[1, 2, 3])" + @test output == """VarNamedTuple(; a="s", b=[1, 2, 3])""" - vnt = setindex!!(vnt, 15, @varname(c[2])) + vnt = setindex!!(vnt, :dada, @varname(c[2])) io = IOBuffer() show(io, vnt) output = String(take!(io)) @test output == """ - VarNamedTuple(; a=1.0, b=[1, 2, 3], c=PartialArray{Int64,1}((2,) => 15))""" + VarNamedTuple(; a="s", b=[1, 2, 3], c=PartialArray{Symbol,1}((2,) => :dada))""" vnt = setindex!!(vnt, [16.0, 17.0], @varname(d.e[3].f.g[1:2])) io = IOBuffer() show(io, vnt) output = String(take!(io)) @test output == """ - VarNamedTuple(; a=1.0, b=[1, 2, 3], \ - c=PartialArray{Int64,1}((2,) => 15), \ + VarNamedTuple(; a="s", b=[1, 2, 3], \ + c=PartialArray{Symbol,1}((2,) => :dada), \ d=VarNamedTuple(; \ - e=PartialArray{DynamicPPL.VarNamedTuples.VarNamedTuple{(:f,), \ - Tuple{DynamicPPL.VarNamedTuples.VarNamedTuple{(:g,), \ + e=PartialArray{VarNamedTuple{(:f,), \ + Tuple{VarNamedTuple{(:g,), \ Tuple{DynamicPPL.VarNamedTuples.PartialArray{Float64, 1}}}}},1}((3,) => \ VarNamedTuple(; f=VarNamedTuple(; g=PartialArray{Float64,1}((1,) => 16.0, \ (2,) => 17.0))))))""" From 050b8c54ca435a50a6cf24bb4052b5a65385b0f9 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 27 Nov 2025 13:49:41 +0000 Subject: [PATCH 029/148] Add test_invariants to VNT tests --- test/varnamedtuple.jl | 67 +++++++++++++++++++++++++++++++++---------- 1 file changed, 52 insertions(+), 15 deletions(-) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index ad5fba8c1..f55f8b996 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -2,47 +2,67 @@ module VarNamedTupleTests using Test: @inferred, @test, @test_throws, @testset using DynamicPPL: DynamicPPL, @varname, VarNamedTuple +using DynamicPPL.VarNamedTuples: PartialArray using BangBang: setindex!! +""" + test_invariants(vnt::VarNamedTuple) + +Test properties that should hold for all VarNamedTuples. + +Uses @test for all the tests. Intended to be called inside a @testset. +""" +function test_invariants(vnt::VarNamedTuple) + # Check that for all keys in vnt, haskey is true, and resetting the value is a no-op. + for k in keys(vnt) + @test haskey(vnt, k) + v = getindex(vnt, k) + vnt2 = setindex!!(copy(vnt), v, k) + @test vnt == vnt2 + end + # Check that the printed representation can be parsed back to an equal VarNamedTuple. + vnt3 = eval(Meta.parse(repr(vnt))) + @test vnt == vnt3 +end + @testset "VarNamedTuple" begin @testset "Construction" begin vnt1 = VarNamedTuple() + test_invariants(vnt1) vnt1 = setindex!!(vnt1, 1.0, @varname(a)) vnt1 = setindex!!(vnt1, [1, 2, 3], @varname(b)) vnt1 = setindex!!(vnt1, "a", @varname(c.d.e)) + test_invariants(vnt1) vnt2 = VarNamedTuple(; a=1.0, b=[1, 2, 3], c=VarNamedTuple(; d=VarNamedTuple(; e="a")) ) + test_invariants(vnt2) @test vnt1 == vnt2 - pa1 = DynamicPPL.VarNamedTuples.PartialArray{Float64,1}() + pa1 = PartialArray{Float64,1}() pa1 = setindex!!(pa1, 1.0, 16) - pa2 = DynamicPPL.VarNamedTuples.PartialArray{Float64,1}(; min_size=(16,)) + pa2 = PartialArray{Float64,1}(; min_size=(16,)) pa2 = setindex!!(pa2, 1.0, 16) - pa3 = DynamicPPL.VarNamedTuples.PartialArray{Float64,1}(16 => 1.0) - pa4 = DynamicPPL.VarNamedTuples.PartialArray{Float64,1}((16,) => 1.0) + pa3 = PartialArray{Float64,1}(16 => 1.0) + pa4 = PartialArray{Float64,1}((16,) => 1.0) @test pa1 == pa2 @test pa1 == pa3 @test pa1 == pa4 - pa1 = DynamicPPL.VarNamedTuples.PartialArray{String,3}() + pa1 = PartialArray{String,3}() pa1 = setindex!!(pa1, "a", 2, 3, 4) pa1 = setindex!!(pa1, "b", 1, 2, 4) - pa2 = DynamicPPL.VarNamedTuples.PartialArray{String,3}(; min_size=(16, 16, 16)) + pa2 = PartialArray{String,3}(; min_size=(16, 16, 16)) pa2 = setindex!!(pa2, "a", 2, 3, 4) pa2 = setindex!!(pa2, "b", 1, 2, 4) - pa3 = DynamicPPL.VarNamedTuples.PartialArray{String,3}( - (2, 3, 4) => "a", (1, 2, 4) => "b" - ) + pa3 = PartialArray{String,3}((2, 3, 4) => "a", (1, 2, 4) => "b") @test pa1 == pa2 @test pa1 == pa3 - @test_throws BoundsError DynamicPPL.VarNamedTuples.PartialArray{Int,1}((0,) => 1) - @test_throws BoundsError DynamicPPL.VarNamedTuples.PartialArray{Int,1}((1, 2) => 1) - @test_throws MethodError DynamicPPL.VarNamedTuples.PartialArray{Int,1}((1,) => "a") - @test_throws MethodError DynamicPPL.VarNamedTuples.PartialArray{Int,1}( - (1,) => 1; min_size=(2, 2) - ) + @test_throws BoundsError PartialArray{Int,1}((0,) => 1) + @test_throws BoundsError PartialArray{Int,1}((1, 2) => 1) + @test_throws MethodError PartialArray{Int,1}((1,) => "a") + @test_throws MethodError PartialArray{Int,1}((1,) => 1; min_size=(2, 2)) end @testset "Basic sets and gets" begin @@ -51,6 +71,7 @@ using BangBang: setindex!! @test @inferred(getindex(vnt, @varname(a))) == 32.0 @test haskey(vnt, @varname(a)) @test !haskey(vnt, @varname(b)) + test_invariants(vnt) vnt = @inferred(setindex!!(vnt, [1, 2, 3], @varname(b))) @test @inferred(getindex(vnt, @varname(b))) == [1, 2, 3] @@ -59,40 +80,50 @@ using BangBang: setindex!! @test haskey(vnt, @varname(b[1])) @test haskey(vnt, @varname(b[1:3])) @test !haskey(vnt, @varname(b[4])) + test_invariants(vnt) vnt = @inferred(setindex!!(vnt, 64.0, @varname(a))) @test @inferred(getindex(vnt, @varname(a))) == 64.0 @test @inferred(getindex(vnt, @varname(b))) == [1, 2, 3] + test_invariants(vnt) vnt = @inferred(setindex!!(vnt, 15, @varname(b[2]))) @test @inferred(getindex(vnt, @varname(b))) == [1, 15, 3] @test @inferred(getindex(vnt, @varname(b[2]))) == 15 + test_invariants(vnt) vnt = @inferred(setindex!!(vnt, [10], @varname(c.x.y))) @test @inferred(getindex(vnt, @varname(c.x.y))) == [10] + test_invariants(vnt) vnt = @inferred(setindex!!(vnt, 11, @varname(c.x.y[1]))) @test @inferred(getindex(vnt, @varname(c.x.y))) == [11] @test @inferred(getindex(vnt, @varname(c.x.y[1]))) == 11 + test_invariants(vnt) vnt = @inferred(setindex!!(vnt, -1.0, @varname(d[4]))) @test @inferred(getindex(vnt, @varname(d[4]))) == -1.0 + test_invariants(vnt) vnt = @inferred(setindex!!(vnt, -2.0, @varname(d[4]))) @test @inferred(getindex(vnt, @varname(d[4]))) == -2.0 + test_invariants(vnt) # These can't be @inferred because `d` now has an abstract element type. Note that this # does not ruin type stability for other varnames that don't involve `d`. vnt = setindex!!(vnt, "a", @varname(d[5])) @test getindex(vnt, @varname(d[5])) == "a" + test_invariants(vnt) vnt = @inferred(setindex!!(vnt, 1.0, @varname(e.f[3].g.h[2].i))) @test @inferred(getindex(vnt, @varname(e.f[3].g.h[2].i))) == 1.0 @test haskey(vnt, @varname(e.f[3].g.h[2].i)) @test !haskey(vnt, @varname(e.f[2].g.h[2].i)) + test_invariants(vnt) vnt = @inferred(setindex!!(vnt, 2.0, @varname(e.f[3].g.h[2].i))) @test @inferred(getindex(vnt, @varname(e.f[3].g.h[2].i))) == 2.0 + test_invariants(vnt) vec = fill(1.0, 4) vnt = @inferred(setindex!!(vnt, vec, @varname(j[1:4]))) @@ -101,12 +132,14 @@ using BangBang: setindex!! @test haskey(vnt, @varname(j[4])) @test !haskey(vnt, @varname(j[5])) @test_throws BoundsError getindex(vnt, @varname(j[5])) + test_invariants(vnt) vec = fill(2.0, 4) vnt = @inferred(setindex!!(vnt, vec, @varname(j[2:5]))) @test @inferred(getindex(vnt, @varname(j[1]))) == 1.0 @test @inferred(getindex(vnt, @varname(j[2:5]))) == vec @test haskey(vnt, @varname(j[5])) + test_invariants(vnt) arr = fill(2.0, (4, 2)) vn = @varname(k.l[2:5, 3, 1:2, 2]) @@ -114,6 +147,7 @@ using BangBang: setindex!! @test @inferred(getindex(vnt, vn)) == arr # A subset of the elements set just now. @test @inferred(getindex(vnt, @varname(k.l[2, 3, 1:2, 2]))) == fill(2.0, 2) + test_invariants(vnt) # Not enough, or too many, indices. @test_throws BoundsError setindex!!(vnt, 0.0, @varname(k.l[1, 2, 3])) @@ -128,11 +162,13 @@ using BangBang: setindex!! # A subset of the elements set previously. @test @inferred(getindex(vnt, @varname(k.l[2, 3, 1:2, 2]))) == fill(2.0, 2) @test !haskey(vnt, @varname(k.l[2, 3, 3, 2])) + test_invariants(vnt) vnt = @inferred(setindex!!(vnt, 1.0, @varname(m[2]))) vnt = @inferred(setindex!!(vnt, 1.0, @varname(m[3]))) @test @inferred(getindex(vnt, @varname(m[2:3]))) == [1.0, 1.0] @test !haskey(vnt, @varname(m[1])) + test_invariants(vnt) # The below tests are mostly significant for the type stability aspect. For the last # test to pass, PartialArray needs to actively tighten its eltype when possible. @@ -147,6 +183,7 @@ using BangBang: setindex!! # VarNamedTuple with a concrete element type, and hence getindex can be inferred. vnt = setindex!!(vnt, 1.0, @varname(n[2].b)) @test @inferred(getindex(vnt, @varname(n[2].b))) == 1.0 + test_invariants(vnt) end @testset "equality" begin From f5616df867e2a685783686dbf1d382888a510974 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 27 Nov 2025 16:44:23 +0000 Subject: [PATCH 030/148] Improve VNT internal docs --- docs/src/internals/varnamedtuple.md | 47 ++++++++++++++++++++--------- 1 file changed, 33 insertions(+), 14 deletions(-) diff --git a/docs/src/internals/varnamedtuple.md b/docs/src/internals/varnamedtuple.md index 0194d05d7..7198aae9f 100644 --- a/docs/src/internals/varnamedtuple.md +++ b/docs/src/internals/varnamedtuple.md @@ -1,16 +1,16 @@ -# VarNamedTuple +# `VarNamedTuple` In DynamicPPL there is often a need to store data keyed by `VarName`s. This comes up when getting conditioned variable values from the user, when tracking values of random variables in the model outputs or inputs, etc. -Historically we've had several different approaches to this: Dictionaries, NamedTuples, vectors with subranges corresponding to different `VarName`s, and various combinations thereof. +Historically we've had several different approaches to this: Dictionaries, `NamedTuple`s, vectors with subranges corresponding to different `VarName`s, and various combinations thereof. To unify the treatment of these use cases, and handle them all in a robust and performant way, is the purpose of `VarNamedTuple`, aka VNT. It's a data structure that can store arbitrary data, indexed by (nearly) arbitrary `VarName`s, in a type stable and performant manner. -`VarNamedTuple` consists of nested `NamedTuple`s and `PartialArray`. +`VarNamedTuple` consists of nested `NamedTuple`s and `PartialArray`s. Let's first talk about the `NamedTuple` part. This is what is needed for handling `PropertyLens`es in `VarName`s, that is, `VarName`s consisting of nested symbols, like in `@varname(a.b.c)`. -In a `VarNamedTuple` each level of such nesting of `PropertyLens`es corresponds to a level of nested `NamedTuple`s, with the `Symbol`s of the lens as the keys. +In a `VarNamedTuple` each level of such nesting of `PropertyLens`es corresponds to a level of nested `NamedTuple`s, with the `Symbol`s of the lenses as keys. For instance, the `VarNamedTuple` mapping `@varname(x) => 1, @varname(y.z) => 2` would be stored as ``` @@ -30,11 +30,11 @@ x / \ y ``` If all `VarName`s consisted of only `PropertyLens`es we would be done designing the data structure. -However, recall that VarNames allow three different kinds of lenses: `PropertyLens`es, `IndexLens`es, and `identity` (the trivial lens). +However, recall that `VarName`s allow three different kinds of lenses: `PropertyLens`es, `IndexLens`es, and `identity` (the trivial lens). The `identity` lens presents no complications, and in fact in the above example there was an implicit identity lens in e.g. `@varname(x) => 1`. It is the `IndexLenses` that require more structure. -An `IndexLens` is the indexing layer in `VarName`s like `@varname(x[1])`, `@varname(x[1].a.b[2:3])` and `@varname(x[:].b[1,2,3].c[1:5,:])`. +An `IndexLens` is the square bracket indexing part in `VarName`s like `@varname(x[1])`, `@varname(x[1].a.b[2:3])` and `@varname(x[:].b[1,2,3].c[1:5,:])`. `VarNamedTuple` can not deal with `IndexLens`es in their full generality, for reasons we'll discuss below. Instead we restrict ourselves to `IndexLens`es where the indices are integers, explicit ranges with end points, like `1:5`, or tuples thereof. @@ -43,7 +43,7 @@ When we meet an `IndexLens`, we instead instert into the tree something called a A `PartialArray` is like a regular `Base.Array`, but with some elements possibly unset. It is a data structure we define ourselves for use within `VarNamedTuple`s. -A `PartialArray` has an element type and a number of dimensions, and they are known at compile time, but it does not have a size, and this thus not an `AbstractArray`. +A `PartialArray` has an element type and a number of dimensions, and they are known at compile time, but it does not have a size, and thus is not an `AbstractArray`. This is because if we set the elements `x[1,2]` and `x[14,10]` in a `PartialArray` called `x`, this does not mean that 14 and 10 are the ends of their respective dimensions. The typical use of this structure in DynamicPPL is that the user may define values for elements in an array-like structure one by one, and we do not always know how large these arrays are. @@ -52,8 +52,8 @@ A `Colon()` says that we should get or set all the values along that dimension, If `x[1]` and `x[4]` have been set, asking for `x[:]` is not a well-posed question. `PartialArray`s have other restrictions, compared to the full indexing syntax of Julia, as well: -They do not support linearly indexing into multidimemensional arrays (as in `rand(3,3)[8]`), nor indexing with arrays of indices (as in `rand(4)[[1,3]]`), nor indexing with boolean mask arrays as in `rand(4)[[true, false, true, false]]`). -This is mostly because we haven't seen a need to support them, and implementing would complicate the codebase for little gain. +They do not support linearly indexing into multidimemensional arrays (as in `rand(3,3)[8]`), nor indexing with arrays of indices (as in `rand(4)[[1,3]]`), nor indexing with boolean mask arrays (as in `rand(4)[[true, false, true, false]]`). +This is mostly because we haven't seen a need to support them, and implementing them would complicate the codebase for little gain. We may add support for them later if needed. `PartialArray`s can hold any values, just like `Base.Array`s, and in particular they can hold `VarNamedTuple`s. @@ -89,8 +89,20 @@ julia> vnt[@varname(d.e[2].f)] PartialArray{Symbol,1}((3,) => hip, (4,) => hop) ``` -The above example also highlights how setting indices in a `VarNamedTuple` is done using `BangBang.setindex!!`. -We do not define a method for `Base.setindex!` at all, the `setindex!!` is the only way. +Or as a tree drawing, where `PA` marks a `PartialArray`: + +``` + /----VNT------\ +a / | b \ d + 1 [2.0, 3.0] VNT + | e + PA(2 => VNT) + | f + PA(3 => :hip, 4 => :hop) +``` + +The above code also highlights how setting indices in a `VarNamedTuple` is done using `BangBang.setindex!!`. +We do not define a method for `Base.setindex!` at all, `setindex!!` is the only way. This is because `VarNamedTuple` mixes mutable an immutable data structures. It is also for user convenience: One does not ever have to think about whether the value that one is inserting into a `VarNamedTuple` is of the right type to fit in. @@ -122,9 +134,15 @@ VarNamedTuple(; a=PartialArray{String,1}((1,) => me here, (2,) => hello)) This approach is at the core of why `VarNamedTuple` is performant: As long as one does not store inhomogeneous types within a single `PartialArray`, by assigning different types to `VarName`s like `@varname(a[1])` and `@varname(a[2])`, different variables in a `VarNamedTuple` can have different types, and all `getindex` and `setindex!!` operations remain type stable. Note that assigning a value to `@varname(a[1].b)` but not to `@varname(a[2].b)` has the same effect as assigning values of different types to `@varname(a[1])` and `@varname(a[2])`, and also causes a loss of type stability for for `getindex` and `setindex!!`. -Although, this only affects `getindex` and `setindex!!` on sub-`VarName`s of `@varname(a)`, you can still use the same `VarNamedTuple` to store information about an unrelated `@varname(c)` with stability. +Although, this only affects `getindex` and `setindex!!` on sub-`VarName`s of `@varname(a)`; +You can still use the same `VarNamedTuple` to store information about an unrelated `@varname(c)` with stability. -Some miscellaneous notes +Note that if you `setindex!!` a new value into a `VarNamedTuple` with an `IndexLens`, this causes a `PartialArray` to be created. +However, if there already is a regular `Base.Array` stored in a `VarNamedTuple`, you can index into it with `IndexLens`es without involving `PartialArray`s. +That is, if you do `vnt = setindex!!(vnt, @varname(a), [1.0, 2.0])`, you can then either get the values with e.g. `vnt[@varname(a[1])`, which returns 1.0. +You can also set the elements with `vnt = setindex!!(vnt, @varname(a[1]), 3.0)`, and this will modify the existing `Base.Array`. +At this point you can not set any new values in that array that would be outside of its range, with something like `vnt = setindex!!(vnt, @varname(a[5]), 5.0)`. +The philosophy here is that once a `Base.Array` has been attached to a `VarName`, that takes precedence, and a `PartialArray` is only used as a fallback when we are told to store a value for `@varname(a[i])` without having any previous knowledge about what `@varname(a)` is. ## Limitations @@ -132,6 +150,7 @@ This design has a several of benefits, for performance and generality, but it al 1. The lack of support for `Colon`s in `VarName`s. 2. The lack of support for some other indexing syntaxes supported by Julia, such as linear indexing and boolean indexing. - 3. An assymmetry between storing arrays with `setindex!!(vnt, array, @varname(a))` and elements of arrays with `setindex!!(vnt, element, @varname(a[i]))`. + 3. `VarNamedTuple` can not store indices with different numbers of dimensions in the same value, so for instance `@varname(a[1])` and `@varname(a[1,1])` can not be stored in the same `VarNamedTuple`. + 4. There is an assymmetry between storing arrays with `setindex!!(vnt, array, @varname(a))` and elements of arrays with `setindex!!(vnt, element, @varname(a[i]))`. The former stores the whole array, which can then be indexed with both `@varname(a)` and `@varname(a[i])`. The latter stores only individual elements, and even if all elements have been set, one still can't get the value associated with `@varname(a)` as a regular `Base.Array`. From ec5dc8f0a53029475418cefb062fe9fe346b2a7b Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 27 Nov 2025 17:41:40 +0000 Subject: [PATCH 031/148] Polish VNT --- src/varnamedtuple.jl | 74 +++++++++++++++++++++++++++---------------- test/varnamedtuple.jl | 8 ++++- 2 files changed, 54 insertions(+), 28 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 2068566b4..47340f8f4 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -43,8 +43,8 @@ nested containers, and calls itself recursively on all elements that are found i `x1` and `x2`. In other words, if both `x` and `y` are collections with the key `a`, `Base.merge(x, y)[a]` -is `y[a]`, whereas `_merge_recursive(x, y)[a]` be `_merge_recursive(x[a], y[a])`, unless -no specific method is defined for the type of `x` and `y`, in which case +is `y[a]`, whereas `_merge_recursive(x, y)[a]` will be `_merge_recursive(x[a], y[a])`, +unless no specific method is defined for the type of `x` and `y`, in which case `_merge_recursive(x, y) === y` """ _merge_recursive(_, x2) = x2 @@ -81,12 +81,18 @@ way of saying whether the right hand side is of an acceptable size or not. The fact that its size is ill-defined also means that `PartialArray` is not a subtype of `AbstractArray`. -All indexing into `PartialArray`s are done with `getindex` and `setindex!!`. `setindex!`, +All indexing into `PartialArray`s is done with `getindex` and `setindex!!`. `setindex!`, `push!`, etc. are not defined. The element type of a `PartialArray` will change as needed under `setindex!!` to accomoddate the new values. Like `Base.Array`s, `PartialArray`s have a well-defined, compile-time-known element type -`ElType` and number of dimensions `numdims`. +`ElType` and number of dimensions `numdims`. Indices into a `PartialArray` must have exactly +`numdims` elements. + +If the element type of a `PartialArray` is not concrete, any call to `setindex!!` will check +if, after the new value has been set, the element type can be made more concrete. If so, +a new `PartialArray` with a more concrete element type is returned. Thus the element type +of any `PartialArray` should always be as concrete as is allowed by the elements in it. The internal implementation of an `PartialArray` consists of two arrays: one holding the data and the other one being a boolean mask indicating which elements are defined. These @@ -113,13 +119,20 @@ struct PartialArray{ElType,num_dims} end """ - PartialArray{ElType,num_dims}(min_size=nothing) + PartialArray{ElType,num_dims}(args::Vararg{Pair}; min_size=nothing) + +Create a new `PartialArray`. -Create a new empty `PartialArray` with set element type and number of dimensions. +The element type and number of dimensions have to be specified explicitly as type +parameters. The positional arguments can be `Pair`s of indices and values. For example, +```jldoctest +julia> pa = PartialArray{Int,2}((1,2) => 5, (3,4) => 10) +PartialArray{Int,2}((1, 2) => 5, (3, 4) => 10) +``` -The optional argument `min_size` can be used to specify the minimum initial size. This is -purely a performance optimisation, to avoid resizing if the eventual size is known ahead of -time. +The optional keywoard argument `min_size` can be used to specify the minimum initial size. +This is purely a performance optimisation, to avoid resizing if the eventual size is known +ahead of time. """ function PartialArray{ElType,num_dims}( args::Vararg{Pair}; min_size::Union{Tuple,Nothing}=nothing @@ -376,12 +389,12 @@ end function _merge_recursive(pa1::PartialArray, pa2::PartialArray) if ndims(pa1) != ndims(pa2) throw( - ArgumentError("Cannot merge PartialArrays with different number of dimensions") + ArgumentError("Cannot merge PartialArrays with different numbers of dimensions") ) end num_dims = ndims(pa1) merge_size = ntuple(i -> max(_internal_size(pa1, i), _internal_size(pa2, i)), num_dims) - result = if merge_size == _internal_size(pa2) + return if merge_size == _internal_size(pa2) # Either pa2 is strictly bigger than pa1 or they are equal in size. result = copy(pa2) for i in CartesianIndices(pa1.data) @@ -426,23 +439,22 @@ function _merge_recursive(pa1::PartialArray, pa2::PartialArray) result end end - return result end function Base.keys(pa::PartialArray) inds = findall(pa.mask) lenses = map(x -> IndexLens(Tuple(x)), inds) ks = Any[] - for l in lenses - val = getindex(pa.data, l.indices...) + for lens in lenses + val = getindex(pa.data, lens.indices...) if val isa VarNamedTuple subkeys = keys(val) for vn in subkeys - lens = varname_to_lens(vn) - push!(ks, _compose_no_identity(lens, l)) + sublens = _varname_to_lens(vn) + push!(ks, _compose_no_identity(sublens, lens)) end else - push!(ks, l) + push!(ks, lens) end end return ks @@ -455,8 +467,8 @@ A `NamedTuple`-like structure with `VarName` keys. `VarNamedTuple` is a data structure for storing arbitrary data, keyed by `VarName`s, in an efficient and type stable manner. It is mainly used through `getindex`, `setindex!!`, and -`haskey`, all of which accept `VarName`s and only `VarName`s as keys. Other notable methods -are `merge`, which recursively merges two `VarNamedTuple`s. +`haskey`, all of which accept `VarName`s and only `VarName`s as keys. Anther notable methods +is `merge`, which recursively merges two `VarNamedTuple`s. The there are two major limitations to indexing by VarNamedTuples: @@ -470,6 +482,9 @@ heterogeneous data under different indices of the same symbol. That is, if one e * sets `a[1].b` and `a[2].c`, without setting `a[1].c`. or `a[2].b`, then getting values for `a[1]` or `a[2]` will not be type stable. + +`VarNamedTuple` is intrinsically linked to `PartialArray`, which it'll use to store data +related to `VarName`s with `IndexLens` components. """ struct VarNamedTuple{Names,Values} data::NamedTuple{Names,Values} @@ -513,26 +528,29 @@ end varname_to_lens(name::VarName{S}) where {S} Convert a `VarName` to an `Accessor` lens, wrapping the first symdol in a `PropertyLens`. + +This is used to simplify method dispatch for `_getindx`, `_setindex!!`, and `_haskey`, by +considering `VarName`s to just be a special case of lenses. """ -function varname_to_lens(name::VarName{S}) where {S} +function _varname_to_lens(name::VarName{S}) where {S} return _compose_no_identity(getoptic(name), PropertyLens{S}()) end -_getindex(vnt::VarNamedTuple, name::VarName) = _getindex(vnt, varname_to_lens(name)) +_getindex(vnt::VarNamedTuple, name::VarName) = _getindex(vnt, _varname_to_lens(name)) _getindex(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} = getindex(vnt.data, S) _getindex(vnt::VarNamedTuple, name::Symbol) = vnt.data[name] -_haskey(vnt::VarNamedTuple, name::VarName) = _haskey(vnt, varname_to_lens(name)) +_haskey(vnt::VarNamedTuple, name::VarName) = _haskey(vnt, _varname_to_lens(name)) _haskey(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} = haskey(vnt.data, S) _haskey(vnt::VarNamedTuple, ::typeof(identity)) = true _haskey(::VarNamedTuple, ::IndexLens) = false function _setindex!!(vnt::VarNamedTuple, value, name::VarName) - return _setindex!!(vnt, value, varname_to_lens(name)) + return _setindex!!(vnt, value, _varname_to_lens(name)) end function _setindex!!(vnt::VarNamedTuple, value, ::PropertyLens{S}) where {S} - # I would like this to just read + # I would like for this to just read # return VarNamedTuple(_setindex!!(vnt.data, value, S)) # but that seems to be type unstable. Why? Shouldn't it obviously be the same as the # below? @@ -556,6 +574,8 @@ function _merge_recursive(vnt1::VarNamedTuple, vnt2::VarNamedTuple) return VarNamedTuple(result_data) end +# TODO(mhauru) The below remains unfinished an undertested. I think it's incorrect for more +# complex VarNames. It is unexported though. """ apply!!(func, vnt::VarNamedTuple, name::VarName) @@ -565,7 +585,7 @@ Apply `func` to the subdata at `name` in `vnt`, and set the result back at `name julia> vnt = VarNamedTuple() () -julia> vnt = setindex!!(vnt, [1,2,3], @varname(a)) +julia> vnt = setindex!!(vnt, [1, 2, 3], @varname(a)) (a -> [1, 2, 3]) julia> VarNamedTuples.apply!!(x -> x .+ 1, vnt, @varname(a)) @@ -650,12 +670,12 @@ function make_leaf(value, optic::ComposedFunction) return make_leaf(sub, optic.inner) end -function make_leaf(value, optic::IndexLens{T}) where {T} +function make_leaf(value, optic::IndexLens) inds = optic.indices num_inds = length(inds) # Check if any of the indices are ranges or colons. If yes, value needs to be an # AbstractArray. Otherwise it needs to be an individual value. - et = _is_multiindex(optic.indices) ? eltype(value) : typeof(value) + et = _is_multiindex(inds) ? eltype(value) : typeof(value) pa = PartialArray{et,num_inds}() return _setindex!!(pa, value, optic) end diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index f55f8b996..803d8c546 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -19,10 +19,12 @@ function test_invariants(vnt::VarNamedTuple) v = getindex(vnt, k) vnt2 = setindex!!(copy(vnt), v, k) @test vnt == vnt2 + @test hash(vnt) == hash(vnt2) end # Check that the printed representation can be parsed back to an equal VarNamedTuple. vnt3 = eval(Meta.parse(repr(vnt))) @test vnt == vnt3 + @test hash(vnt) == hash(vnt3) end @testset "VarNamedTuple" begin @@ -417,13 +419,17 @@ end io = IOBuffer() show(io, vnt) output = String(take!(io)) + # Depending on what's in scope, and maybe sometimes even the Julia version, + # sometimes types in the output are fully qualified, sometimes not. To avoid + # brittle tests, we normalise the output: + output = replace(output, "DynamicPPL." => "", "VarNamedTuples." => "") @test output == """ VarNamedTuple(; a="s", b=[1, 2, 3], \ c=PartialArray{Symbol,1}((2,) => :dada), \ d=VarNamedTuple(; \ e=PartialArray{VarNamedTuple{(:f,), \ Tuple{VarNamedTuple{(:g,), \ - Tuple{DynamicPPL.VarNamedTuples.PartialArray{Float64, 1}}}}},1}((3,) => \ + Tuple{PartialArray{Float64, 1}}}}},1}((3,) => \ VarNamedTuple(; f=VarNamedTuple(; g=PartialArray{Float64,1}((1,) => 16.0, \ (2,) => 17.0))))))""" end From 3ca36c48b9b52175a788407a90d11da940e12cac Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 27 Nov 2025 18:26:06 +0000 Subject: [PATCH 032/148] Make VNT merge type stable. Simplify printing, improve tests. --- src/varnamedtuple.jl | 55 ++++++++++++++++++---------------- test/varnamedtuple.jl | 70 +++++++++++++++++++++++++++++++------------ 2 files changed, 80 insertions(+), 45 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 47340f8f4..e49e1cb66 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -461,7 +461,7 @@ function Base.keys(pa::PartialArray) end """ - VarNamedTuple{Names,Values} + VarNamedTuple{names,Values} A `NamedTuple`-like structure with `VarName` keys. @@ -496,27 +496,19 @@ Base.:(==)(vnt1::VarNamedTuple, vnt2::VarNamedTuple) = vnt1.data == vnt2.data Base.hash(vnt::VarNamedTuple, h::UInt) = hash(vnt.data, h) function Base.show(io::IO, vnt::VarNamedTuple) - print(io, "VarNamedTuple(;") - for (i, (name, value)) in enumerate(pairs(vnt.data)) - if i > 1 - print(io, ",") - end - print(io, " ") - print(io, name) - print(io, "=") - # Note the distinction: The raw strings that form part of the structure of the print - # out are `print`ed, whereas the value itself is `show`n. The latter ensures that - # strings are quoted, Symbols are prefixed with :, etc. - show(io, value) + if isempty(vnt.data) + return print(io, "VarNamedTuple()") end - return print(io, ")") + print(io, "VarNamedTuple") + show(io, vnt.data) + return nothing end -function Base.copy(vnt::VarNamedTuple{Names}) where {Names} +function Base.copy(vnt::VarNamedTuple{names}) where {names} # Make a shallow copy of vnt, except for any VarNamedTuple or PartialArray elements, # which we recursively copy. return VarNamedTuple( - NamedTuple{Names}( + NamedTuple{names}( map( x -> x isa Union{VarNamedTuple,PartialArray} ? copy(x) : x, values(vnt.data) ), @@ -559,19 +551,25 @@ end Base.merge(x1::VarNamedTuple, x2::VarNamedTuple) = _merge_recursive(x1, x2) -# TODO(mhauru) Check the performance of this, and make it into a generated function if -# necessary. -function _merge_recursive(vnt1::VarNamedTuple, vnt2::VarNamedTuple) - result_data = vnt1.data - for k in keys(vnt2.data) - val = if haskey(result_data, k) - _merge_recursive(result_data[k], vnt2.data[k]) +# This needs to be a generated function for type stability. +@generated function _merge_recursive( + vnt1::VarNamedTuple{names1}, vnt2::VarNamedTuple{names2} +) where {names1,names2} + all_names = union(names1, names2) + exs = Expr[] + push!(exs, :(data = (;))) + for name in all_names + val_expr = if name in names1 && name in names2 + :(_merge_recursive(vnt1.data[$(QuoteNode(name))], vnt2.data[$(QuoteNode(name))])) + elseif name in names1 + :(vnt1.data[$(QuoteNode(name))]) else - vnt2.data[k] + :(vnt2.data[$(QuoteNode(name))]) end - Accessors.@reset result_data[k] = val + push!(exs, :(data = merge(data, NamedTuple{($(QuoteNode(name)),)}(($val_expr,))))) end - return VarNamedTuple(result_data) + push!(exs, :(return VarNamedTuple(data))) + return Expr(:block, exs...) end # TODO(mhauru) The below remains unfinished an undertested. I think it's incorrect for more @@ -601,6 +599,11 @@ function apply!!(func, vnt::VarNamedTuple, name::VarName) return _setindex!!(vnt, new_subdata, name) end +# TODO(mhauru) Should this return tuples, like it does now? That makes sense for +# VarNamedTuple itself, but if there is a nested PartialArray the tuple might get very big. +# Also, this is not very type stable, it fails even in basic cases. A generated function +# would help, but I failed to make one. Might be something to do with a recursive +# generated function. function Base.keys(vnt::VarNamedTuple) result = () for sym in keys(vnt.data) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 803d8c546..53ce10e94 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -3,6 +3,7 @@ module VarNamedTupleTests using Test: @inferred, @test, @test_throws, @testset using DynamicPPL: DynamicPPL, @varname, VarNamedTuple using DynamicPPL.VarNamedTuples: PartialArray +using AbstractPPL: VarName, prefix using BangBang: setindex!! """ @@ -25,6 +26,9 @@ function test_invariants(vnt::VarNamedTuple) vnt3 = eval(Meta.parse(repr(vnt))) @test vnt == vnt3 @test hash(vnt) == hash(vnt3) + # Check that merge with an empty VarNamedTuple is a no-op. + @test merge(vnt, VarNamedTuple()) == vnt + @test merge(VarNamedTuple(), vnt) == vnt end @testset "VarNamedTuple" begin @@ -186,6 +190,32 @@ end vnt = setindex!!(vnt, 1.0, @varname(n[2].b)) @test @inferred(getindex(vnt, @varname(n[2].b))) == 1.0 test_invariants(vnt) + + # Some funky Symbols in VarNames + # TODO(mhauru) This still isn't as robust as it should be, for instance Symbol(":") + # fails the eval(Meta.parse(print(vnt))) == vnt test because NamedTuple show doesn't + # respect the eval-property. + vn1 = VarName{Symbol("a b c")}() + vnt = @inferred(setindex!!(vnt, 2, vn1)) + @test @inferred(getindex(vnt, vn1)) == 2 + test_invariants(vnt) + vn2 = VarName{Symbol("1")}() + vnt = @inferred(setindex!!(vnt, 3, vn2)) + @test @inferred(getindex(vnt, vn2)) == 3 + test_invariants(vnt) + vn3 = VarName{Symbol("?!")}() + vnt = @inferred(setindex!!(vnt, 4, vn3)) + @test @inferred(getindex(vnt, vn3)) == 4 + test_invariants(vnt) + vnt = VarNamedTuple() + vn4 = prefix(prefix(vn1, vn2), vn3) + vnt = @inferred(setindex!!(vnt, 5, vn4)) + @test @inferred(getindex(vnt, vn4)) == 5 + test_invariants(vnt) + vn5 = prefix(prefix(vn3, vn2), vn1) + vnt = @inferred(setindex!!(vnt, 6, vn5)) + @test @inferred(getindex(vnt, vn5)) == 6 + test_invariants(vnt) end @testset "equality" begin @@ -229,7 +259,7 @@ end vnt2 = VarNamedTuple() expected_merge = VarNamedTuple() # TODO(mhauru) Wrap this merge in @inferred, likewise other merges where it makes sense. - @test merge(vnt1, vnt2) == expected_merge + @test @inferred(merge(vnt1, vnt2)) == expected_merge vnt1 = setindex!!(vnt1, 1.0, @varname(a)) vnt2 = setindex!!(vnt2, 2.0, @varname(b)) @@ -238,7 +268,7 @@ end expected_merge = setindex!!(expected_merge, 1.0, @varname(a)) expected_merge = setindex!!(expected_merge, 2, @varname(c)) expected_merge = setindex!!(expected_merge, 2.0, @varname(b)) - @test merge(vnt1, vnt2) == expected_merge + @test @inferred(merge(vnt1, vnt2)) == expected_merge vnt1 = VarNamedTuple() vnt2 = VarNamedTuple() @@ -250,7 +280,7 @@ end expected_merge = setindex!!(expected_merge, [1], @varname(d.a)) expected_merge = setindex!!(expected_merge, [2, 2], @varname(d.c)) expected_merge = setindex!!(expected_merge, [2, 2], @varname(d.b)) - @test merge(vnt1, vnt2) == expected_merge + @test @inferred(merge(vnt1, vnt2)) == expected_merge vnt1 = setindex!!(vnt1, 1, @varname(e.a[1])) vnt2 = setindex!!(vnt2, 2, @varname(e.a[2])) @@ -259,13 +289,13 @@ end vnt1 = setindex!!(vnt1, 1, @varname(e.a[3])) vnt2 = setindex!!(vnt2, 2, @varname(e.a[3])) expected_merge = setindex!!(expected_merge, 2, @varname(e.a[3])) - @test merge(vnt1, vnt2) == expected_merge + @test @inferred(merge(vnt1, vnt2)) == expected_merge vnt1 = setindex!!(vnt1, fill(1, 4), @varname(e.a[7:10])) vnt2 = setindex!!(vnt2, fill(2, 4), @varname(e.a[8:11])) expected_merge = setindex!!(expected_merge, 1, @varname(e.a[7])) expected_merge = setindex!!(expected_merge, fill(2, 4), @varname(e.a[8:11])) - @test merge(vnt1, vnt2) == expected_merge + @test @inferred(merge(vnt1, vnt2)) == expected_merge vnt1 = setindex!!(vnt1, ["1", "1"], @varname(f.a[1].b.c[2, 2].d[1, 3:4])) vnt2 = setindex!!(vnt2, ["2", "2"], @varname(f.a[1].b.c[2, 2].d[1, 3:4])) @@ -289,9 +319,9 @@ end expected_merge_12 = setindex!!(expected_merge_12, 1, @varname(a[257])) expected_merge_12 = setindex!!(expected_merge_12, 2, @varname(a[1])) expected_merge_12 = setindex!!(expected_merge_12, 2, @varname(a[2])) - @test merge(vnt1, vnt2) == expected_merge_12 + @test @inferred(merge(vnt1, vnt2)) == expected_merge_12 expected_merge_21 = setindex!!(expected_merge_12, 1, @varname(a[1])) - @test merge(vnt2, vnt1) == expected_merge_21 + @test @inferred(merge(vnt2, vnt1)) == expected_merge_21 vnt1 = VarNamedTuple() vnt2 = VarNamedTuple() @@ -310,11 +340,13 @@ end @testset "keys" begin vnt = VarNamedTuple() - @test keys(vnt) == () + @test @inferred(keys(vnt)) == () @test all(x -> haskey(vnt, x), keys(vnt)) vnt = setindex!!(vnt, 1.0, @varname(a)) - @test keys(vnt) == (@varname(a),) + # TODO(mhauru) that the below passes @inferred, but any of the later ones don't. + # We should improve type stability of keys(). + @test @inferred(keys(vnt)) == (@varname(a),) @test all(x -> haskey(vnt, x), keys(vnt)) vnt = setindex!!(vnt, [1, 2, 3], @varname(b)) @@ -394,26 +426,26 @@ end io = IOBuffer() show(io, vnt) output = String(take!(io)) - @test output == "VarNamedTuple(;)" + @test output == "VarNamedTuple()" vnt = setindex!!(vnt, "s", @varname(a)) io = IOBuffer() show(io, vnt) output = String(take!(io)) - @test output == """VarNamedTuple(; a="s")""" + @test output == """VarNamedTuple(a = "s",)""" vnt = setindex!!(vnt, [1, 2, 3], @varname(b)) io = IOBuffer() show(io, vnt) output = String(take!(io)) - @test output == """VarNamedTuple(; a="s", b=[1, 2, 3])""" + @test output == """VarNamedTuple(a = "s", b = [1, 2, 3])""" vnt = setindex!!(vnt, :dada, @varname(c[2])) io = IOBuffer() show(io, vnt) output = String(take!(io)) @test output == """ - VarNamedTuple(; a="s", b=[1, 2, 3], c=PartialArray{Symbol,1}((2,) => :dada))""" + VarNamedTuple(a = "s", b = [1, 2, 3], c = PartialArray{Symbol,1}((2,) => :dada))""" vnt = setindex!!(vnt, [16.0, 17.0], @varname(d.e[3].f.g[1:2])) io = IOBuffer() @@ -424,14 +456,14 @@ end # brittle tests, we normalise the output: output = replace(output, "DynamicPPL." => "", "VarNamedTuples." => "") @test output == """ - VarNamedTuple(; a="s", b=[1, 2, 3], \ - c=PartialArray{Symbol,1}((2,) => :dada), \ - d=VarNamedTuple(; \ - e=PartialArray{VarNamedTuple{(:f,), \ + VarNamedTuple(a = "s", b = [1, 2, 3], \ + c = PartialArray{Symbol,1}((2,) => :dada), \ + d = VarNamedTuple(\ + e = PartialArray{VarNamedTuple{(:f,), \ Tuple{VarNamedTuple{(:g,), \ Tuple{PartialArray{Float64, 1}}}}},1}((3,) => \ - VarNamedTuple(; f=VarNamedTuple(; g=PartialArray{Float64,1}((1,) => 16.0, \ - (2,) => 17.0))))))""" + VarNamedTuple(f = VarNamedTuple(g = PartialArray{Float64,1}((1,) => 16.0, \ + (2,) => 17.0),),)),))""" end end From 59f67fd1c62a5ad8885fa207f666dc58d713117c Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 27 Nov 2025 19:23:45 +0000 Subject: [PATCH 033/148] Add VNT too API docs --- docs/src/api.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/src/api.md b/docs/src/api.md index adb476db5..a3d93aa22 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -364,6 +364,12 @@ Base.empty! SimpleVarInfo ``` +#### `VarNamedTuple` + +```@docs +VarNamedTuple +``` + ### Accumulators The subtypes of [`AbstractVarInfo`](@ref) store the cumulative log prior and log likelihood, and sometimes other variables that change during executing, in what are called accumulators. From 9aba468eeb65471f26469475d168b90cdcf54b61 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 27 Nov 2025 19:27:36 +0000 Subject: [PATCH 034/148] Fix doctests --- src/varnamedtuple.jl | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index e49e1cb66..d3f2ba13a 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -126,8 +126,10 @@ Create a new `PartialArray`. The element type and number of dimensions have to be specified explicitly as type parameters. The positional arguments can be `Pair`s of indices and values. For example, ```jldoctest +julia> using DynamicPPL.VarNamedTuples: PartialArray + julia> pa = PartialArray{Int,2}((1,2) => 5, (3,4) => 10) -PartialArray{Int,2}((1, 2) => 5, (3, 4) => 10) +PartialArray{Int64,2}((1, 2) => 5, (3, 4) => 10) ``` The optional keywoard argument `min_size` can be used to specify the minimum initial size. @@ -580,14 +582,18 @@ end Apply `func` to the subdata at `name` in `vnt`, and set the result back at `name`. ```jldoctest +julia> using DynamicPPL: VarNamedTuple, setindex!! + +julia> using DynamicPPL.VarNamedTuples: apply!! + julia> vnt = VarNamedTuple() -() +VarNamedTuple() julia> vnt = setindex!!(vnt, [1, 2, 3], @varname(a)) -(a -> [1, 2, 3]) +VarNamedTuple(a = [1, 2, 3],) -julia> VarNamedTuples.apply!!(x -> x .+ 1, vnt, @varname(a)) -(a -> [2, 3, 4]) +julia> apply!!(x -> x .+ 1, vnt, @varname(a)) +VarNamedTuple(a = [2, 3, 4],) ``` """ function apply!!(func, vnt::VarNamedTuple, name::VarName) From 0b4c772460f8283b45e919f7e6277db864f2371b Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 27 Nov 2025 19:39:24 +0000 Subject: [PATCH 035/148] Clean up tests a bit --- test/varnamedtuple.jl | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 53ce10e94..67f3d5c2b 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -258,7 +258,6 @@ end vnt1 = VarNamedTuple() vnt2 = VarNamedTuple() expected_merge = VarNamedTuple() - # TODO(mhauru) Wrap this merge in @inferred, likewise other merges where it makes sense. @test @inferred(merge(vnt1, vnt2)) == expected_merge vnt1 = setindex!!(vnt1, 1.0, @varname(a)) @@ -341,29 +340,23 @@ end @testset "keys" begin vnt = VarNamedTuple() @test @inferred(keys(vnt)) == () - @test all(x -> haskey(vnt, x), keys(vnt)) vnt = setindex!!(vnt, 1.0, @varname(a)) # TODO(mhauru) that the below passes @inferred, but any of the later ones don't. # We should improve type stability of keys(). @test @inferred(keys(vnt)) == (@varname(a),) - @test all(x -> haskey(vnt, x), keys(vnt)) vnt = setindex!!(vnt, [1, 2, 3], @varname(b)) @test keys(vnt) == (@varname(a), @varname(b)) - @test all(x -> haskey(vnt, x), keys(vnt)) vnt = setindex!!(vnt, 15, @varname(b[2])) @test keys(vnt) == (@varname(a), @varname(b)) - @test all(x -> haskey(vnt, x), keys(vnt)) vnt = setindex!!(vnt, [10], @varname(c.x.y)) @test keys(vnt) == (@varname(a), @varname(b), @varname(c.x.y)) - @test all(x -> haskey(vnt, x), keys(vnt)) vnt = setindex!!(vnt, -1.0, @varname(d[4])) @test keys(vnt) == (@varname(a), @varname(b), @varname(c.x.y), @varname(d[4])) - @test all(x -> haskey(vnt, x), keys(vnt)) vnt = setindex!!(vnt, 2.0, @varname(e.f[3, 3].g.h[2, 4, 1].i)) @test keys(vnt) == ( @@ -373,7 +366,6 @@ end @varname(d[4]), @varname(e.f[3, 3].g.h[2, 4, 1].i), ) - @test all(x -> haskey(vnt, x), keys(vnt)) vnt = setindex!!(vnt, fill(1.0, 4), @varname(j[1:4])) @test keys(vnt) == ( @@ -387,7 +379,6 @@ end @varname(j[3]), @varname(j[4]), ) - @test all(x -> haskey(vnt, x), keys(vnt)) vnt = setindex!!(vnt, 1.0, @varname(j[6])) @test keys(vnt) == ( @@ -402,7 +393,6 @@ end @varname(j[4]), @varname(j[6]), ) - @test all(x -> haskey(vnt, x), keys(vnt)) vnt = setindex!!(vnt, 1.0, @varname(n[2].a)) @test keys(vnt) == ( @@ -418,7 +408,6 @@ end @varname(j[6]), @varname(n[2].a), ) - @test all(x -> haskey(vnt, x), keys(vnt)) end @testset "printing" begin @@ -445,7 +434,8 @@ end show(io, vnt) output = String(take!(io)) @test output == """ - VarNamedTuple(a = "s", b = [1, 2, 3], c = PartialArray{Symbol,1}((2,) => :dada))""" + VarNamedTuple(a = "s", b = [1, 2, 3], \ + c = PartialArray{Symbol,1}((2,) => :dada))""" vnt = setindex!!(vnt, [16.0, 17.0], @varname(d.e[3].f.g[1:2])) io = IOBuffer() From 38662a8537ac2d00f10f743affd6116c92e06e4a Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 27 Nov 2025 19:46:22 +0000 Subject: [PATCH 036/148] Fix API docs --- docs/src/api.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/api.md b/docs/src/api.md index a3d93aa22..1a22e4cdb 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -367,7 +367,7 @@ SimpleVarInfo #### `VarNamedTuple` ```@docs -VarNamedTuple +DynamicPPL.VarNamedTuples.VarNamedTuple ``` ### Accumulators From e41afcaf063124cd61fbefe3a58c5966be8ca61f Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 27 Nov 2025 19:47:34 +0000 Subject: [PATCH 037/148] Fix a bug and a docstring --- src/chains.jl | 5 +---- src/contexts/init.jl | 8 -------- 2 files changed, 1 insertion(+), 12 deletions(-) diff --git a/src/chains.jl b/src/chains.jl index d01606c3d..4d69b3590 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -136,10 +136,7 @@ function ParamsWithStats( include_log_probs::Bool=true, ) where {Tlink} strategy = InitFromParams( - VectorWithRanges{Tlink}( - ldf._iden_varname_ranges, ldf._varname_ranges, param_vector - ), - nothing, + VectorWithRanges{Tlink}(ldf._varname_ranges, param_vector), nothing ) accs = if include_log_probs ( diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 305f28767..90394a24c 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -227,14 +227,6 @@ this `VectorWithRanges` are linked/not linked, or `nothing` if either the linkin not known or is mixed, i.e. some are linked while others are not. Using `nothing` does not affect functionality or correctness, but causes more work to be done at runtime, with possible impacts on type stability and performance. - -In the simplest case, this could be accomplished only with a single dictionary mapping -VarNames to ranges and link status. However, for performance reasons, we separate out -VarNames with identity optics into a NamedTuple (`iden_varname_ranges`). All -non-identity-optic VarNames are stored in the `varname_ranges` Dict. - -It would be nice to improve the NamedTuple and Dict approach. See, e.g. -https://github.com/TuringLang/DynamicPPL.jl/issues/1116. """ struct VectorWithRanges{Tlink,VNT<:VarNamedTuple,T<:AbstractVector{<:Real}} # Ranges for all VarNames From 8c50bbb2bd3e6aaff059ea9a5a3afe6455f5a4b1 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 28 Nov 2025 09:14:02 +0000 Subject: [PATCH 038/148] Apply suggestions from code review Co-authored-by: Penelope Yong --- docs/src/internals/varnamedtuple.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/src/internals/varnamedtuple.md b/docs/src/internals/varnamedtuple.md index 7198aae9f..2d787574b 100644 --- a/docs/src/internals/varnamedtuple.md +++ b/docs/src/internals/varnamedtuple.md @@ -35,7 +35,7 @@ The `identity` lens presents no complications, and in fact in the above example It is the `IndexLenses` that require more structure. An `IndexLens` is the square bracket indexing part in `VarName`s like `@varname(x[1])`, `@varname(x[1].a.b[2:3])` and `@varname(x[:].b[1,2,3].c[1:5,:])`. -`VarNamedTuple` can not deal with `IndexLens`es in their full generality, for reasons we'll discuss below. +`VarNamedTuple` cannot deal with `IndexLens`es in their full generality, for reasons we'll discuss below. Instead we restrict ourselves to `IndexLens`es where the indices are integers, explicit ranges with end points, like `1:5`, or tuples thereof. When storing data in a `VarNamedTuple`, we recursively go through the nested lenses in the `VarName`, inserting a new `VarNamedTuple` for every `PropertyLens`. @@ -73,7 +73,7 @@ julia> print(vnt) VarNamedTuple(; a=1.0, b=VarNamedTuple(; c=[2.0, 3.0]), d=VarNamedTuple(; e=PartialArray{VarNamedTuple{(:f,), Tuple{DynamicPPL.VarNamedTuples.PartialArray{Symbol, 1}}},1}((2,) => VarNamedTuple(; f=PartialArray{Symbol,1}((3,) => hip, (4,) => hop))))) ``` -The output there may be a bit hard bit hard to parse, so to illustrate: +The output there may be a bit hard to parse, so to illustrate: ```julia julia> vnt[@varname(b)] @@ -103,7 +103,7 @@ a / | b \ d The above code also highlights how setting indices in a `VarNamedTuple` is done using `BangBang.setindex!!`. We do not define a method for `Base.setindex!` at all, `setindex!!` is the only way. -This is because `VarNamedTuple` mixes mutable an immutable data structures. +This is because `VarNamedTuple` mixes mutable and immutable data structures. It is also for user convenience: One does not ever have to think about whether the value that one is inserting into a `VarNamedTuple` is of the right type to fit in. Rather the containers will flex to fit it, keeping element types concrete when possible, but making them abstract if needed. @@ -150,7 +150,7 @@ This design has a several of benefits, for performance and generality, but it al 1. The lack of support for `Colon`s in `VarName`s. 2. The lack of support for some other indexing syntaxes supported by Julia, such as linear indexing and boolean indexing. - 3. `VarNamedTuple` can not store indices with different numbers of dimensions in the same value, so for instance `@varname(a[1])` and `@varname(a[1,1])` can not be stored in the same `VarNamedTuple`. - 4. There is an assymmetry between storing arrays with `setindex!!(vnt, array, @varname(a))` and elements of arrays with `setindex!!(vnt, element, @varname(a[i]))`. + 3. `VarNamedTuple` cannot store indices with different numbers of dimensions in the same value, so for instance `@varname(a[1])` and `@varname(a[1,1])` cannot be stored in the same `VarNamedTuple`. + 4. There is an asymmetry between storing arrays with `setindex!!(vnt, array, @varname(a))` and elements of arrays with `setindex!!(vnt, element, @varname(a[i]))`. The former stores the whole array, which can then be indexed with both `@varname(a)` and `@varname(a[i])`. The latter stores only individual elements, and even if all elements have been set, one still can't get the value associated with `@varname(a)` as a regular `Base.Array`. From cae8864c636bb33d49d26a8aaaf19ec173935689 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 28 Nov 2025 09:16:25 +0000 Subject: [PATCH 039/148] Fix VNT docs --- docs/src/internals/varnamedtuple.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/src/internals/varnamedtuple.md b/docs/src/internals/varnamedtuple.md index 2d787574b..47ff9c65e 100644 --- a/docs/src/internals/varnamedtuple.md +++ b/docs/src/internals/varnamedtuple.md @@ -94,9 +94,9 @@ Or as a tree drawing, where `PA` marks a `PartialArray`: ``` /----VNT------\ a / | b \ d - 1 [2.0, 3.0] VNT - | e - PA(2 => VNT) + 1 VNT VNT + | c | e + [2.0, 3.0] PA(2 => VNT) | f PA(3 => :hip, 4 => :hop) ``` From c27f5e0854b11e21c960569dcf478df2dba066a6 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 1 Dec 2025 13:22:04 +0000 Subject: [PATCH 040/148] Make threadsafe evaluation opt-in (#1151) * Make threadsafe evaluation opt-in * Reduce number of type parameters in methods * Make `warned_warn_about_threads_threads_threads_threads` shorter * Improve `setthreadsafe` docstring * warn on bare `@threads` as well * fix merge * Fix performance issues * Use maxthreadid() in TSVI * Move convert_eltype code to threadsafe eval function * Point to new Turing docs page * Add a test for setthreadsafe * Tidy up check_model * Apply suggestions from code review Fix outdated docstrings Co-authored-by: Markus Hauru * Improve warning message * Export `requires_threadsafe` * Add an actual docstring for `requires_threadsafe` --------- Co-authored-by: Markus Hauru --- HISTORY.md | 39 +++++++- docs/src/api.md | 8 ++ src/DynamicPPL.jl | 2 + src/compiler.jl | 54 ++++++++--- src/debug_utils.jl | 6 +- src/model.jl | 185 ++++++++++++++++++------------------- src/simple_varinfo.jl | 10 +- src/threadsafe.jl | 7 +- src/varinfo.jl | 16 +--- test/compiler.jl | 42 +++++++-- test/logdensityfunction.jl | 78 ++++++++-------- test/threadsafe.jl | 75 +++++---------- 12 files changed, 281 insertions(+), 241 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index ff28349d8..5dcb008d1 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -9,12 +9,49 @@ This version provides a reimplementation of `LogDensityFunction` that provides performance improvements on the order of 2–10× for both model evaluation as well as automatic differentiation. Exact speedups depend on the model size: larger models have less significant speedups because the bulk of the work is done in calls to `logpdf`. -For more information about how this is accomplished, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/fasteval.jl` file, which contains extensive comments. +For more information about how this is accomplished, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/logdensityfunction.jl` file, which contains extensive comments. As a result of this change, `LogDensityFunction` no longer stores a VarInfo inside it. In general, if `ldf` is a `LogDensityFunction`, it is now only valid to access `ldf.model` and `ldf.adtype`. If you were previously relying on this behaviour, you will need to store a VarInfo separately. +#### Threadsafe evaluation + +DynamicPPL models have traditionally supported running some probabilistic statements (e.g. tilde-statements, or `@addlogprob!`) in parallel. +Prior to DynamicPPL 0.39, thread safety for such models used to be enabled by default if Julia was launched with more than one thread. + +In DynamicPPL 0.39, **thread-safe evaluation is now disabled by default**. +If you need it (see below for more discussion of when you _do_ need it), you **must** now manually mark it as so, using: + +```julia +@model f() = ... +model = f() +model = setthreadsafe(model, true) +``` + +The problem with the previous on-by-default is that it can sacrifice a huge amount of performance when thread safety is not needed. +This is especially true when running Julia in a notebook, where multiple threads are often enabled by default. +Furthermore, it is not actually the correct approach: just because Julia has multiple threads does not mean that a particular model actually requires threadsafe evaluation. + +**A model requires threadsafe evaluation if, and only if, the VarInfo object used inside the model is manipulated in parallel.** +This can occur if any of the following are inside `Threads.@threads` or other concurrency functions / macros: + + - tilde-statements + - calls to `@addlogprob!` + - any direct manipulation of the special `__varinfo__` variable + +If you have none of these inside threaded blocks, then you do not need to mark your model as threadsafe. +**Notably, the following do not require threadsafe evaluation:** + + - Using threading for any computation that does not involve VarInfo. For example, you can calculate a log-probability in parallel, and then add it using `@addlogprob!` outside of the threaded block. This does not require threadsafe evaluation. + - Sampling with `AbstractMCMC.MCMCThreads()`. + +For more information about threadsafe evaluation, please see [the Turing docs](https://turinglang.org/docs/usage/threadsafe-evaluation/). + +When threadsafe evaluation is enabled for a model, an internal flag is set on the model. +The value of this flag can be queried using `DynamicPPL.requires_threadsafe(model)`, which returns a boolean. +This function is newly exported in this version of DynamicPPL. + #### Parent and leaf contexts The `DynamicPPL.NodeTrait` function has been removed. diff --git a/docs/src/api.md b/docs/src/api.md index adb476db5..193a6ce4c 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -42,6 +42,14 @@ The context of a model can be set using [`contextualize`](@ref): contextualize ``` +Some models require threadsafe evaluation (see [the Turing docs](https://turinglang.org/docs/usage/threadsafe-evaluation/) for more information on when this is necessary). +If this is the case, one must enable threadsafe evaluation for a model: + +```@docs +setthreadsafe +requires_threadsafe +``` + ## Evaluation With [`rand`](@ref) one can draw samples from the prior distribution of a [`Model`](@ref). diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index a885f6a96..fda428eaa 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -90,6 +90,8 @@ export AbstractVarInfo, Model, getmissings, getargnames, + setthreadsafe, + requires_threadsafe, extract_priors, values_as_in_model, # evaluation diff --git a/src/compiler.jl b/src/compiler.jl index 3324780ca..1b4260121 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -301,7 +301,7 @@ function model(mod, linenumbernode, expr, warn) modeldef = build_model_definition(expr) # Generate main body - modeldef[:body] = generate_mainbody(mod, modeldef[:body], warn) + modeldef[:body] = generate_mainbody(mod, modeldef[:body], warn, true) return build_output(modeldef, linenumbernode) end @@ -346,10 +346,11 @@ Generate the body of the main evaluation function from expression `expr` and arg If `warn` is true, a warning is displayed if internal variables are used in the model definition. """ -generate_mainbody(mod, expr, warn) = generate_mainbody!(mod, Symbol[], expr, warn) +generate_mainbody(mod, expr, warn, warn_threads) = + generate_mainbody!(mod, Symbol[], expr, warn, warn_threads) -generate_mainbody!(mod, found, x, warn) = x -function generate_mainbody!(mod, found, sym::Symbol, warn) +generate_mainbody!(mod, found, x, warn, warn_threads) = x +function generate_mainbody!(mod, found, sym::Symbol, warn, warn_threads) if warn && sym in INTERNALNAMES && sym ∉ found @warn "you are using the internal variable `$sym`" push!(found, sym) @@ -357,17 +358,39 @@ function generate_mainbody!(mod, found, sym::Symbol, warn) return sym end -function generate_mainbody!(mod, found, expr::Expr, warn) +function generate_mainbody!(mod, found, expr::Expr, warn, warn_threads) # Do not touch interpolated expressions expr.head === :$ && return expr.args[1] + # Flag to determine whether we've issued a warning for threadsafe macros Note that this + # detection is not fully correct. We can only detect the presence of a macro that has + # the symbol `Threads.@threads`, however, we can't detect if that *is actually* + # Threads.@threads from Base.Threads. + # Do we don't want escaped expressions because we unfortunately # escape the entire body afterwards. - Meta.isexpr(expr, :escape) && return generate_mainbody(mod, found, expr.args[1], warn) + Meta.isexpr(expr, :escape) && + return generate_mainbody(mod, found, expr.args[1], warn, warn_threads) # If it's a macro, we expand it if Meta.isexpr(expr, :macrocall) - return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn) + if ( + expr.args[1] == Symbol("@threads") || + expr.args[1] == Expr(:., :Threads, QuoteNode(Symbol("@threads"))) && + warn_threads + ) + warn_threads = false + @warn ( + "It looks like you are using `Threads.@threads` in your model definition." * + "\n\nNote that since version 0.39 of DynamicPPL, threadsafe evaluation of models is disabled by default." * + " If you need it, you will need to explicitly enable it by creating the model, and then running `model = setthreadsafe(model, true)`." * + "\n\nThreadsafe model evaluation is only needed when parallelising tilde-statements (not arbitrary Julia code), and avoiding it can often lead to significant performance improvements." * + "\n\nPlease see https://turinglang.org/docs/usage/threadsafe-evaluation/ for more details of when threadsafe evaluation is actually required." + ) + end + return generate_mainbody!( + mod, found, macroexpand(mod, expr; recursive=true), warn, warn_threads + ) end # Modify dotted tilde operators. @@ -375,7 +398,7 @@ function generate_mainbody!(mod, found, expr::Expr, warn) if args_dottilde !== nothing L, R = args_dottilde return generate_mainbody!( - mod, found, Base.remove_linenums!(generate_dot_tilde(L, R)), warn + mod, found, Base.remove_linenums!(generate_dot_tilde(L, R)), warn, warn_threads ) end @@ -385,8 +408,8 @@ function generate_mainbody!(mod, found, expr::Expr, warn) L, R = args_tilde return Base.remove_linenums!( generate_tilde( - generate_mainbody!(mod, found, L, warn), - generate_mainbody!(mod, found, R, warn), + generate_mainbody!(mod, found, L, warn, warn_threads), + generate_mainbody!(mod, found, R, warn, warn_threads), ), ) end @@ -397,13 +420,16 @@ function generate_mainbody!(mod, found, expr::Expr, warn) L, R = args_assign return Base.remove_linenums!( generate_assign( - generate_mainbody!(mod, found, L, warn), - generate_mainbody!(mod, found, R, warn), + generate_mainbody!(mod, found, L, warn, warn_threads), + generate_mainbody!(mod, found, R, warn, warn_threads), ), ) end - return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...) + return Expr( + expr.head, + map(x -> generate_mainbody!(mod, found, x, warn, warn_threads), expr.args)..., + ) end function generate_assign(left, right) @@ -699,7 +725,7 @@ function build_output(modeldef, linenumbernode) # to the call site modeldef[:body] = MacroTools.@q begin $(linenumbernode) - return $(DynamicPPL.Model)($name, $args_nt; $(kwargs_inclusion...)) + return $(DynamicPPL.Model){false}($name, $args_nt; $(kwargs_inclusion...)) end return MacroTools.@q begin diff --git a/src/debug_utils.jl b/src/debug_utils.jl index e8b50a0b7..8810b9819 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -424,8 +424,10 @@ function check_model_and_trace( # Perform checks before evaluating the model. issuccess = check_model_pre_evaluation(model) - # Force single-threaded execution. - _, varinfo = DynamicPPL.evaluate_threadunsafe!!(model, varinfo) + # TODO(penelopeysm): Implement merge, etc. for DebugAccumulator, and then perform a + # check on the merged accumulator, rather than checking it in the accumulate_assume + # calls. That way we can also correctly support multi-threaded evaluation. + _, varinfo = DynamicPPL.evaluate!!(model, varinfo) # Perform checks after evaluating the model. debug_acc = DynamicPPL.getacc(varinfo, Val(_DEBUG_ACC_NAME)) diff --git a/src/model.jl b/src/model.jl index 7d5bbf2fb..e82fdc60c 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,5 +1,5 @@ """ - struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} + struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext,Threaded} f::F args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} @@ -17,6 +17,10 @@ An argument with a type of `Missing` will be in `missings` by default. However, non-traditional use-cases `missings` can be defined differently. All variables in `missings` are treated as random variables rather than observations. +The `Threaded` type parameter indicates whether the model requires threadsafe evaluation +(i.e., whether the model contains statements which modify the internal VarInfo that are +executed in parallel). By default, this is set to `false`. + The default arguments are used internally when constructing instances of the same model with different arguments. @@ -33,26 +37,27 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) ``` """ -struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} <: - AbstractProbabilisticProgram +struct Model{ + F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext,Threaded +} <: AbstractProbabilisticProgram f::F args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} context::Ctx @doc """ - Model{missings}(f, args::NamedTuple, defaults::NamedTuple) + Model{Threaded,missings}(f, args::NamedTuple, defaults::NamedTuple) Create a model with evaluation function `f` and missing arguments overwritten by `missings`. """ - function Model{missings}( + function Model{Threaded,missings}( f::F, args::NamedTuple{argnames,Targs}, defaults::NamedTuple{defaultnames,Tdefaults}, context::Ctx=DefaultContext(), - ) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx} - return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx}( + ) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx,Threaded} + return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,Threaded}( f, args, defaults, context ) end @@ -66,23 +71,39 @@ Create a model with evaluation function `f` and missing arguments deduced from ` Default arguments `defaults` are used internally when constructing instances of the same model with different arguments. """ -@generated function Model( +@generated function Model{Threaded}( f::F, args::NamedTuple{argnames,Targs}, defaults::NamedTuple{kwargnames,Tkwargs}, context::AbstractContext=DefaultContext(), -) where {F,argnames,Targs,kwargnames,Tkwargs} +) where {Threaded,F,argnames,Targs,kwargnames,Tkwargs} missing_args = Tuple( name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing ) missing_kwargs = Tuple( name for (name, typ) in zip(kwargnames, Tkwargs.types) if typ <: Missing ) - return :(Model{$(missing_args..., missing_kwargs...)}(f, args, defaults, context)) + return :(Model{Threaded,$(missing_args..., missing_kwargs...)}( + f, args, defaults, context + )) +end + +function Model{Threaded}( + f, args::NamedTuple, context::AbstractContext=DefaultContext(); kwargs... +) where {Threaded} + return Model{Threaded}(f, args, NamedTuple(kwargs), context) end -function Model(f, args::NamedTuple, context::AbstractContext=DefaultContext(); kwargs...) - return Model(f, args, NamedTuple(kwargs), context) +""" + requires_threadsafe(model::Model) + +Return whether `model` has been marked as needing threadsafe evaluation (using +`setthreadsafe`). +""" +function requires_threadsafe( + ::Model{F,A,D,M,Ta,Td,Ctx,Threaded} +) where {F,A,D,M,Ta,Td,Ctx,Threaded} + return Threaded end """ @@ -92,7 +113,7 @@ Return a new `Model` with the same evaluation function and other arguments, but with its underlying context set to `context`. """ function contextualize(model::Model, context::AbstractContext) - return Model(model.f, model.args, model.defaults, context) + return Model{requires_threadsafe(model)}(model.f, model.args, model.defaults, context) end """ @@ -105,6 +126,33 @@ function setleafcontext(model::Model, context::AbstractContext) return contextualize(model, setleafcontext(model.context, context)) end +""" + setthreadsafe(model::Model, threadsafe::Bool) + +Returns a new `Model` with its threadsafe flag set to `threadsafe`. + +Threadsafe evaluation ensures correctness when executing model statements that mutate the +internal `VarInfo` object in parallel. For example, this is needed if tilde-statements are +nested inside `Threads.@threads` or similar constructs. + +It is not needed for generic multithreaded operations that don't involve VarInfo. For +example, calculating a log-likelihood term in parallel and then calling `@addlogprob!` +outside of the parallel region is safe without needing to set `threadsafe=true`. + +It is also not needed for multithreaded sampling with AbstractMCMC's `MCMCThreads()`. + +Setting `threadsafe` to `true` increases the overhead in evaluating the model. Please see +[the Turing.jl docs](https://turinglang.org/docs/usage/threadsafe-evaluation/) for more +details. +""" +function setthreadsafe(model::Model{F,A,D,M}, threadsafe::Bool) where {F,A,D,M} + return if requires_threadsafe(model) == threadsafe + model + else + Model{threadsafe,M}(model.f, model.args, model.defaults, model.context) + end +end + """ model | (x = 1.0, ...) @@ -863,16 +911,6 @@ function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInf return first(init!!(rng, model, varinfo)) end -""" - use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) - -Return `true` if evaluation of a model using `context` and `varinfo` should -wrap `varinfo` in `ThreadSafeVarInfo`, i.e. threadsafe evaluation, and `false` otherwise. -""" -function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) - return Threads.nthreads() > 1 -end - """ init!!( [rng::Random.AbstractRNG,] @@ -889,10 +927,7 @@ If `init_strategy` is not provided, defaults to `InitFromPrior()`. Returns a tuple of the model's return value, plus the updated `varinfo` object. """ -@inline function init!!( - # Note that this `@inline` is mandatory for performance, especially for - # LogDensityFunction. If it's not inlined, it leads to extra allocations (even for - # trivial models) and much slower runtime. +function init!!( rng::Random.AbstractRNG, model::Model, vi::AbstractVarInfo, @@ -900,36 +935,11 @@ Returns a tuple of the model's return value, plus the updated `varinfo` object. ) ctx = InitContext(rng, strategy) model = DynamicPPL.setleafcontext(model, ctx) - # TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what - # it _should_ do, but this is wrong regardless. - # https://github.com/TuringLang/DynamicPPL.jl/issues/1086 - return if Threads.nthreads() > 1 - # TODO(penelopeysm): The logic for setting eltype of accs is very similar to that - # used in `unflatten`. The reason why we need it here is because the VarInfo `vi` - # won't have been filled with parameters prior to `init!!` being called. - # - # Note that this eltype promotion is only needed for threadsafe evaluation. In an - # ideal world, this code should be handled inside `evaluate_threadsafe!!` or a - # similar method. In other words, it should not be here, and it should not be inside - # `unflatten` either. The problem is performance. Shifting this code around can have - # massive, inexplicable, impacts on performance. This should be investigated - # properly. - param_eltype = DynamicPPL.get_param_eltype(strategy) - accs = map(vi.accs) do acc - DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc) - end - vi = DynamicPPL.setaccs!!(vi, accs) - tsvi = ThreadSafeVarInfo(resetaccs!!(vi)) - retval, tsvi_new = DynamicPPL._evaluate!!(model, tsvi) - return retval, setaccs!!(tsvi_new.varinfo, DynamicPPL.getaccs(tsvi_new)) - else - return DynamicPPL._evaluate!!(model, resetaccs!!(vi)) - end + return DynamicPPL.evaluate!!(model, vi) end -@inline function init!!( +function init!!( model::Model, vi::AbstractVarInfo, strategy::AbstractInitStrategy=InitFromPrior() ) - # This `@inline` is also mandatory for performance return init!!(Random.default_rng(), model, vi, strategy) end @@ -938,55 +948,42 @@ end Evaluate the `model` with the given `varinfo`. -If multiple threads are available, the varinfo provided will be wrapped in a -`ThreadSafeVarInfo` before evaluation. +If the model has been marked as requiring threadsafe evaluation, are available, the varinfo +provided will be wrapped in a `ThreadSafeVarInfo` before evaluation. Returns a tuple of the model's return value, plus the updated `varinfo` (unwrapped if necessary). """ function AbstractPPL.evaluate!!(model::Model, varinfo::AbstractVarInfo) - return if use_threadsafe_eval(model.context, varinfo) - evaluate_threadsafe!!(model, varinfo) + return if requires_threadsafe(model) + # Use of float_type_with_fallback(eltype(x)) is necessary to deal with cases where x is + # a gradient type of some AD backend. + # TODO(mhauru) How could we do this more cleanly? The problem case is map_accumulator!! + # for ThreadSafeVarInfo. In that one, if the map produces e.g a ForwardDiff.Dual, but + # the accumulators in the VarInfo are plain floats, we error since we can't change the + # element type of ThreadSafeVarInfo.accs_by_thread. However, doing this conversion here + # messes with cases like using Float32 of logprobs and Float64 for x. Also, this is just + # plain ugly and hacky. + # The below line is finicky for type stability. For instance, assigning the eltype to + # convert to into an intermediate variable makes this unstable (constant propagation + # fails). Take care when editing. + param_eltype = DynamicPPL.get_param_eltype(varinfo, model.context) + accs = map(DynamicPPL.getaccs(varinfo)) do acc + DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc) + end + varinfo = DynamicPPL.setaccs!!(varinfo, accs) + wrapper = ThreadSafeVarInfo(resetaccs!!(varinfo)) + result, wrapper_new = _evaluate!!(model, wrapper) + # TODO(penelopeysm): If seems that if you pass a TSVI to this method, it + # will return the underlying VI, which is a bit counterintuitive (because + # calling TSVI(::TSVI) returns the original TSVI, instead of wrapping it + # again). + return result, setaccs!!(wrapper_new.varinfo, getaccs(wrapper_new)) else - evaluate_threadunsafe!!(model, varinfo) + _evaluate!!(model, resetaccs!!(varinfo)) end end -""" - evaluate_threadunsafe!!(model, varinfo) - -Evaluate the `model` without wrapping `varinfo` inside a `ThreadSafeVarInfo`. - -If the `model` makes use of Julia's multithreading this will lead to undefined behaviour. -This method is not exposed and supposed to be used only internally in DynamicPPL. - -See also: [`evaluate_threadsafe!!`](@ref) -""" -function evaluate_threadunsafe!!(model, varinfo) - return _evaluate!!(model, resetaccs!!(varinfo)) -end - -""" - evaluate_threadsafe!!(model, varinfo, context) - -Evaluate the `model` with `varinfo` wrapped inside a `ThreadSafeVarInfo`. - -With the wrapper, Julia's multithreading can be used for observe statements in the `model` -but parallel sampling will lead to undefined behaviour. -This method is not exposed and supposed to be used only internally in DynamicPPL. - -See also: [`evaluate_threadunsafe!!`](@ref) -""" -function evaluate_threadsafe!!(model, varinfo) - wrapper = ThreadSafeVarInfo(resetaccs!!(varinfo)) - result, wrapper_new = _evaluate!!(model, wrapper) - # TODO(penelopeysm): If seems that if you pass a TSVI to this method, it - # will return the underlying VI, which is a bit counterintuitive (because - # calling TSVI(::TSVI) returns the original TSVI, instead of wrapping it - # again). - return result, setaccs!!(wrapper_new.varinfo, getaccs(wrapper_new)) -end - """ _evaluate!!(model::Model, varinfo) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 434480be6..9d3fb1925 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -278,15 +278,7 @@ end function unflatten(svi::SimpleVarInfo, x::AbstractVector) vals = unflatten(svi.values, x) - # TODO(mhauru) See comment in unflatten in src/varinfo.jl for why this conversion is - # required but undesireable. - # The below line is finicky for type stability. For instance, assigning the eltype to - # convert to into an intermediate variable makes this unstable (constant propagation) - # fails. Take care when editing. - accs = map( - acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc), getaccs(svi) - ) - return SimpleVarInfo(vals, accs, svi.transformation) + return SimpleVarInfo(vals, svi.accs, svi.transformation) end function BangBang.empty!!(vi::SimpleVarInfo) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 89877f385..0e906b6ca 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -13,12 +13,7 @@ function ThreadSafeVarInfo(vi::AbstractVarInfo) # fields. This is not good practice --- see # https://github.com/TuringLang/DynamicPPL.jl/issues/924 for a full # explanation --- but it has worked okay so far. - # The use of nthreads()*2 here ensures that threadid() doesn't exceed - # the length of the logps array. Ideally, we would use maxthreadid(), - # but Mooncake can't differentiate through that. Empirically, nthreads()*2 - # seems to provide an upper bound to maxthreadid(), so we use that here. - # See https://github.com/TuringLang/DynamicPPL.jl/pull/936 - accs_by_thread = [map(split, getaccs(vi)) for _ in 1:(Threads.nthreads() * 2)] + accs_by_thread = [map(split, getaccs(vi)) for _ in 1:Threads.maxthreadid()] return ThreadSafeVarInfo(vi, accs_by_thread) end ThreadSafeVarInfo(vi::ThreadSafeVarInfo) = vi diff --git a/src/varinfo.jl b/src/varinfo.jl index 486d24191..14e08515c 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -367,21 +367,7 @@ vector_length(md::Metadata) = sum(length, md.ranges) function unflatten(vi::VarInfo, x::AbstractVector) md = unflatten_metadata(vi.metadata, x) - # Use of float_type_with_fallback(eltype(x)) is necessary to deal with cases where x is - # a gradient type of some AD backend. - # TODO(mhauru) How could we do this more cleanly? The problem case is map_accumulator!! - # for ThreadSafeVarInfo. In that one, if the map produces e.g a ForwardDiff.Dual, but - # the accumulators in the VarInfo are plain floats, we error since we can't change the - # element type of ThreadSafeVarInfo.accs_by_thread. However, doing this conversion here - # messes with cases like using Float32 of logprobs and Float64 for x. Also, this is just - # plain ugly and hacky. - # The below line is finicky for type stability. For instance, assigning the eltype to - # convert to into an intermediate variable makes this unstable (constant propagation) - # fails. Take care when editing. - accs = map( - acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc), copy(getaccs(vi)) - ) - return VarInfo(md, accs) + return VarInfo(md, vi.accs) end # We would call this `unflatten` if not for `unflatten` having a method for NamedTuples in diff --git a/test/compiler.jl b/test/compiler.jl index b1309254e..9056f666a 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -606,12 +606,7 @@ module Issue537 end @model demo() = return __varinfo__ retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) @test svi == SimpleVarInfo() - if Threads.nthreads() > 1 - @test retval isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo} - @test retval.varinfo == svi - else - @test retval == svi - end + @test retval == svi # We should not be altering return-values other than at top-level. @model function demo() @@ -793,4 +788,39 @@ module Issue537 end res = model() @test res == (a=1, b=1, c=2, d=2, t=DynamicPPL.TypeWrap{Int}()) end + + @testset "Threads.@threads detection" begin + # Check that the compiler detects when `Threads.@threads` is used inside a model + + e1 = quote + @model function f1() + Threads.@threads for i in 1:10 + x[i] ~ Normal() + end + end + end + @test_logs (:warn, r"threadsafe evaluation") eval(e1) + + e2 = quote + @model function f2() + for j in 1:10 + Threads.@threads for i in 1:10 + x[i] ~ Normal() + end + end + end + end + @test_logs (:warn, r"threadsafe evaluation") eval(e2) + + e3 = quote + @model function f3() + begin + Threads.@threads for i in 1:10 + x[i] ~ Normal() + end + end + end + end + @test_logs (:warn, r"threadsafe evaluation") eval(e3) + end end diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index f43ed45a4..1d609a013 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -51,21 +51,19 @@ using Mooncake: Mooncake end @testset "Threaded observe" begin - if Threads.nthreads() > 1 - @model function threaded(y) - x ~ Normal() - Threads.@threads for i in eachindex(y) - y[i] ~ Normal(x) - end + @model function threaded(y) + x ~ Normal() + Threads.@threads for i in eachindex(y) + y[i] ~ Normal(x) end - N = 100 - model = threaded(zeros(N)) - ldf = DynamicPPL.LogDensityFunction(model) - - xs = [1.0] - @test LogDensityProblems.logdensity(ldf, xs) ≈ - logpdf(Normal(), xs[1]) + N * logpdf(Normal(xs[1]), 0.0) end + N = 100 + model = setthreadsafe(threaded(zeros(N)), true) + ldf = DynamicPPL.LogDensityFunction(model) + + xs = [1.0] + @test LogDensityProblems.logdensity(ldf, xs) ≈ + logpdf(Normal(), xs[1]) + N * logpdf(Normal(xs[1]), 0.0) end end @@ -125,34 +123,32 @@ end end @testset "LogDensityFunction: performance" begin - if Threads.nthreads() == 1 - # Evaluating these three models should not lead to any allocations (but only when - # not using TSVI). - @model function f() - x ~ Normal() - return 1.0 ~ Normal(x) - end - @model function submodel_inner() - m ~ Normal(0, 1) - s ~ Exponential() - return (m=m, s=s) - end - # Note that for the allocation tests to work on this one, `inner` has - # to be passed as an argument to `submodel_outer`, instead of just - # being called inside the model function itself - @model function submodel_outer(inner) - params ~ to_submodel(inner) - y ~ Normal(params.m, params.s) - return 1.0 ~ Normal(y) - end - @testset for model in - (f(), submodel_inner() | (; s=0.0), submodel_outer(submodel_inner())) - vi = VarInfo(model) - ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint_internal, vi) - x = vi[:] - bench = median(@be LogDensityProblems.logdensity(ldf, x)) - @test iszero(bench.allocs) - end + # Evaluating these three models should not lead to any allocations (but only when + # not using TSVI). + @model function f() + x ~ Normal() + return 1.0 ~ Normal(x) + end + @model function submodel_inner() + m ~ Normal(0, 1) + s ~ Exponential() + return (m=m, s=s) + end + # Note that for the allocation tests to work on this one, `inner` has + # to be passed as an argument to `submodel_outer`, instead of just + # being called inside the model function itself + @model function submodel_outer(inner) + params ~ to_submodel(inner) + y ~ Normal(params.m, params.s) + return 1.0 ~ Normal(y) + end + @testset for model in + (f(), submodel_inner() | (; s=0.0), submodel_outer(submodel_inner())) + vi = VarInfo(model) + ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint_internal, vi) + x = vi[:] + bench = median(@be LogDensityProblems.logdensity($ldf, $x)) + @test iszero(bench.allocs) end end diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 522730566..879e936d6 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -5,13 +5,23 @@ @test threadsafe_vi.varinfo === vi @test threadsafe_vi.accs_by_thread isa Vector{<:DynamicPPL.AccumulatorTuple} - @test length(threadsafe_vi.accs_by_thread) == Threads.nthreads() * 2 + @test length(threadsafe_vi.accs_by_thread) == Threads.maxthreadid() expected_accs = DynamicPPL.AccumulatorTuple( (DynamicPPL.split(acc) for acc in DynamicPPL.getaccs(vi))... ) @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread) end + @testset "setthreadsafe" begin + @model f() = x ~ Normal() + model = f() + @test !DynamicPPL.requires_threadsafe(model) + model = setthreadsafe(model, true) + @test DynamicPPL.requires_threadsafe(model) + model = setthreadsafe(model, false) + @test !DynamicPPL.requires_threadsafe(model) + end + # TODO: Add more tests of the public API @testset "API" begin vi = VarInfo(gdemo_default) @@ -41,8 +51,6 @@ end @testset "model" begin - println("Peforming threading tests with $(Threads.nthreads()) threads") - x = rand(10_000) @model function wthreads(x) @@ -52,63 +60,24 @@ x[i] ~ Normal(x[i - 1], 1) end end - model = wthreads(x) - - vi = VarInfo() - model(vi) - lp_w_threads = getlogjoint(vi) - if Threads.nthreads() == 1 - @test vi_ isa VarInfo - else - @test vi_ isa DynamicPPL.ThreadSafeVarInfo - end - - println("With `@threads`:") - println(" default:") - @time model(vi) - - # Ensure that we use `ThreadSafeVarInfo` to handle multithreaded observe statements. - DynamicPPL.evaluate_threadsafe!!(model, vi) - @test getlogjoint(vi) ≈ lp_w_threads - # check that it's wrapped during the model evaluation - @test vi_ isa DynamicPPL.ThreadSafeVarInfo - # ensure that it's unwrapped after evaluation finishes - @test vi isa VarInfo + model = setthreadsafe(wthreads(x), true) - println(" evaluate_threadsafe!!:") - @time DynamicPPL.evaluate_threadsafe!!(model, vi) - - @model function wothreads(x) - global vi_ = __varinfo__ - x[1] ~ Normal(0, 1) + function correct_lp(x) + lp = logpdf(Normal(0, 1), x[1]) for i in 2:length(x) - x[i] ~ Normal(x[i - 1], 1) + lp += logpdf(Normal(x[i - 1], 1), x[i]) end + return lp end - model = wothreads(x) vi = VarInfo() - model(vi) - lp_wo_threads = getlogjoint(vi) - if Threads.nthreads() == 1 - @test vi_ isa VarInfo - else - @test vi_ isa DynamicPPL.ThreadSafeVarInfo - end + _, vi = DynamicPPL.evaluate!!(model, vi) - println("Without `@threads`:") - println(" default:") - @time model(vi) - - @test lp_w_threads ≈ lp_wo_threads - - # Ensure that we use `VarInfo`. - DynamicPPL.evaluate_threadunsafe!!(model, vi) - @test getlogjoint(vi) ≈ lp_w_threads - @test vi_ isa VarInfo + # check that logp is correct + @test getlogjoint(vi) ≈ correct_lp(x) + # check that varinfo was wrapped during the model evaluation + @test vi_ isa DynamicPPL.ThreadSafeVarInfo + # ensure that it's unwrapped after evaluation finishes @test vi isa VarInfo - - println(" evaluate_threadunsafe!!:") - @time DynamicPPL.evaluate_threadunsafe!!(model, vi) end end From 54ae7e30df1cfaa1f69202c8637d58afce55134d Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 1 Dec 2025 18:58:53 +0000 Subject: [PATCH 041/148] Standardise `:lp` -> `:logjoint` (#1161) * Standardise `:lp` -> `:logjoint` * changelog * fix a test --- HISTORY.md | 4 ++++ src/chains.jl | 4 ++-- test/chains.jl | 8 ++++---- test/ext/DynamicPPLMCMCChainsExt.jl | 2 +- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index c15b4136a..48f2efb0e 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -65,6 +65,10 @@ Leaf contexts require no changes, apart from a removal of the `NodeTrait` functi `ConditionContext` and `PrefixContext` are no longer exported. You should not need to use these directly, please use `AbstractPPL.condition` and `DynamicPPL.prefix` instead. +#### ParamsWithStats + +In the 'stats' part of `DynamicPPL.ParamsWithStats`, the log-joint is now consistently represented with the key `logjoint` instead of `lp`. + #### Miscellaneous Removed the method `returned(::Model, values, keys)`; please use `returned(::Model, ::AbstractDict{<:VarName})` instead. diff --git a/src/chains.jl b/src/chains.jl index d01606c3d..8ce4979c6 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -57,7 +57,7 @@ function ParamsWithStats( ( logprior=DynamicPPL.getlogprior(varinfo), loglikelihood=DynamicPPL.getloglikelihood(varinfo), - lp=DynamicPPL.getlogjoint(varinfo), + logjoint=DynamicPPL.getlogjoint(varinfo), ), ) end @@ -158,7 +158,7 @@ function ParamsWithStats( ( logprior=DynamicPPL.getlogprior(vi), loglikelihood=DynamicPPL.getloglikelihood(vi), - lp=DynamicPPL.getlogjoint(vi), + logjoint=DynamicPPL.getlogjoint(vi), ), ) end diff --git a/test/chains.jl b/test/chains.jl index 12a9ece71..498e2e912 100644 --- a/test/chains.jl +++ b/test/chains.jl @@ -20,9 +20,9 @@ using Test @test length(ps.params) == 2 @test haskey(ps.stats, :logprior) @test haskey(ps.stats, :loglikelihood) - @test haskey(ps.stats, :lp) + @test haskey(ps.stats, :logjoint) @test length(ps.stats) == 3 - @test ps.stats.lp ≈ ps.stats.logprior + ps.stats.loglikelihood + @test ps.stats.logjoint ≈ ps.stats.logprior + ps.stats.loglikelihood @test ps.params[@varname(y)] ≈ ps.params[@varname(x)] + 1 @test ps.stats.logprior ≈ logpdf(Normal(), ps.params[@varname(x)]) @test ps.stats.loglikelihood ≈ logpdf(Normal(ps.params[@varname(y)]), z) @@ -34,9 +34,9 @@ using Test @test length(ps.params) == 1 @test haskey(ps.stats, :logprior) @test haskey(ps.stats, :loglikelihood) - @test haskey(ps.stats, :lp) + @test haskey(ps.stats, :logjoint) @test length(ps.stats) == 3 - @test ps.stats.lp ≈ ps.stats.logprior + ps.stats.loglikelihood + @test ps.stats.logjoint ≈ ps.stats.logprior + ps.stats.loglikelihood @test ps.stats.logprior ≈ logpdf(Normal(), ps.params[@varname(x)]) @test ps.stats.loglikelihood ≈ logpdf(Normal(ps.params[@varname(x)] + 1), z) end diff --git a/test/ext/DynamicPPLMCMCChainsExt.jl b/test/ext/DynamicPPLMCMCChainsExt.jl index 6091492df..445270ef8 100644 --- a/test/ext/DynamicPPLMCMCChainsExt.jl +++ b/test/ext/DynamicPPLMCMCChainsExt.jl @@ -20,7 +20,7 @@ using DynamicPPL, Distributions, MCMCChains, Test, AbstractMCMC @test size(c, 1) == 50 @test size(c, 3) == 3 @test Set(c.name_map.parameters) == Set([:x, :y]) - @test Set(c.name_map.internals) == Set([:logprior, :loglikelihood, :lp]) + @test Set(c.name_map.internals) == Set([:logprior, :loglikelihood, :logjoint]) @test logpdf.(Normal(), c[:x]) ≈ c[:logprior] @test c.info.varname_to_symbol[@varname(x)] == :x @test c.info.varname_to_symbol[@varname(y)] == :y From 384e3ac398784af123f6e3596d83aab61a754a96 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 3 Dec 2025 17:36:55 +0000 Subject: [PATCH 042/148] Apply suggestions from code review Co-authored-by: Penelope Yong --- src/varnamedtuple.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index d3f2ba13a..9efc1efa2 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -132,7 +132,7 @@ julia> pa = PartialArray{Int,2}((1,2) => 5, (3,4) => 10) PartialArray{Int64,2}((1, 2) => 5, (3, 4) => 10) ``` -The optional keywoard argument `min_size` can be used to specify the minimum initial size. +The optional keyword argument `min_size` can be used to specify the minimum initial size. This is purely a performance optimisation, to avoid resizing if the eventual size is known ahead of time. """ @@ -521,9 +521,9 @@ end """ varname_to_lens(name::VarName{S}) where {S} -Convert a `VarName` to an `Accessor` lens, wrapping the first symdol in a `PropertyLens`. +Convert a `VarName` to an `Accessor` lens, wrapping the first symbol in a `PropertyLens`. -This is used to simplify method dispatch for `_getindx`, `_setindex!!`, and `_haskey`, by +This is used to simplify method dispatch for `_getindex`, `_setindex!!`, and `_haskey`, by considering `VarName`s to just be a special case of lenses. """ function _varname_to_lens(name::VarName{S}) where {S} From 9d61a54f5dc86ae94b597e068fcf60b0910c0194 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 3 Dec 2025 17:41:46 +0000 Subject: [PATCH 043/148] Add a microoptimisation --- src/varnamedtuple.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 9efc1efa2..f9040e6ad 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -304,8 +304,8 @@ function _resize_partialarray!!(pa::PartialArray, inds) # may use a linear index that does not match between the old and the new arrays. @inbounds for i in CartesianIndices(pa.data) mask_val = pa.mask[i] - new_mask[i] = mask_val if mask_val + new_mask[i] = mask_val new_data[i] = pa.data[i] end end From 8c8e39f98519b7765c1f1c1b3de9f11e675ea62b Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 3 Dec 2025 17:49:31 +0000 Subject: [PATCH 044/148] Improve docstrings --- src/varnamedtuple.jl | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index f9040e6ad..881fde767 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -474,14 +474,16 @@ is `merge`, which recursively merges two `VarNamedTuple`s. The there are two major limitations to indexing by VarNamedTuples: -* `VarName`s with `Colon`s, (e.g. `a[:]`) are not supported. This is because the meaning of `a[:]` is ambiguous if only some elements of `a`, say `a[1]` and `a[3]`, are defined. -* Any `VarNames` with IndexLenses` must have a consistent number of indices. That is, one cannot set `a[1]` and `a[1,2]` in the same `VarNamedTuple`. +* `VarName`s with `Colon`s, (e.g. `a[:]`) are not supported. This is because the meaning of + `a[:]` is ambiguous if only some elements of `a`, say `a[1]` and `a[3]`, are defined. +* Any `VarNames` with IndexLenses` must have a consistent number of indices. That is, one + cannot set `a[1]` and `a[1,2]` in the same `VarNamedTuple`. `setindex!!` and `getindex` on `VarNamedTuple` are type stable as long as one does not store -heterogeneous data under different indices of the same symbol. That is, if one either +heterogeneous data under different indices of the same symbol. That is, if either -* sets `a[1]` and `a[2]` to be of different types, or -* sets `a[1].b` and `a[2].c`, without setting `a[1].c`. or `a[2].b`, +* one sets `a[1]` and `a[2]` to be of different types, or +* if `a[1]` and `a[2]` both exist, one sets `a[1].b` without setting `a[2].b`, then getting values for `a[1]` or `a[2]` will not be type stable. From c818bf887ab474d15944394e62237f8be29b0f32 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 3 Dec 2025 18:01:15 +0000 Subject: [PATCH 045/148] Simplify use of QuoteNodes --- src/varnamedtuple.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 881fde767..1ca75d343 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -564,11 +564,11 @@ Base.merge(x1::VarNamedTuple, x2::VarNamedTuple) = _merge_recursive(x1, x2) push!(exs, :(data = (;))) for name in all_names val_expr = if name in names1 && name in names2 - :(_merge_recursive(vnt1.data[$(QuoteNode(name))], vnt2.data[$(QuoteNode(name))])) + :(_merge_recursive(vnt1.data.$name, vnt2.data.$name)) elseif name in names1 - :(vnt1.data[$(QuoteNode(name))]) + :(vnt1.data.$name) else - :(vnt2.data[$(QuoteNode(name))]) + :(vnt2.data.$name) end push!(exs, :(data = merge(data, NamedTuple{($(QuoteNode(name)),)}(($val_expr,))))) end From 3c02da40348be0ecdba27ec3bf13be6caec8375c Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 11 Dec 2025 16:16:36 +0000 Subject: [PATCH 046/148] Improve equality tests --- src/varnamedtuple.jl | 34 ++++++++++++++++++++- test/varnamedtuple.jl | 71 ++++++++++++++++++++++--------------------- 2 files changed, 70 insertions(+), 35 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 1ca75d343..db2462e53 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -208,13 +208,44 @@ function Base.:(==)(pa1::PartialArray, pa2::PartialArray) # TODO(mhauru) This could be optimised by not calling checkbounds on all elements # outside the size of an array, but not sure it's worth it. merge_size = ntuple(i -> max(size1[i], size2[i]), ndims(pa1)) + result = true for i in CartesianIndices(merge_size) m1 = checkbounds(Bool, pa1.mask, Tuple(i)...) ? pa1.mask[i] : false m2 = checkbounds(Bool, pa2.mask, Tuple(i)...) ? pa2.mask[i] : false if m1 != m2 return false end - if m1 && (pa1.data[i] != pa2.data[i]) + if m1 + elements_equal = pa1.data[i] == pa2.data[i] + if elements_equal === false + return false + elseif elements_equal === missing + # This branch can't short-circuit and just return missing, because some + # later values may be straight-up not equal. + result = missing + end + end + end + return result +end + +# Exactly as == above, except the comparison of the data elements uses isequal. +function Base.isequal(pa1::PartialArray, pa2::PartialArray) + if ndims(pa1) != ndims(pa2) + return false + end + size1 = _internal_size(pa1) + size2 = _internal_size(pa2) + # TODO(mhauru) This could be optimised by not calling checkbounds on all elements + # outside the size of an array, but not sure it's worth it. + merge_size = ntuple(i -> max(size1[i], size2[i]), ndims(pa1)) + for i in CartesianIndices(merge_size) + m1 = checkbounds(Bool, pa1.mask, Tuple(i)...) ? pa1.mask[i] : false + m2 = checkbounds(Bool, pa2.mask, Tuple(i)...) ? pa2.mask[i] : false + if m1 != m2 + return false + end + if m1 && !isequal(pa1.data[i], pa2.data[i]) return false end end @@ -497,6 +528,7 @@ end VarNamedTuple(; kwargs...) = VarNamedTuple((; kwargs...)) Base.:(==)(vnt1::VarNamedTuple, vnt2::VarNamedTuple) = vnt1.data == vnt2.data +Base.isequal(vnt1::VarNamedTuple, vnt2::VarNamedTuple) = isequal(vnt1.data, vnt2.data) Base.hash(vnt::VarNamedTuple, h::UInt) = hash(vnt.data, h) function Base.show(io::IO, vnt::VarNamedTuple) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 67f3d5c2b..3beadebf8 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -1,5 +1,6 @@ module VarNamedTupleTests +using Combinatorics: Combinatorics using Test: @inferred, @test, @test_throws, @testset using DynamicPPL: DynamicPPL, @varname, VarNamedTuple using DynamicPPL.VarNamedTuples: PartialArray @@ -20,11 +21,13 @@ function test_invariants(vnt::VarNamedTuple) v = getindex(vnt, k) vnt2 = setindex!!(copy(vnt), v, k) @test vnt == vnt2 + @test isequal(vnt, vnt2) @test hash(vnt) == hash(vnt2) end # Check that the printed representation can be parsed back to an equal VarNamedTuple. vnt3 = eval(Meta.parse(repr(vnt))) @test vnt == vnt3 + @test isequal(vnt, vnt3) @test hash(vnt) == hash(vnt3) # Check that merge with an empty VarNamedTuple is a no-op. @test merge(vnt, VarNamedTuple()) == vnt @@ -218,40 +221,40 @@ end test_invariants(vnt) end - @testset "equality" begin - vnt1 = VarNamedTuple() - vnt2 = VarNamedTuple() - @test vnt1 == vnt2 - - vnt1 = setindex!!(vnt1, 1.0, @varname(a)) - @test vnt1 != vnt2 - - vnt2 = setindex!!(vnt2, 1.0, @varname(a)) - @test vnt1 == vnt2 - - vnt1 = setindex!!(vnt1, [1, 2], @varname(b)) - vnt2 = setindex!!(vnt2, [1, 2], @varname(b)) - @test vnt1 == vnt2 - - vnt2 = setindex!!(vnt2, [1, 3], @varname(b)) - @test vnt1 != vnt2 - vnt2 = setindex!!(vnt2, [1, 2], @varname(b)) - - # Try with index lenses too - vnt1 = setindex!!(vnt1, 2, @varname(c[2])) - vnt2 = setindex!!(vnt2, 2, @varname(c[2])) - @test vnt1 == vnt2 - - vnt2 = setindex!!(vnt2, 3, @varname(c[2])) - @test vnt1 != vnt2 - vnt2 = setindex!!(vnt2, 2, @varname(c[2])) - - vnt1 = setindex!!(vnt1, ["a", "b"], @varname(d.e[1:2])) - vnt2 = setindex!!(vnt2, ["a", "b"], @varname(d.e[1:2])) - @test vnt1 == vnt2 - - vnt2 = setindex!!(vnt2, :b, @varname(d.e[2])) - @test vnt1 != vnt2 + @testset "equality and hash" begin + # Test all combinations of having or not having the below values set, and having + # them set to any of the possible_values, and check that isequal and == return the + # expected value. + # NOTE: Be very careful adding new values to these sets. The below test has three + # nested loops over Combinatorics.combinations, the run time can explode very, very + # quickly. + varnames = (@varname(b[1]), @varname(b[3]), @varname(c.d[2].e)) + possible_values = (missing, 1, -0.0, 0.0) + for vn_set in Combinatorics.combinations(varnames) + valuesets1 = Combinatorics.with_replacement_combinations( + possible_values, length(vn_set) + ) + valuesets2 = Combinatorics.with_replacement_combinations( + possible_values, length(vn_set) + ) + for vset1 in valuesets1, vset2 in valuesets2 + vnt1 = VarNamedTuple() + vnt2 = VarNamedTuple() + expected_isequal = true + expected_doubleequal = true + for (vn, v1, v2) in zip(vn_set, vset1, vset2) + vnt1 = setindex!!(vnt1, v1, vn) + vnt2 = setindex!!(vnt2, v2, vn) + expected_isequal = expected_isequal & isequal(v1, v2) + expected_doubleequal = expected_doubleequal & (v1 == v2) + end + @test isequal(vnt1, vnt2) == expected_isequal + @test (vnt1 == vnt2) === expected_doubleequal + if expected_isequal + @test hash(vnt1) == hash(vnt2) + end + end + end end @testset "merge" begin From 35c3e20b1247493f74b38f582c170a4ab4957dca Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 11 Dec 2025 19:00:30 +0000 Subject: [PATCH 047/148] ArrayLikeBlock WIP --- src/varnamedtuple.jl | 37 ++++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index db2462e53..0e984b7a7 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -55,6 +55,11 @@ const PARTIAL_ARRAY_DIM_GROWTH_FACTOR = 4 """A convenience for defining method argument type bounds.""" const INDEX_TYPES = Union{Integer,UnitRange,Colon} +struct ArrayLikeBlock{T,I} + block::T + inds::I +end + """ PartialArray{ElType,numdims} @@ -105,6 +110,9 @@ means that the largest index set so far determines the memory usage of the `Part a few scattered values are set, a structure like `SparseArray` may be more appropriate. """ struct PartialArray{ElType,num_dims} + # TODO(mhauru) Consider trying FixedSizeArrays instead, see how it would change + # performance. We reallocate new Arrays every time when resizing anyway, except for + # Vectors, which can be extended in place. data::Array{ElType,num_dims} mask::Array{Bool,num_dims} @@ -395,7 +403,34 @@ function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) else _resize_partialarray!!(pa, inds) end - new_data = setindex!!(pa.data, value, inds...) + + new_data = pa.data + if _is_multiindex(inds) && !(isa(value, AbstractArray)) + if !hasmethod(size, value) + throw(ArgumentError("Cannot assign a scalar value to a range.")) + end + if size(value) != map(x -> _length_needed(x), inds) + throw( + DimensionMismatch( + "Assigned value has size $(size(value)), which does not match the size " * + "implied by the indices $(map(x -> _length_needed(x), inds)).", + ), + ) + end + # At this point we know we have a value that is not an AbstractArray, but it has + # some notion of size, and that size matches the indices that are being set. In this + # case we wrap the value in a ArrayLikeBlock, and set all the individual indices + # point to that, with the right subindices. + first_index = first.(inds) + # Iterate over all the subindices of inds. + for ind in CartesianIndices(map(x -> _length_needed(x), inds)) + subinds = ntuple(i -> first_index[i] + ind[i] - 1, length(inds)) + new_data = _setindex!!(new_data, ArrayLikeBlock(value, Tuple(ind)), subinds...) + end + else + new_data = setindex!!(new_data, value, inds...) + end + if _is_multiindex(inds) pa.mask[inds...] .= true else From 4253e9b53c89aabe80f23fde3c4dfc059bf20d11 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 12 Dec 2025 18:17:23 +0000 Subject: [PATCH 048/148] ArrayLikeBlock WIP2 --- src/varnamedtuple.jl | 118 +++++++++++++++++++++++++++++++++++++----- test/varnamedtuple.jl | 50 ++++++++++++++++++ 2 files changed, 155 insertions(+), 13 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 0e984b7a7..fa711e4f4 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -58,6 +58,13 @@ const INDEX_TYPES = Union{Integer,UnitRange,Colon} struct ArrayLikeBlock{T,I} block::T inds::I + + function ArrayLikeBlock(block::T, inds::I) where {T,I} + if !_is_multiindex(inds) + throw(ArgumentError("ArrayLikeBlock must be constructed with a multi-index")) + end + return new{T,I}(block, inds) + end end """ @@ -385,15 +392,102 @@ end function _getindex(pa::PartialArray, inds::Vararg{INDEX_TYPES}) _check_index_validity(pa, inds) - if !_haskey(pa, inds) + if !(checkbounds(Bool, pa.mask, inds...) && all(@inbounds(getindex(pa.mask, inds...)))) throw(BoundsError(pa, inds)) end - return getindex(pa.data, inds...) + val = getindex(pa.data, inds...) + + # If not for ArrayLikeBlocks, at this point we could just return val directly. However, + # we need to check if val contains any ArrayLikeBlocks, and if so, make sure that that + # we are retrieving exactly that block and nothing else. + + # The error we'll throw if the retrieval is invalid. + err = ArgumentError(""" + A non-Array value set with a range of indices must be retrieved with the same + range of indices. + """) + if val isa ArrayLikeBlock + # Tried to get a single value, but it's an ArrayLikeBlock. + throw(err) + elseif val isa Array && (eltype(val) <: ArrayLikeBlock || ArrayLikeBlock <: eltype(val)) + # Tried to get a range of values, and at least some of them may be ArrayLikeBlocks. + # The below isempty check is deliberately kept separate from the outer elseif, + # because the outer one can be resolved at compile time. + if isempty(val) + return val + end + first_elem = first(val) + if !(first_elem isa ArrayLikeBlock) + throw(err) + end + if inds != first_elem.inds + # The requested indices do not match the ones used to set the value. + throw(err) + end + # If _setindex!! works correctly, we should only be able to reach this point if all + # the elements in `val` are identical to first_elem. Thus we just return that one. + return first(val).block + else + return val + end end function _haskey(pa::PartialArray, inds::NTuple{N,INDEX_TYPES}) where {N} _check_index_validity(pa, inds) - return checkbounds(Bool, pa.mask, inds...) && all(@inbounds(getindex(pa.mask, inds...))) + hasall = + checkbounds(Bool, pa.mask, inds...) && all(@inbounds(getindex(pa.mask, inds...))) + + # If not for ArrayLikeBlocks, we could just return hasall directly. However, we need to + # check that if any ArrayLikeBlocks are included, they are fully included. + et = eltype(pa) + if !(et <: ArrayLikeBlock || ArrayLikeBlock <: et) + # pa can't possibly hold any ArrayLikeBlocks, so nothing to do. + return hasall + end + + if !hasall + return false + end + # From this point on we can assume that all the requested elements are set, and the only + # thing to check is that we are not partially indexing into any ArrayLikeBlocks. + # We've already checked checkbounds at the top of the function, and returned if it + # wasn't true, so @inbounds is safe. + subdata = @inbounds getindex(pa.data, inds...) + if !_is_multiindex(inds) + return !(subdata isa ArrayLikeBlock) + end + return !any(elem -> elem isa ArrayLikeBlock && elem.inds != inds, subdata) +end + +function BangBang.delete!!(pa::PartialArray, inds::Vararg{INDEX_TYPES}) + _check_index_validity(pa, inds) + if _is_multiindex(inds) + pa.mask[inds...] .= false + else + pa.mask[inds...] = false + end + return _concretise_eltype!!(pa) +end + +_ensure_range(r::UnitRange) = r +_ensure_range(i::Integer) = i:i + +function _remove_partial_blocks!!(pa::PartialArray, inds::Vararg{INDEX_TYPES}) + et = eltype(pa) + if !(et <: ArrayLikeBlock || ArrayLikeBlock <: et) + # pa can't possibly hold any ArrayLikeBlocks, so nothing to do. + return pa + end + + for i in CartesianIndices(map(_ensure_range, inds)) + if pa.mask[i] + val = @inbounds pa.data[i] + if val isa ArrayLikeBlock + pa = delete!!(pa, val.inds...) + end + end + end + return pa end function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) @@ -403,13 +497,15 @@ function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) else _resize_partialarray!!(pa, inds) end + pa = _remove_partial_blocks!!(pa, inds...) new_data = pa.data if _is_multiindex(inds) && !(isa(value, AbstractArray)) - if !hasmethod(size, value) + if !hasmethod(size, Tuple{typeof(value)}) throw(ArgumentError("Cannot assign a scalar value to a range.")) end - if size(value) != map(x -> _length_needed(x), inds) + inds_size = reduce((x, y) -> tuple(x..., y...), map(size, inds)) + if size(value) != inds_size throw( DimensionMismatch( "Assigned value has size $(size(value)), which does not match the size " * @@ -419,14 +515,10 @@ function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) end # At this point we know we have a value that is not an AbstractArray, but it has # some notion of size, and that size matches the indices that are being set. In this - # case we wrap the value in a ArrayLikeBlock, and set all the individual indices - # point to that, with the right subindices. - first_index = first.(inds) - # Iterate over all the subindices of inds. - for ind in CartesianIndices(map(x -> _length_needed(x), inds)) - subinds = ntuple(i -> first_index[i] + ind[i] - 1, length(inds)) - new_data = _setindex!!(new_data, ArrayLikeBlock(value, Tuple(ind)), subinds...) - end + # case we wrap the value in an ArrayLikeBlock, and set all the individual indices + # point to that. + alb = ArrayLikeBlock(value, inds) + new_data = setindex!!(new_data, fill(alb, inds_size...), inds...) else new_data = setindex!!(new_data, value, inds...) end diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 3beadebf8..9365d5a7b 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -2,6 +2,7 @@ module VarNamedTupleTests using Combinatorics: Combinatorics using Test: @inferred, @test, @test_throws, @testset +using Distributions: Dirichlet using DynamicPPL: DynamicPPL, @varname, VarNamedTuple using DynamicPPL.VarNamedTuples: PartialArray using AbstractPPL: VarName, prefix @@ -458,6 +459,55 @@ end VarNamedTuple(f = VarNamedTuple(g = PartialArray{Float64,1}((1,) => 16.0, \ (2,) => 17.0),),)),))""" end + + @testset "block variables" begin + # Tests for setting and getting block variables, i.e. variables that have a non-zero + # size in a PartialArray, but are not Arrays themselves. + expected_err = ArgumentError(""" + A non-Array value set with a range of indices must be retrieved with the same + range of indices. + """) + vnt = VarNamedTuple() + vnt = setindex!!(vnt, Dirichlet(3, 1.0), @varname(x[2:4])) + @test haskey(vnt, @varname(x[2:4])) + @test getindex(vnt, @varname(x[2:4])) == Dirichlet(3, 1.0) + @test !haskey(vnt, @varname(x[2:3])) + @test_throws expected_err getindex(vnt, @varname(x[2:3])) + @test !haskey(vnt, @varname(x[3])) + @test_throws expected_err getindex(vnt, @varname(x[3])) + @test !haskey(vnt, @varname(x[1])) + @test !haskey(vnt, @varname(x[5])) + vnt = setindex!!(vnt, 1.0, @varname(x[1])) + vnt = setindex!!(vnt, 1.0, @varname(x[5])) + @test haskey(vnt, @varname(x[1])) + @test haskey(vnt, @varname(x[5])) + @test_throws expected_err getindex(vnt, @varname(x[1:4])) + @test_throws expected_err getindex(vnt, @varname(x[2:5])) + + # Setting any of these indices should remove the block variable x[2:4]. + @testset "index = $index" for index in (2, 3, 4, 2:3, 3:5) + # Test setting different types of values. + vals = if index isa Int + (2.0,) + else + (fill(2.0, length(index)), Dirichlet(length(index), 2.0)) + end + @testset "val = $val" for val in vals + vn = @varname(x[index]) + vnt2 = copy(vnt) + vnt2 = setindex!!(vnt2, val, vn) + @test !haskey(vnt2, @varname(x[2:4])) + @test_throws BoundsError getindex(vnt2, @varname(x[2:4])) + other_index = index in (2, 2:3) ? 4 : 2 + @test !haskey(vnt2, @varname(x[other_index])) + @test_throws BoundsError getindex(vnt2, @varname(x[other_index])) + @test haskey(vnt2, vn) + @test getindex(vnt2, vn) == val + @test haskey(vnt2, @varname(x[1])) + @test_throws BoundsError getindex(vnt2, @varname(x[1:4])) + end + end + end end end From 5cb3916ddf1898b36a54e590e88ced043ee18765 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 12 Dec 2025 18:50:24 +0000 Subject: [PATCH 049/148] Improve type stability of ArrayLikeBlock stuff --- src/varnamedtuple.jl | 30 +++++++++++++++++++++++++----- test/varnamedtuple.jl | 28 ++++++++++++++++++++++++++-- 2 files changed, 51 insertions(+), 7 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index fa711e4f4..3210f474c 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -67,6 +67,8 @@ struct ArrayLikeBlock{T,I} end end +_blocktype(::Type{ArrayLikeBlock{T}}) where {T} = T + """ PartialArray{ElType,numdims} @@ -414,7 +416,13 @@ function _getindex(pa::PartialArray, inds::Vararg{INDEX_TYPES}) # The below isempty check is deliberately kept separate from the outer elseif, # because the outer one can be resolved at compile time. if isempty(val) - return val + # We need to return an empty array, but for type stability, we want to unwrap + # any ArrayLikeBlock types in the element type. + return if eltype(val) <: ArrayLikeBlock + Array{_blocktype(eltype(val)),ndims(val)}() + else + val + end end first_elem = first(val) if !(first_elem isa ArrayLikeBlock) @@ -490,6 +498,12 @@ function _remove_partial_blocks!!(pa::PartialArray, inds::Vararg{INDEX_TYPES}) return pa end +function _needs_arraylikeblock(value, inds::Vararg{INDEX_TYPES}) + return _is_multiindex(inds) && + !isa(value, AbstractArray) && + hasmethod(size, Tuple{typeof(value)}) +end + function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) _check_index_validity(pa, inds) pa = if checkbounds(Bool, pa.mask, inds...) @@ -500,7 +514,7 @@ function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) pa = _remove_partial_blocks!!(pa, inds...) new_data = pa.data - if _is_multiindex(inds) && !(isa(value, AbstractArray)) + if _needs_arraylikeblock(value, inds...) if !hasmethod(size, Tuple{typeof(value)}) throw(ArgumentError("Cannot assign a scalar value to a range.")) end @@ -843,9 +857,15 @@ end function make_leaf(value, optic::IndexLens) inds = optic.indices num_inds = length(inds) - # Check if any of the indices are ranges or colons. If yes, value needs to be an - # AbstractArray. Otherwise it needs to be an individual value. - et = _is_multiindex(inds) ? eltype(value) : typeof(value) + # The element type of the PartialArray depends on whether we are setting a single value + # or a range of values. + et = if !_is_multiindex(inds) + typeof(value) + elseif _needs_arraylikeblock(value, inds...) + ArrayLikeBlock{typeof(value),typeof(inds)} + else + eltype(value) + end pa = PartialArray{et,num_inds}() return _setindex!!(pa, value, optic) end diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 9365d5a7b..fe66ab317 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -468,9 +468,9 @@ end range of indices. """) vnt = VarNamedTuple() - vnt = setindex!!(vnt, Dirichlet(3, 1.0), @varname(x[2:4])) + vnt = @inferred(setindex!!(vnt, Dirichlet(3, 1.0), @varname(x[2:4]))) @test haskey(vnt, @varname(x[2:4])) - @test getindex(vnt, @varname(x[2:4])) == Dirichlet(3, 1.0) + @test @inferred(getindex(vnt, @varname(x[2:4]))) == Dirichlet(3, 1.0) @test !haskey(vnt, @varname(x[2:3])) @test_throws expected_err getindex(vnt, @varname(x[2:3])) @test !haskey(vnt, @varname(x[3])) @@ -507,6 +507,30 @@ end @test_throws BoundsError getindex(vnt2, @varname(x[1:4])) end end + + # Extra checks, mostly for type stability and to confirm that multidimensional + # blocks work too. + struct TwoByTwoBlock end + Base.size(::TwoByTwoBlock) = (2, 2) + val = TwoByTwoBlock() + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, val, @varname(y.z[1:2, 1:2]))) + @test haskey(vnt, @varname(y.z[1:2, 1:2])) + @test @inferred(getindex(vnt, @varname(y.z[1:2, 1:2]))) == val + @test !haskey(vnt, @varname(y.z[1, 1])) + @test_throws expected_err getindex(vnt, @varname(y.z[1, 1])) + + vnt = @inferred(setindex!!(vnt, val, @varname(y.z[2:3, 2:3]))) + @test haskey(vnt, @varname(y.z[2:3, 2:3])) + @test @inferred(getindex(vnt, @varname(y.z[2:3, 2:3]))) == val + @test !haskey(vnt, @varname(y.z[1:2, 1:2])) + @test_throws BoundsError getindex(vnt, @varname(y.z[1:2, 1:2])) + + vnt = @inferred(setindex!!(vnt, val, @varname(y.z[4:5, 2:3]))) + @test haskey(vnt, @varname(y.z[2:3, 2:3])) + @test @inferred(getindex(vnt, @varname(y.z[2:3, 2:3]))) == val + @test haskey(vnt, @varname(y.z[4:5, 2:3])) + @test @inferred(getindex(vnt, @varname(y.z[4:5, 2:3]))) == val end end From a96bb44a6cc5ca48dd293a8ba708f41cdce58300 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 15 Dec 2025 09:07:39 +0000 Subject: [PATCH 050/148] Test more invariants --- test/varnamedtuple.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index fe66ab317..2f113aacc 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -469,6 +469,7 @@ end """) vnt = VarNamedTuple() vnt = @inferred(setindex!!(vnt, Dirichlet(3, 1.0), @varname(x[2:4]))) + test_invariants(vnt) @test haskey(vnt, @varname(x[2:4])) @test @inferred(getindex(vnt, @varname(x[2:4]))) == Dirichlet(3, 1.0) @test !haskey(vnt, @varname(x[2:3])) @@ -479,6 +480,7 @@ end @test !haskey(vnt, @varname(x[5])) vnt = setindex!!(vnt, 1.0, @varname(x[1])) vnt = setindex!!(vnt, 1.0, @varname(x[5])) + test_invariants(vnt) @test haskey(vnt, @varname(x[1])) @test haskey(vnt, @varname(x[5])) @test_throws expected_err getindex(vnt, @varname(x[1:4])) @@ -496,6 +498,7 @@ end vn = @varname(x[index]) vnt2 = copy(vnt) vnt2 = setindex!!(vnt2, val, vn) + test_invariants(vnt) @test !haskey(vnt2, @varname(x[2:4])) @test_throws BoundsError getindex(vnt2, @varname(x[2:4])) other_index = index in (2, 2:3) ? 4 : 2 @@ -515,18 +518,21 @@ end val = TwoByTwoBlock() vnt = VarNamedTuple() vnt = @inferred(setindex!!(vnt, val, @varname(y.z[1:2, 1:2]))) + test_invariants(vnt) @test haskey(vnt, @varname(y.z[1:2, 1:2])) @test @inferred(getindex(vnt, @varname(y.z[1:2, 1:2]))) == val @test !haskey(vnt, @varname(y.z[1, 1])) @test_throws expected_err getindex(vnt, @varname(y.z[1, 1])) vnt = @inferred(setindex!!(vnt, val, @varname(y.z[2:3, 2:3]))) + test_invariants(vnt) @test haskey(vnt, @varname(y.z[2:3, 2:3])) @test @inferred(getindex(vnt, @varname(y.z[2:3, 2:3]))) == val @test !haskey(vnt, @varname(y.z[1:2, 1:2])) @test_throws BoundsError getindex(vnt, @varname(y.z[1:2, 1:2])) vnt = @inferred(setindex!!(vnt, val, @varname(y.z[4:5, 2:3]))) + test_invariants(vnt) @test haskey(vnt, @varname(y.z[2:3, 2:3])) @test @inferred(getindex(vnt, @varname(y.z[2:3, 2:3]))) == val @test haskey(vnt, @varname(y.z[4:5, 2:3])) From a8014e6208f2cd346d400329fc2a71ae73f017c9 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 15 Dec 2025 10:21:21 +0000 Subject: [PATCH 051/148] Actually run VNT tests --- test/runtests.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/runtests.jl b/test/runtests.jl index 9649aebbb..e0b42904c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -53,6 +53,7 @@ include("test_util.jl") include("utils.jl") include("accumulators.jl") include("compiler.jl") + include("varnamedtuple.jl") include("varnamedvector.jl") include("varinfo.jl") include("simple_varinfo.jl") From cfc60419fde7736ec4d051b7d99d61dbb5fb5a61 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 15 Dec 2025 11:13:05 +0000 Subject: [PATCH 052/148] Implement show for ArrayLikeBlock --- src/varnamedtuple.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 3210f474c..5dfbf153e 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -67,6 +67,18 @@ struct ArrayLikeBlock{T,I} end end +function Base.show(io::IO, alb::ArrayLikeBlock) + # Note the distinction: The raw strings that form part of the structure of the print + # out are `print`ed, whereas the keys and values are `show`n. The latter ensures + # that strings are quoted, Symbols are prefixed with :, etc. + print(io, "ArrayLikeBlock(") + show(io, alb.block) + print(io, ", ") + show(io, alb.inds) + print(io, ")") + return nothing +end + _blocktype(::Type{ArrayLikeBlock{T}}) where {T} = T """ From e198fbb1c4234c0b4b69fe493751f7cba0b54a30 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 15 Dec 2025 11:14:16 +0000 Subject: [PATCH 053/148] Change keys on VNT to return an array --- src/varnamedtuple.jl | 15 ++++----------- test/varnamedtuple.jl | 28 ++++++++++++++-------------- 2 files changed, 18 insertions(+), 25 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 5dfbf153e..62c36d021 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -792,25 +792,18 @@ function apply!!(func, vnt::VarNamedTuple, name::VarName) return _setindex!!(vnt, new_subdata, name) end -# TODO(mhauru) Should this return tuples, like it does now? That makes sense for -# VarNamedTuple itself, but if there is a nested PartialArray the tuple might get very big. -# Also, this is not very type stable, it fails even in basic cases. A generated function -# would help, but I failed to make one. Might be something to do with a recursive -# generated function. function Base.keys(vnt::VarNamedTuple) - result = () + result = VarName[] for sym in keys(vnt.data) subdata = vnt.data[sym] if subdata isa VarNamedTuple subkeys = keys(subdata) - result = ( - result..., (AbstractPPL.prefix(sk, VarName{sym}()) for sk in subkeys)... - ) + append!(result, [AbstractPPL.prefix(sk, VarName{sym}()) for sk in subkeys]) elseif subdata isa PartialArray subkeys = keys(subdata) - result = (result..., (VarName{sym}(lens) for lens in subkeys)...) + append!(result, [VarName{sym}(lens) for lens in subkeys]) else - result = (result..., VarName{sym}()) + push!(result, VarName{sym}()) end end return result diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 2f113aacc..41bcd5fd5 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -343,36 +343,36 @@ end @testset "keys" begin vnt = VarNamedTuple() - @test @inferred(keys(vnt)) == () + @test @inferred(keys(vnt)) == VarName[] vnt = setindex!!(vnt, 1.0, @varname(a)) # TODO(mhauru) that the below passes @inferred, but any of the later ones don't. # We should improve type stability of keys(). - @test @inferred(keys(vnt)) == (@varname(a),) + @test @inferred(keys(vnt)) == [@varname(a)] vnt = setindex!!(vnt, [1, 2, 3], @varname(b)) - @test keys(vnt) == (@varname(a), @varname(b)) + @test keys(vnt) == [@varname(a), @varname(b)] vnt = setindex!!(vnt, 15, @varname(b[2])) - @test keys(vnt) == (@varname(a), @varname(b)) + @test keys(vnt) == [@varname(a), @varname(b)] vnt = setindex!!(vnt, [10], @varname(c.x.y)) - @test keys(vnt) == (@varname(a), @varname(b), @varname(c.x.y)) + @test keys(vnt) == [@varname(a), @varname(b), @varname(c.x.y)] vnt = setindex!!(vnt, -1.0, @varname(d[4])) - @test keys(vnt) == (@varname(a), @varname(b), @varname(c.x.y), @varname(d[4])) + @test keys(vnt) == [@varname(a), @varname(b), @varname(c.x.y), @varname(d[4])] vnt = setindex!!(vnt, 2.0, @varname(e.f[3, 3].g.h[2, 4, 1].i)) - @test keys(vnt) == ( + @test keys(vnt) == [ @varname(a), @varname(b), @varname(c.x.y), @varname(d[4]), @varname(e.f[3, 3].g.h[2, 4, 1].i), - ) + ] vnt = setindex!!(vnt, fill(1.0, 4), @varname(j[1:4])) - @test keys(vnt) == ( + @test keys(vnt) == [ @varname(a), @varname(b), @varname(c.x.y), @@ -382,10 +382,10 @@ end @varname(j[2]), @varname(j[3]), @varname(j[4]), - ) + ] vnt = setindex!!(vnt, 1.0, @varname(j[6])) - @test keys(vnt) == ( + @test keys(vnt) == [ @varname(a), @varname(b), @varname(c.x.y), @@ -396,10 +396,10 @@ end @varname(j[3]), @varname(j[4]), @varname(j[6]), - ) + ] vnt = setindex!!(vnt, 1.0, @varname(n[2].a)) - @test keys(vnt) == ( + @test keys(vnt) == [ @varname(a), @varname(b), @varname(c.x.y), @@ -411,7 +411,7 @@ end @varname(j[4]), @varname(j[6]), @varname(n[2].a), - ) + ] end @testset "printing" begin From b77b0af1d64ee60f5703de165e6f187d707f7550 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 15 Dec 2025 11:15:54 +0000 Subject: [PATCH 054/148] Fix keys and some tests for PartialArray --- src/varnamedtuple.jl | 2 ++ test/varnamedtuple.jl | 25 +++++++++++++++++-------- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 62c36d021..711fc6037 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -639,6 +639,8 @@ function Base.keys(pa::PartialArray) sublens = _varname_to_lens(vn) push!(ks, _compose_no_identity(sublens, lens)) end + elseif val isa ArrayLikeBlock + push!(ks, IndexLens(Tuple(val.inds))) else push!(ks, lens) end diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 41bcd5fd5..67c3621a7 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -2,9 +2,8 @@ module VarNamedTupleTests using Combinatorics: Combinatorics using Test: @inferred, @test, @test_throws, @testset -using Distributions: Dirichlet using DynamicPPL: DynamicPPL, @varname, VarNamedTuple -using DynamicPPL.VarNamedTuples: PartialArray +using DynamicPPL.VarNamedTuples: PartialArray, ArrayLikeBlock using AbstractPPL: VarName, prefix using BangBang: setindex!! @@ -20,12 +19,18 @@ function test_invariants(vnt::VarNamedTuple) for k in keys(vnt) @test haskey(vnt, k) v = getindex(vnt, k) + # ArrayLikeBlocks are an implementation detail, and should not be exposed through + # getindex. + @test !(v isa ArrayLikeBlock) vnt2 = setindex!!(copy(vnt), v, k) @test vnt == vnt2 @test isequal(vnt, vnt2) @test hash(vnt) == hash(vnt2) end # Check that the printed representation can be parsed back to an equal VarNamedTuple. + # The below eval test is a bit fragile: If any elements in vnt don't respect the same + # reconstructability-from-repr property, this will fail. Likewise if any element uses + # in its repr print out types that are not in scope in this module, it will fail. vnt3 = eval(Meta.parse(repr(vnt))) @test vnt == vnt3 @test isequal(vnt, vnt3) @@ -461,6 +466,12 @@ end end @testset "block variables" begin + """ A type that has a size but is not an Array.""" + struct SizedThing + size::Tuple + end + Base.size(st::SizedThing) = st.size + # Tests for setting and getting block variables, i.e. variables that have a non-zero # size in a PartialArray, but are not Arrays themselves. expected_err = ArgumentError(""" @@ -468,10 +479,10 @@ end range of indices. """) vnt = VarNamedTuple() - vnt = @inferred(setindex!!(vnt, Dirichlet(3, 1.0), @varname(x[2:4]))) + vnt = @inferred(setindex!!(vnt, SizedThing((3,)), @varname(x[2:4]))) test_invariants(vnt) @test haskey(vnt, @varname(x[2:4])) - @test @inferred(getindex(vnt, @varname(x[2:4]))) == Dirichlet(3, 1.0) + @test @inferred(getindex(vnt, @varname(x[2:4]))) == SizedThing((3,)) @test !haskey(vnt, @varname(x[2:3])) @test_throws expected_err getindex(vnt, @varname(x[2:3])) @test !haskey(vnt, @varname(x[3])) @@ -492,7 +503,7 @@ end vals = if index isa Int (2.0,) else - (fill(2.0, length(index)), Dirichlet(length(index), 2.0)) + (fill(2.0, length(index)), SizedThing((length(index),))) end @testset "val = $val" for val in vals vn = @varname(x[index]) @@ -513,9 +524,7 @@ end # Extra checks, mostly for type stability and to confirm that multidimensional # blocks work too. - struct TwoByTwoBlock end - Base.size(::TwoByTwoBlock) = (2, 2) - val = TwoByTwoBlock() + val = SizedThing((2, 2)) vnt = VarNamedTuple() vnt = @inferred(setindex!!(vnt, val, @varname(y.z[1:2, 1:2]))) test_invariants(vnt) From 633e920c561bb536a4b0c633c70f10a5ee5952a3 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 15 Dec 2025 11:39:09 +0000 Subject: [PATCH 055/148] Improve type stability --- src/varnamedtuple.jl | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 711fc6037..36ddd7377 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -27,9 +27,30 @@ function _setindex!!(arr::AbstractArray, value, optic::IndexLens) end # Some utilities for checking what sort of indices we are dealing with. -_has_colon(::T) where {T<:Tuple} = any(x <: Colon for x in T.parameters) -function _is_multiindex(::T) where {T<:Tuple} - return any(x <: UnitRange || x <: Colon for x in T.parameters) +# The non-generated function implementations of these would be +# _has_colon(::T) where {T<:Tuple} = any(x <: Colon for x in T.parameters) +# function _is_multiindex(::T) where {T<:Tuple} +# return any(x <: UnitRange || x <: Colon for x in T.parameters) +# end +# However, constant propagation sometimes fails if the index tuple is too big (e.g. length +# 4), so we play it safe and use generated functions. Constant propagating these is +# important, because many functions choose different paths based on their values, which +# would lead to type instability if they were only evaluated at runtime. +@generated function _has_colon(::T) where {T<:Tuple} + for x in T.parameters + if x <: Colon + return :(true) + end + end + return :(false) +end +@generated function _is_multiindex(::T) where {T<:Tuple} + for x in T.parameters + if x <: UnitRange || x <: Colon + return :(true) + end + end + return :(false) end """ From 222334a97219d6786cbac319285ebe2441d12230 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 15 Dec 2025 11:39:56 +0000 Subject: [PATCH 056/148] Fix keys for PartialArray --- src/varnamedtuple.jl | 6 +++++- test/varnamedtuple.jl | 28 ++++++++++++++++++++++------ 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 36ddd7377..ffc69bdcf 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -652,6 +652,7 @@ function Base.keys(pa::PartialArray) inds = findall(pa.mask) lenses = map(x -> IndexLens(Tuple(x)), inds) ks = Any[] + alb_inds_seen = Set{Tuple}() for lens in lenses val = getindex(pa.data, lens.indices...) if val isa VarNamedTuple @@ -661,7 +662,10 @@ function Base.keys(pa::PartialArray) push!(ks, _compose_no_identity(sublens, lens)) end elseif val isa ArrayLikeBlock - push!(ks, IndexLens(Tuple(val.inds))) + if !(val.inds in alb_inds_seen) + push!(ks, IndexLens(Tuple(val.inds))) + push!(alb_inds_seen, val.inds) + end else push!(ks, lens) end diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 67c3621a7..c9e9cb07f 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -40,6 +40,12 @@ function test_invariants(vnt::VarNamedTuple) @test merge(VarNamedTuple(), vnt) == vnt end +""" A type that has a size but is not an Array. Used in ArrayLikeBlock tests.""" +struct SizedThing{T<:Tuple} + size::T +end +Base.size(st::SizedThing) = st.size + @testset "VarNamedTuple" begin @testset "Construction" begin vnt1 = VarNamedTuple() @@ -417,6 +423,22 @@ end @varname(j[6]), @varname(n[2].a), ] + + vnt = setindex!!(vnt, SizedThing((3, 1, 4)), @varname(o[2:4, 5:5, 11:14])) + @test keys(vnt) == [ + @varname(a), + @varname(b), + @varname(c.x.y), + @varname(d[4]), + @varname(e.f[3, 3].g.h[2, 4, 1].i), + @varname(j[1]), + @varname(j[2]), + @varname(j[3]), + @varname(j[4]), + @varname(j[6]), + @varname(n[2].a), + @varname(o[2:4, 5:5, 11:14]), + ] end @testset "printing" begin @@ -466,12 +488,6 @@ end end @testset "block variables" begin - """ A type that has a size but is not an Array.""" - struct SizedThing - size::Tuple - end - Base.size(st::SizedThing) = st.size - # Tests for setting and getting block variables, i.e. variables that have a non-zero # size in a PartialArray, but are not Arrays themselves. expected_err = ArgumentError(""" From d22face74a1d70bab93feeba2a1739d1b823a817 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 15 Dec 2025 11:40:12 +0000 Subject: [PATCH 057/148] More ArrayLikeBlock tests --- test/varnamedtuple.jl | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index c9e9cb07f..8be72a184 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -562,6 +562,30 @@ Base.size(st::SizedThing) = st.size @test @inferred(getindex(vnt, @varname(y.z[2:3, 2:3]))) == val @test haskey(vnt, @varname(y.z[4:5, 2:3])) @test @inferred(getindex(vnt, @varname(y.z[4:5, 2:3]))) == val + + # A lot like above, but with extra indices that are not ranges. + val = SizedThing((2, 2)) + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, val, @varname(y.z[2, 1:2, 3, 1:2, 4]))) + test_invariants(vnt) + @test haskey(vnt, @varname(y.z[2, 1:2, 3, 1:2, 4])) + @test @inferred(getindex(vnt, @varname(y.z[2, 1:2, 3, 1:2, 4]))) == val + @test !haskey(vnt, @varname(y.z[2, 1, 3, 1, 4])) + @test_throws expected_err getindex(vnt, @varname(y.z[2, 1, 3, 1, 4])) + + vnt = @inferred(setindex!!(vnt, val, @varname(y.z[2, 2:3, 3, 2:3, 4]))) + test_invariants(vnt) + @test haskey(vnt, @varname(y.z[2, 2:3, 3, 2:3, 4])) + @test @inferred(getindex(vnt, @varname(y.z[2, 2:3, 3, 2:3, 4]))) == val + @test !haskey(vnt, @varname(y.z[2, 1:2, 3, 1:2, 4])) + @test_throws BoundsError getindex(vnt, @varname(y.z[2, 1:2, 3, 1:2, 4])) + + vnt = @inferred(setindex!!(vnt, val, @varname(y.z[3, 2:3, 3, 2:3, 4]))) + test_invariants(vnt) + @test haskey(vnt, @varname(y.z[2, 2:3, 3, 2:3, 4])) + @test @inferred(getindex(vnt, @varname(y.z[2, 2:3, 3, 2:3, 4]))) == val + @test haskey(vnt, @varname(y.z[3, 2:3, 3, 2:3, 4])) + @test @inferred(getindex(vnt, @varname(y.z[3, 2:3, 3, 2:3, 4]))) == val end end From 4cb49e1194616ce8238d583c8aba22d18c1ab49f Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 15 Dec 2025 11:46:02 +0000 Subject: [PATCH 058/148] Add docstrings --- src/varnamedtuple.jl | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index ffc69bdcf..ab66da5ac 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -76,6 +76,19 @@ const PARTIAL_ARRAY_DIM_GROWTH_FACTOR = 4 """A convenience for defining method argument type bounds.""" const INDEX_TYPES = Union{Integer,UnitRange,Colon} +""" + ArrayLikeBlock{T,I} + +A wrapper for non-array blocks stored in `PartialArray`s. + +When setting a value in a `PartialArray` over a range of indices, if the value being set +is not itself an `AbstractArray`, but has a well-defined size, we wrap it in an +`ArrayLikeBlock`, which records both the value and the indices it was set with. + +When getting values from a `PartialArray`, if any of the requested indices correspond to +an `ArrayLikeBlock`, we check that the requested indices match the ones used to set the +value. If they do, we return the underlying block, otherwise we throw an error. +""" struct ArrayLikeBlock{T,I} block::T inds::I @@ -136,6 +149,14 @@ Like `Base.Array`s, `PartialArray`s have a well-defined, compile-time-known elem `ElType` and number of dimensions `numdims`. Indices into a `PartialArray` must have exactly `numdims` elements. +One can set values in a `PartialArray` either element-by-element, or with ranges like +`arr[1:3,2] = [5,10,15]`. When setting values over a range of indices, the value being set +must either be an `AbstractArray` or otherwise something for which `size(value)` is defined, +and the size mathces the range. If the value is an `AbstractArray`, the elements are copied +individually, but if it is not, the value is stored as a block, that takes up the whole +range, e.g. `[1:3,2]`, but is only a single object. Getting such a block-value must be done +with the exact same range of indices, otherwise an error is thrown. + If the element type of a `PartialArray` is not concrete, any call to `setindex!!` will check if, after the new value has been set, the element type can be made more concrete. If so, a new `PartialArray` with a more concrete element type is returned. Thus the element type From 420a6b2889429625a7fadb948937f2ef1ce1b6aa Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 15 Dec 2025 12:03:39 +0000 Subject: [PATCH 059/148] Remove redundant code, improve documentation --- src/varnamedtuple.jl | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index ab66da5ac..308951608 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -528,12 +528,20 @@ function BangBang.delete!!(pa::PartialArray, inds::Vararg{INDEX_TYPES}) else pa.mask[inds...] = false end - return _concretise_eltype!!(pa) + return pa end _ensure_range(r::UnitRange) = r _ensure_range(i::Integer) = i:i +""" + _remove_partial_blocks!!(pa::PartialArray, inds::Vararg{INDEX_TYPES}) + +Remove any ArrayLikeBlocks that overlap with the given indices from the PartialArray. + +Note that this removes the whole block, even the parts that are within `inds`, to avoid +partially indexing into ArrayLikeBlocks. +""" function _remove_partial_blocks!!(pa::PartialArray, inds::Vararg{INDEX_TYPES}) et = eltype(pa) if !(et <: ArrayLikeBlock || ArrayLikeBlock <: et) @@ -552,6 +560,13 @@ function _remove_partial_blocks!!(pa::PartialArray, inds::Vararg{INDEX_TYPES}) return pa end +""" + _needs_arraylikeblock(value, inds::Vararg{INDEX_TYPES}) + +Check if the given value needs to be wrapped in an `ArrayLikeBlock` when being set at inds. + +The value only depends on the types of the arguments, and should be constant propagated. +""" function _needs_arraylikeblock(value, inds::Vararg{INDEX_TYPES}) return _is_multiindex(inds) && !isa(value, AbstractArray) && @@ -569,9 +584,6 @@ function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) new_data = pa.data if _needs_arraylikeblock(value, inds...) - if !hasmethod(size, Tuple{typeof(value)}) - throw(ArgumentError("Cannot assign a scalar value to a range.")) - end inds_size = reduce((x, y) -> tuple(x..., y...), map(size, inds)) if size(value) != inds_size throw( @@ -584,7 +596,7 @@ function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) # At this point we know we have a value that is not an AbstractArray, but it has # some notion of size, and that size matches the indices that are being set. In this # case we wrap the value in an ArrayLikeBlock, and set all the individual indices - # point to that. + # to point to that. alb = ArrayLikeBlock(value, inds) new_data = setindex!!(new_data, fill(alb, inds_size...), inds...) else From ce9da19422b4255c97bcb42c6e3a4a9eff29e31d Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 15 Dec 2025 12:10:11 +0000 Subject: [PATCH 060/148] Add Base.size(::RangeAndLinked) --- src/contexts/init.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 90394a24c..e666e0622 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -213,6 +213,8 @@ struct RangeAndLinked is_linked::Bool end +Base.size(ral::RangeAndLinked) = size(ral.range) + """ VectorWithRanges{Tlink}( varname_ranges::VarNamedTuple, From 4eb33e931853d55400cf6b897bf97f485d6016bd Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 15 Dec 2025 16:36:22 +0000 Subject: [PATCH 061/148] Fix issues with RangeAndLinked and VNT --- ext/DynamicPPLMarginalLogDensitiesExt.jl | 10 ++++------ src/contexts/init.jl | 13 +++++++++--- src/logdensityfunction.jl | 10 ++++++++-- src/varname.jl | 25 ++++++++++++++++++++++++ src/varnamedtuple.jl | 21 ++++++++++++++++---- 5 files changed, 64 insertions(+), 15 deletions(-) diff --git a/ext/DynamicPPLMarginalLogDensitiesExt.jl b/ext/DynamicPPLMarginalLogDensitiesExt.jl index ffb5baf25..e28560872 100644 --- a/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ b/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -1,6 +1,6 @@ module DynamicPPLMarginalLogDensitiesExt -using DynamicPPL: DynamicPPL, LogDensityProblems, VarName +using DynamicPPL: DynamicPPL, LogDensityProblems, VarName, RangeAndLinked using MarginalLogDensities: MarginalLogDensities # A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by @@ -105,11 +105,9 @@ function DynamicPPL.marginalize( ldf = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo) # Determine the indices for the variables to marginalise out. varindices = mapreduce(vcat, marginalized_varnames) do vn - if DynamicPPL.getoptic(vn) === identity - ldf._iden_varname_ranges[DynamicPPL.getsym(vn)].range - else - ldf._varname_ranges[vn].range - end + # The type assertion helps in cases where the model is type unstable and thus + # `varname_ranges` may have an abstract element type. + (ldf._varname_ranges[vn]::RangeAndLinked).range end mld = MarginalLogDensities.MarginalLogDensity( LogDensityFunctionWrapper(ldf, varinfo), diff --git a/src/contexts/init.jl b/src/contexts/init.jl index e666e0622..dc811df85 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -206,14 +206,16 @@ an unlinked value. $(TYPEDFIELDS) """ -struct RangeAndLinked +struct RangeAndLinked{T<:Tuple} # indices that the variable corresponds to in the vectorised parameter range::UnitRange{Int} # whether it's linked is_linked::Bool + # original size of the variable before vectorisation + original_size::T end -Base.size(ral::RangeAndLinked) = size(ral.range) +Base.size(ral::RangeAndLinked) = ral.original_size """ VectorWithRanges{Tlink}( @@ -249,7 +251,12 @@ struct VectorWithRanges{Tlink,VNT<:VarNamedTuple,T<:AbstractVector{<:Real}} end function _get_range_and_linked(vr::VectorWithRanges, vn::VarName) - return vr.varname_ranges[vn] + # The type assertion does nothing if VectorWithRanges has concrete element types, as is + # the case for all type stable models. However, if the model is not type stable, + # vr.varname_ranges[vn] may infer to have type `Any`. In this case it is helpful to + # assert that it is a RangeAndLinked, because even though it remains non-concrete, + # it'll allow the compiler to infer the types of `range` and `is_linked`. + return vr.varname_ranges[vn]::RangeAndLinked end function init( ::Random.AbstractRNG, diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 47b49a277..89e2b5989 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -330,7 +330,10 @@ function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int) for (vn, idx) in md.idcs is_linked = md.is_transformed[idx] range = md.ranges[idx] .+ (start_offset - 1) - all_ranges = BangBang.setindex!!(all_ranges, RangeAndLinked(range, is_linked), vn) + orig_size = varnamesize(vn) + all_ranges = BangBang.setindex!!( + all_ranges, RangeAndLinked(range, is_linked, orig_size), vn + ) offset += length(range) end return all_ranges, offset @@ -341,7 +344,10 @@ function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int) for (vn, idx) in vnv.varname_to_index is_linked = vnv.is_unconstrained[idx] range = vnv.ranges[idx] .+ (start_offset - 1) - all_ranges = BangBang.setindex!!(all_ranges, RangeAndLinked(range, is_linked), vn) + orig_size = varnamesize(vn) + all_ranges = BangBang.setindex!!( + all_ranges, RangeAndLinked(range, is_linked, orig_size), vn + ) offset += length(range) end return all_ranges, offset diff --git a/src/varname.jl b/src/varname.jl index 3eb1f2460..7ffe9cc08 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -41,3 +41,28 @@ Possibly existing indices of `varname` are neglected. ) where {s,missings,_F,_a,_T} return s in missings end + +# TODO(mhauru) This should probably be Base.size(::VarName) in AbstractPPL. +""" + varnamesize(vn::VarName) + +Return the size of the object referenced by this VarName. + +```jldoctest +julia> varnamesize(@varname(a)) +() + +julia> varnamesize(@varname(b[1:3, 2])) +(3,) + +julia> varnamesize(@varname(c.d[4].e[3, 2:5, 2, 1:4, 1])) +(4, 4) +""" +function varnamesize(vn::VarName) + l = AbstractPPL._last(vn.optic) + if l isa Accessors.IndexLens + return reduce((x, y) -> tuple(x..., y...), map(size, l.indices)) + else + return () + end +end diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 308951608..1340846a9 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -352,7 +352,13 @@ function _concretise_eltype!!(pa::PartialArray) if isconcretetype(eltype(pa)) return pa end - new_et = promote_type((typeof(pa.data[i]) for i in eachindex(pa.mask) if pa.mask[i])...) + # We could use promote_type here, instead of typejoin. However, that would e.g. + # cause Ints to be converted to Float64s, since + # promote_type(Int, Float64) == Float64, which can cause problems. See + # https://github.com/TuringLang/DynamicPPL.jl/pull/1098#discussion_r2472636188. + # Base.promote_typejoin would be like typejoin, but creates Unions out of Nothing + # and Missing, rather than falling back on Any. However, it's not exported. + new_et = typejoin((typeof(pa.data[i]) for i in eachindex(pa.mask) if pa.mask[i])...) # TODO(mhauru) Should we check as below, or rather isconcretetype(new_et)? # In other words, does it help to be more concrete, even if we aren't fully concrete? if new_et === eltype(pa) @@ -588,8 +594,8 @@ function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) if size(value) != inds_size throw( DimensionMismatch( - "Assigned value has size $(size(value)), which does not match the size " * - "implied by the indices $(map(x -> _length_needed(x), inds)).", + "Assigned value has size $(size(value)), which does not match the " * + "size implied by the indices $(map(x -> _length_needed(x), inds)).", ), ) end @@ -659,7 +665,14 @@ function _merge_recursive(pa1::PartialArray, pa2::PartialArray) result else # Neither is strictly bigger than the other. - et = promote_type(eltype(pa1), eltype(pa2)) + # We could use promote_type here, instead of typejoin. However, that would e.g. + # cause Ints to be converted to Float64s, since + # promote_type(Int, Float64) == Float64, which can cause problems. See + # https://github.com/TuringLang/DynamicPPL.jl/pull/1098#discussion_r2472636188. + # Base.promote_typejoin would be like typejoin, but creates Unions out of + # Nothing and Missing, rather than falling back on Any. However, it's not + # exported. + et = typejoin(eltype(pa1), eltype(pa2)) new_data = Array{et,num_dims}(undef, merge_size) new_mask = fill(false, merge_size) result = PartialArray(new_data, new_mask) From 51b399aeb1f3c4ee29e1029215668b47847e0a15 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 15 Dec 2025 17:33:16 +0000 Subject: [PATCH 062/148] Write more design doc for ArrayLikeBlocks --- docs/src/internals/varnamedtuple.md | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/docs/src/internals/varnamedtuple.md b/docs/src/internals/varnamedtuple.md index 47ff9c65e..63f4bb5b9 100644 --- a/docs/src/internals/varnamedtuple.md +++ b/docs/src/internals/varnamedtuple.md @@ -144,6 +144,29 @@ You can also set the elements with `vnt = setindex!!(vnt, @varname(a[1]), 3.0)`, At this point you can not set any new values in that array that would be outside of its range, with something like `vnt = setindex!!(vnt, @varname(a[5]), 5.0)`. The philosophy here is that once a `Base.Array` has been attached to a `VarName`, that takes precedence, and a `PartialArray` is only used as a fallback when we are told to store a value for `@varname(a[i])` without having any previous knowledge about what `@varname(a)` is. +## Non-Array blocks with `IndexLens`es + +The above is all that is needed for setting regular scalar values. +However, in DynamicPPL we also have a particular need for something slightly odd: +We sometimes need to do calls like `setindex!!(vnt, @varname(a[1:5]), val)` on a `val` that is _not_ an `AbstractArray`, or even iterable at all. +Normally this would error: As a scalar value with size `()`, `val` is the wrong size to be set with `@varname(a[1:5])`, which clearly wants something with size `(5,)`. +However, we want to allow this even if `val` is not an iterable, if it is some object for which `size` is well-defined, and `size(val) == (5,)`. +In DynamicPPL this comes up when storing e.g. the priors of a model, where a random variable like `@varname(a[1:5])` may be associated with a prior that is a 5-dimensional distribution. + +Internally, a `PartialArray` is just a regular `Array` with a mask saying which elements have been set. +Hence we can't store `val` directly in the same `PartialArray`: +We need it to take up a sub-block of the array, in our example case a sub-block of length 5. +To this end, internally, `PartialArray` uses a wrapper type called `ArrayLikeWrapper`, that stores `val` together with the indices that are being used to set it. +The `PartialArray` has all its corresponding elements, in our example elements 1, 2, 3, 4, and, 5, point to the same wrapper object. + +While such blocks can be stored using a wrapper like this, some care must be taken in indexing into these blocks. +For instance, after setting a block with `setindex!!(vnt, @varname(a[1:5]), val)`, we can't `getindex(vnt, @varname(a[1]))`, since we can't return "the first element of five in `val`", because `val` may not be indexable in any way. +Similarly, if next we set `setindex!!(vnt, @varname(a[1]), some_other_value)`, that should invalidate/delete the elements `@varname(a[2:5])`, since the block only makes sense as a whole. +Because of these reasons, setting and getting blocks of well-defined size like this is allowed with `VarNamedTuple`s, but _only by always using the full range_. +For instance, if `setindex!!(vnt, @varname(a[1:5]), val)` has been set, then the only valid `getindex` key to access `val` is `@varname(a[1:5])`; +Not `@varname(a[1:10])`, nor `@varname(a[3])`, nor for anything else that overlaps with `@varname(a[1:5])`. +`haskey` likewise only returns true for `@varname(a[1:5])`, and `keys(vnt)` only has that as an element. + ## Limitations This design has a several of benefits, for performance and generality, but it also has limitations: From 57fd11a30b4bf5b5b55500300d5af6506c7e31d5 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 16 Dec 2025 15:14:22 +0000 Subject: [PATCH 063/148] Make VNT support concretized slices --- docs/src/internals/varnamedtuple.md | 1 + src/test_utils.jl | 1 + src/test_utils/models.jl | 111 ++++++++++++++++++++++++++++ src/varnamedtuple.jl | 31 +++++--- test/simple_varinfo.jl | 8 ++ test/varnamedtuple.jl | 21 +++++- 6 files changed, 162 insertions(+), 11 deletions(-) diff --git a/docs/src/internals/varnamedtuple.md b/docs/src/internals/varnamedtuple.md index 63f4bb5b9..aa08c119d 100644 --- a/docs/src/internals/varnamedtuple.md +++ b/docs/src/internals/varnamedtuple.md @@ -50,6 +50,7 @@ The typical use of this structure in DynamicPPL is that the user may define valu This is also the reason why `PartialArray`, and by extension `VarNamedTuple`, do not support indexing by `Colon()`, i.e. `:`, as in `x[:]`. A `Colon()` says that we should get or set all the values along that dimension, but a `PartialArray` does not know how many values there may be. If `x[1]` and `x[4]` have been set, asking for `x[:]` is not a well-posed question. +Note however, that concretising the `VarName` resolves this ambiguity, and makes the `VarName` fine as a key to a `VarNamedTuple`. `PartialArray`s have other restrictions, compared to the full indexing syntax of Julia, as well: They do not support linearly indexing into multidimemensional arrays (as in `rand(3,3)[8]`), nor indexing with arrays of indices (as in `rand(4)[[1,3]]`), nor indexing with boolean mask arrays (as in `rand(4)[[true, false, true, false]]`). diff --git a/src/test_utils.jl b/src/test_utils.jl index f584055b3..ebb516844 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -1,6 +1,7 @@ module TestUtils using AbstractMCMC +using AbstractPPL: AbstractPPL using DynamicPPL using LinearAlgebra using Distributions diff --git a/src/test_utils/models.jl b/src/test_utils/models.jl index 84e1f10d8..dcc2d92a2 100644 --- a/src/test_utils/models.jl +++ b/src/test_utils/models.jl @@ -565,6 +565,71 @@ function varnames(model::Model{typeof(demo_assume_matrix_observe_matrix_index)}) return [@varname(s), @varname(m)] end +@model function demo_nested_colons( + x=(; data=[(; subdata=transpose([1.5 2.0;]))]), ::Type{TV}=Array{Float64} +) where {TV} + n = length(x.data[1].subdata) + d = n ÷ 2 + s = (; params=[(; subparams=TV(undef, (d, 1, 2)))]) + s.params[1].subparams[:, 1, :] ~ reshape( + product_distribution(fill(InverseGamma(2, 3), n)), d, 2 + ) + s_vec = vec(s.params[1].subparams) + # TODO(mhauru) The below element type concretisation is because of + # https://github.com/JuliaFolds2/BangBang.jl/issues/39 + # which causes, when this is evaluated with an untyped VarInfo, s_vec to be an + # Array{Any}. + s_vec = [x for x in s_vec] + m ~ MvNormal(zeros(n), Diagonal(s_vec)) + + x.data[1].subdata[:, 1] ~ MvNormal(m, Diagonal(s_vec)) + + return (; s=s, m=m, x=x) +end +function logprior_true(model::Model{typeof(demo_nested_colons)}, s, m) + n = length(model.args.x.data[1].subdata) + # TODO(mhauru) We need to enforce a convention on whether this function gets called + # with the parameters as the model returns them, or with the parameters "unpacked". + # Currently different tests do different things. + s_vec = if s isa NamedTuple + vec(s.params[1].subparams) + else + vec(s) + end + return loglikelihood(InverseGamma(2, 3), s_vec) + + logpdf(MvNormal(zeros(n), Diagonal(s_vec)), m) +end +function loglikelihood_true(model::Model{typeof(demo_nested_colons)}, s, m) + # TODO(mhauru) We need to enforce a convention on whether this function gets called + # with the parameters as the model returns them, or with the parameters "unpacked". + # Currently different tests do different things. + s_vec = if s isa NamedTuple + vec(s.params[1].subparams) + else + vec(s) + end + return loglikelihood(MvNormal(m, Diagonal(s_vec)), model.args.x.data[1].subdata) +end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_nested_colons)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s.params[1].subparams, m) +end +function varnames(::Model{typeof(demo_nested_colons)}) + return [ + @varname( + s.params[1].subparams[ + AbstractPPL.ConcretizedSlice(Base.Slice(Base.OneTo(1))), + 1, + AbstractPPL.ConcretizedSlice(Base.Slice(Base.OneTo(2))), + ] + ), + # @varname(s.params[1].subparams[1,1,1]), + # @varname(s.params[1].subparams[1,1,2]), + @varname(m), + ] +end + const UnivariateAssumeDemoModels = Union{ Model{typeof(demo_assume_dot_observe)}, Model{typeof(demo_assume_dot_observe_literal)}, @@ -701,6 +766,51 @@ function rand_prior_true(rng::Random.AbstractRNG, model::MatrixvariateAssumeDemo return vals end +function posterior_mean(model::Model{typeof(demo_nested_colons)}) + # Get some containers to fill. + vals = rand_prior_true(model) + + vals.s.params[1].subparams[1, 1, 1] = 19 / 8 + vals.m[1] = 3 / 4 + + vals.s.params[1].subparams[1, 1, 2] = 8 / 3 + vals.m[2] = 1 + + return vals +end +function likelihood_optima(model::Model{typeof(demo_nested_colons)}) + # Get some containers to fill. + vals = rand_prior_true(model) + + # NOTE: These are "as close to zero as we can get". + vals.s.params[1].subparams[1, 1, 1] = 1e-32 + vals.s.params[1].subparams[1, 1, 2] = 1e-32 + + vals.m[1] = 1.5 + vals.m[2] = 2.0 + + return vals +end +function posterior_optima(model::Model{typeof(demo_nested_colons)}) + # Get some containers to fill. + vals = rand_prior_true(model) + + # TODO: Figure out exact for `s[1]`. + vals.s.params[1].subparams[1, 1, 1] = 0.890625 + vals.s.params[1].subparams[1, 1, 2] = 1 + vals.m[1] = 3 / 4 + vals.m[2] = 1 + + return vals +end +function rand_prior_true(rng::Random.AbstractRNG, ::Model{typeof(demo_nested_colons)}) + svec = rand(rng, InverseGamma(2, 3), 2) + return (; + s=(; params=[(; subparams=reshape(svec, (1, 1, 2)))]), + m=rand(rng, MvNormal(zeros(2), Diagonal(svec))), + ) +end + """ A collection of models corresponding to the posterior distribution defined by the generative process @@ -749,6 +859,7 @@ const DEMO_MODELS = ( demo_dot_assume_observe_submodel(), demo_dot_assume_observe_matrix_index(), demo_assume_matrix_observe_matrix_index(), + demo_nested_colons(), ) """ diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 1340846a9..55f613e87 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -30,7 +30,7 @@ end # The non-generated function implementations of these would be # _has_colon(::T) where {T<:Tuple} = any(x <: Colon for x in T.parameters) # function _is_multiindex(::T) where {T<:Tuple} -# return any(x <: UnitRange || x <: Colon for x in T.parameters) +# return any(x <: AbstractUnitRange || x <: Colon for x in T.parameters) # end # However, constant propagation sometimes fails if the index tuple is too big (e.g. length # 4), so we play it safe and use generated functions. Constant propagating these is @@ -39,18 +39,18 @@ end @generated function _has_colon(::T) where {T<:Tuple} for x in T.parameters if x <: Colon - return :(true) + return :(return true) end end - return :(false) + return :(return false) end @generated function _is_multiindex(::T) where {T<:Tuple} for x in T.parameters - if x <: UnitRange || x <: Colon - return :(true) + if x <: AbstractUnitRange || x <: Colon || x <: AbstractPPL.ConcretizedSlice + return :(return true) end end - return :(false) + return :(return false) end """ @@ -74,7 +74,10 @@ _merge_recursive(_, x2) = x2 const PARTIAL_ARRAY_DIM_GROWTH_FACTOR = 4 """A convenience for defining method argument type bounds.""" -const INDEX_TYPES = Union{Integer,UnitRange,Colon} +const INDEX_TYPES = Union{Integer,AbstractUnitRange,Colon,AbstractPPL.ConcretizedSlice} + +_unwrap_concretized_slice(cs::AbstractPPL.ConcretizedSlice) = cs.range +_unwrap_concretized_slice(x::Union{Integer,AbstractUnitRange,Colon}) = x """ ArrayLikeBlock{T,I} @@ -376,7 +379,7 @@ end """Return the length needed in a dimension given an index.""" _length_needed(i::Integer) = i -_length_needed(r::UnitRange) = last(r) +_length_needed(r::AbstractUnitRange) = last(r) """Take the minimum size that a dimension of a PartialArray needs to be, and return the size we choose it to be. This size will be the smallest possible power of @@ -447,12 +450,16 @@ function _check_index_validity(pa::PartialArray, inds::NTuple{N,INDEX_TYPES}) wh throw(BoundsError(pa, inds)) end if _has_colon(inds) - throw(ArgumentError("Indexing PartialArrays with Colon is not supported")) + msg = """ + Indexing PartialArrays with Colon is not supported. + You may need to concretise the `VarName` first.""" + throw(ArgumentError(msg)) end return nothing end function _getindex(pa::PartialArray, inds::Vararg{INDEX_TYPES}) + inds = _unwrap_concretized_slice.(inds) _check_index_validity(pa, inds) if !(checkbounds(Bool, pa.mask, inds...) && all(@inbounds(getindex(pa.mask, inds...)))) throw(BoundsError(pa, inds)) @@ -501,6 +508,7 @@ function _getindex(pa::PartialArray, inds::Vararg{INDEX_TYPES}) end function _haskey(pa::PartialArray, inds::NTuple{N,INDEX_TYPES}) where {N} + inds = _unwrap_concretized_slice.(inds) _check_index_validity(pa, inds) hasall = checkbounds(Bool, pa.mask, inds...) && all(@inbounds(getindex(pa.mask, inds...))) @@ -528,6 +536,7 @@ function _haskey(pa::PartialArray, inds::NTuple{N,INDEX_TYPES}) where {N} end function BangBang.delete!!(pa::PartialArray, inds::Vararg{INDEX_TYPES}) + inds = _unwrap_concretized_slice.(inds) _check_index_validity(pa, inds) if _is_multiindex(inds) pa.mask[inds...] .= false @@ -537,7 +546,7 @@ function BangBang.delete!!(pa::PartialArray, inds::Vararg{INDEX_TYPES}) return pa end -_ensure_range(r::UnitRange) = r +_ensure_range(r::AbstractUnitRange) = r _ensure_range(i::Integer) = i:i """ @@ -580,6 +589,7 @@ function _needs_arraylikeblock(value, inds::Vararg{INDEX_TYPES}) end function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) + inds = _unwrap_concretized_slice.(inds) _check_index_validity(pa, inds) pa = if checkbounds(Bool, pa.mask, inds...) pa @@ -733,6 +743,7 @@ The there are two major limitations to indexing by VarNamedTuples: * `VarName`s with `Colon`s, (e.g. `a[:]`) are not supported. This is because the meaning of `a[:]` is ambiguous if only some elements of `a`, say `a[1]` and `a[3]`, are defined. + However, _concretised_ `VarName`s with `Colon`s are supported. * Any `VarNames` with IndexLenses` must have a consistent number of indices. That is, one cannot set `a[1]` and `a[1,2]` in the same `VarNamedTuple`. diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 42e377440..2c0e21bec 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -144,6 +144,14 @@ @testset "SimpleVarInfo on $(nameof(model))" for model in DynamicPPL.TestUtils.ALL_MODELS + if model.f === DynamicPPL.TestUtils.demo_nested_colons + # TODO(mhauru) Either VarNamedVector or SimpleVarInfo has a bug that causes + # the push!! below to fail with a NamedTuple variable like what + # demo_nested_colons has. I don't want to fix it now though, because this may + # all go soon (as of 2025-12-16). + @test false broken = true + continue + end # We might need to pre-allocate for the variable `m`, so we need # to see whether this is the case. svi_nt = SimpleVarInfo(DynamicPPL.TestUtils.rand_prior_true(model)) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 8be72a184..6578d19ae 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -4,7 +4,7 @@ using Combinatorics: Combinatorics using Test: @inferred, @test, @test_throws, @testset using DynamicPPL: DynamicPPL, @varname, VarNamedTuple using DynamicPPL.VarNamedTuples: PartialArray, ArrayLikeBlock -using AbstractPPL: VarName, prefix +using AbstractPPL: VarName, concretize, prefix using BangBang: setindex!! """ @@ -231,6 +231,25 @@ Base.size(st::SizedThing) = st.size vnt = @inferred(setindex!!(vnt, 6, vn5)) @test @inferred(getindex(vnt, vn5)) == 6 test_invariants(vnt) + + # ConcretizedSlices + vnt = VarNamedTuple() + x = [1, 2, 3] + vn = concretize(@varname(y[:]), x) + vnt = @inferred(setindex!!(vnt, x, vn)) + @test haskey(vnt, vn) + @test @inferred(getindex(vnt, vn)) == x + test_invariants(vnt) + + y = fill("a", (3, 2, 4)) + x = y[:, 2, :] + a = (; b=[nothing, nothing, (; c=(; d=reshape(y, (1, 3, 2, 4, 1))))]) + vn = @varname(a.b[3].c.d[1, 3:5, 2, :, 1]) + vn = concretize(vn, a) + vnt = @inferred(setindex!!(vnt, x, vn)) + @test haskey(vnt, vn) + @test @inferred(getindex(vnt, vn)) == x + test_invariants(vnt) end @testset "equality and hash" begin From 7bdce5ce53a9dc410ad140f782d4ad4c7b31f939 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 16 Dec 2025 16:47:50 +0000 Subject: [PATCH 064/148] Start the VNT HISTORY.md entry --- HISTORY.md | 106 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) diff --git a/HISTORY.md b/HISTORY.md index 9dc4414ce..0ad1824dd 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -2,6 +2,112 @@ ## 0.40 +### Changes to indexing random variables with square brackets + +0.40 internally reimplements how DynamicPPL handles random variables like `x[1]`, `x.y[2,2]`, and `x[:,1:4,5]`, i.e. ones that use indexing with square brackets. +Most of this is invisible to users, but it has some effects that show on the surface. +The gist of the changes is that any indexing by square brackets is now implicitly assumed to be indexing into a regular `Base.Array`, with 1-based indexing. +The general effect this has is that the new rules on what is and isn't allowed are stricter, forbidding some old syntax that used to be allowed, and at the same time guaranteeing that it works correctly. +(Previously there were some sharp edges around these sorts of variable names.) + +#### No more linear indexing of multidimensional arrays + +Previously you could do this: + +```julia +x = Array{Float64,2}(undef, (2, 2)) +x[1] ~ Normal() +x[1, 1] ~ Normal() +``` + +Now you can't, this will error. +If you first create a variable like `x[1]`, DynamicPPL from there on assumes that this variable only takes a single index (like a `Vector`). +It will then error if you try to index the same variable with any other number of indices. + +The same logic also bans this, which likewise was previously allowed: + +```julia +x = Array{Float64,2}(undef, (2, 2)) +x[1, 1, 1] ~ Normal() +x[1, 1] ~ Normal() +``` + +This made use of Julia allowing trailing indices of `1`. + +Note that the above models were previously quite dangerous and easy to misuse, because DynamicPPL was oblivious to the fact that e.g. `x[1]` and `x[1,1]` refer to the same element. +Both of the above examples previously created 2-dimensional models, with two distinct random variables, one of which effectively overwrote the other in the model body. + +TODO(mhauru) This may cause surprising issues when using `eachindex`, which is generally encouraged, e.g. + +``` +x = Array{Float64,2}(undef, (3, 3) +for i in eachindex(x) + x[i] ~ Normal() +end +``` + +Maybe we should fix linear indexing before releasing? + +#### No more square bracket indexing with arbitrary keys + +Previously you could do this: + +```julia +x = Dict() +x["a"] ~ Normal() +``` + +Now you can't, this will error. +This is because DynamicPPL now assumes that if you are indexing with square brackets, you are dealing with an `Array`, for which `"a"` is not a valid index. +You can still use a dictionary on the left-hand side of a `~` statement as long as the indices are valid indices to an `Array`, e.g. integers. + +#### No more unusually indexed arrays, such as `OffsetArrays` + +Previously you could do this + +```julia +using OffsetArrays +x = OffsetArray(Vector{Float64}(undef, 3), -3) +x[-2] ~ Normal() +0.0 ~ Normal(x[-2]) +``` + +Now you can't, this will error. +This is because DynamicPPL now assumes that if you are indexing with square brackes, you are dealing with an `Array`, for which `-2` is not a valid index. + +#### The above limitations are not fundamental + +The above, new restrictions to what sort of variable names are allowed aren't fundamental. +With some effort we could e.g. add support for linear indexing, this time done properly, so that e.g. `x[1,1]` and `x[1]` would be the same variable. +Likewise, we could manually add structures to support indexing into dictionaries or `OffsetArrays`. +If this would be useful to you, let us know. + +#### This only affects `~` statements + +You can still use any arbitrary indexing within your model in statements that don't involve `~`. +For instance, you can use `OffsetArray`s, or linear indexing, as long as you don't put them on the left-hand side of a `~`. + +#### Performance benefits + +The upside of all these new limitations is that models that use square bracket indexing are now faster. +For instance, take the following model + +```julia +@model function f() + x = Vector{Float64}(undef, 1000) + for i in eachindex(x) + x[i] ~ Normal() + end + return 0.0 ~ Normal(sum(x)) +end +``` + +Evaluating the log joint for this model has gotten about 3 times faster in v0.40. + +#### Robustness benefits + +TODO(mhauru) Add an example here for how this improves `condition`ing, once `condition` uses `VarNamedTuple`. + ## 0.39.4 Removed the internal functions `DynamicPPL.getranges`, `DynamicPPL.vector_getrange`, and `DynamicPPL.vector_getranges` (the new LogDensityFunction construction does exactly the same thing, so this specialised function was not needed). From 9992051225b887f45da466c0067e113d17280b71 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 16 Dec 2025 17:48:03 +0000 Subject: [PATCH 065/148] Skip a type stability test on 1.10 --- test/model.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/model.jl b/test/model.jl index c878fd905..30fc614ca 100644 --- a/test/model.jl +++ b/test/model.jl @@ -408,6 +408,14 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() DynamicPPL.TestUtils.DEMO_MODELS..., DynamicPPL.TestUtils.demo_lkjchol(2) ] @testset "$(model.f)" for model in models_to_test + if model.f === DynamicPPL.TestUtils.demo_nested_colons && VERSION < v"1.11" + # On v1.10, the demo_nested_colons model, which uses a lot of + # NamedTuples, is badly type unstable. Not worth doing much about + # it, since it's fixed on later Julia versions, so just skipping + # these tests. + @test_skip false skip = true + continue + end vns = DynamicPPL.TestUtils.varnames(model) example_values = DynamicPPL.TestUtils.rand_prior_true(model) varinfos = filter( From 753ca81b85af88adb0970dff88670dda2445fa4d Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 16 Dec 2025 18:32:34 +0000 Subject: [PATCH 066/148] Fix test_skip --- test/model.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/model.jl b/test/model.jl index 30fc614ca..3272fd8b5 100644 --- a/test/model.jl +++ b/test/model.jl @@ -413,7 +413,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() # NamedTuples, is badly type unstable. Not worth doing much about # it, since it's fixed on later Julia versions, so just skipping # these tests. - @test_skip false skip = true + @test false skip = true continue end vns = DynamicPPL.TestUtils.varnames(model) From 34c42af9ca17c4d71df8e85888089584cbf8592c Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 1 Dec 2025 18:02:14 +0000 Subject: [PATCH 067/148] Change Base.keys on PartialArrays to be more type stable --- src/varnamedtuple.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 55f613e87..0a23afa66 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -715,7 +715,7 @@ function Base.keys(pa::PartialArray) subkeys = keys(val) for vn in subkeys sublens = _varname_to_lens(vn) - push!(ks, _compose_no_identity(sublens, lens)) + ks = push!!(ks, _compose_no_identity(sublens, lens)) end elseif val isa ArrayLikeBlock if !(val.inds in alb_inds_seen) @@ -723,7 +723,7 @@ function Base.keys(pa::PartialArray) push!(alb_inds_seen, val.inds) end else - push!(ks, lens) + ks = push!!(ks, lens) end end return ks From 262e303b7977bf614a28c88cfd354ec67fde04a2 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 1 Dec 2025 18:04:06 +0000 Subject: [PATCH 068/148] Implement Base.values for VNT --- src/varnamedtuple.jl | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 0a23afa66..22c2e1dfc 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -729,6 +729,21 @@ function Base.keys(pa::PartialArray) return ks end +function Base.values(pa::PartialArray) + inds = findall(pa.mask) + vs = Union{}[] + for ind in inds + val = getindex(pa.data, ind...) + if val isa VarNamedTuple + subvalues = values(val) + vs = push!!(vs, subvalues...) + else + vs = push!!(vs, val) + end + end + return vs +end + """ VarNamedTuple{names,Values} @@ -893,6 +908,24 @@ function Base.keys(vnt::VarNamedTuple) return result end +# TODO(mhauru) Same comments as for keys. +function Base.values(vnt::VarNamedTuple) + result = () + for sym in keys(vnt.data) + subdata = vnt.data[sym] + if subdata isa VarNamedTuple + subvalues = values(subdata) + result = (result..., subvalues...) + elseif subdata isa PartialArray + subvalues = values(subdata) + result = (result..., subvalues...) + else + result = (result..., subdata) + end + end + return result +end + # The following methods, indexing with ComposedFunction, are exactly the same for # VarNamedTuple and PartialArray, since they just fall back on indexing with the outer and # inner lenses. From 1f2cd8bf9f9a3b2fdce96cd52f0ddff1e6f08acc Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 1 Dec 2025 18:04:25 +0000 Subject: [PATCH 069/148] Implement isempty and empty for VNT --- src/varnamedtuple.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 22c2e1dfc..d5faf3ade 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -804,6 +804,13 @@ function Base.copy(vnt::VarNamedTuple{names}) where {names} ) end +# TODO(mhauru) Should this recur to PartialArray? +Base.isempty(vnt::VarNamedTuple) = isempty(vnt.data) + +# TODO(mhauru) Should this in fact keep the PartialArrays in place, but set them all to have +# mask = fill(false, size(pa.mask))? That might save some allocations. +Base.empty(::VarNamedTuple) = VarNamedTuple() + """ varname_to_lens(name::VarName{S}) where {S} From 1debab196dc4bf9258c1dc2bbffc816f8105d1e8 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 1 Dec 2025 18:05:04 +0000 Subject: [PATCH 070/148] Use VNT in VAIMAcc --- src/chains.jl | 2 +- src/compiler.jl | 4 ++-- src/values_as_in_model.jl | 15 +++++++++------ 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/chains.jl b/src/chains.jl index 319579a9c..71ca29a8f 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -5,7 +5,7 @@ A struct which contains parameter values extracted from a `VarInfo`, along with statistics associated with the VarInfo. The statistics are provided as a NamedTuple and are optional. """ -struct ParamsWithStats{P<:OrderedDict{<:VarName,<:Any},S<:NamedTuple} +struct ParamsWithStats{P<:Union{OrderedDict{<:VarName,<:Any},VarNamedTuple},S<:NamedTuple} params::P stats::S end diff --git a/src/compiler.jl b/src/compiler.jl index 1b4260121..cd6cf29fd 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -434,14 +434,14 @@ end function generate_assign(left, right) # A statement `x := y` reduces to `x = y`, but if __varinfo__ has an accumulator for - # ValuesAsInModel then in addition we push! the pair of `x` and `y` to the accumulator. + # ValuesAsInModel then in addition we push!! the pair of `x` and `y` to the accumulator. @gensym acc right_val vn return quote $right_val = $right if $(DynamicPPL.is_extracting_values)(__varinfo__) $vn = $(DynamicPPL.prefix)(__model__.context, $(make_varname_expression(left))) __varinfo__ = $(map_accumulator!!)( - $acc -> push!($acc, $vn, $right_val), __varinfo__, Val(:ValuesAsInModel) + $acc -> push!!($acc, $vn, $right_val), __varinfo__, Val(:ValuesAsInModel) ) end $left = $right_val diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index 71baebe92..d7d2da18a 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -10,14 +10,14 @@ wants to extract the realization of a model in a constrained space. # Fields $(TYPEDFIELDS) """ -struct ValuesAsInModelAccumulator <: AbstractAccumulator +struct ValuesAsInModelAccumulator{VNT<:VarNamedTuple} <: AbstractAccumulator "values that are extracted from the model" - values::OrderedDict{<:VarName} + values::VNT "whether to extract variables on the LHS of :=" include_colon_eq::Bool end function ValuesAsInModelAccumulator(include_colon_eq) - return ValuesAsInModelAccumulator(OrderedDict{VarName,Any}(), include_colon_eq) + return ValuesAsInModelAccumulator(VarNamedTuple(), include_colon_eq) end function Base.:(==)(acc1::ValuesAsInModelAccumulator, acc2::ValuesAsInModelAccumulator) @@ -45,8 +45,9 @@ function combine(acc1::ValuesAsInModelAccumulator, acc2::ValuesAsInModelAccumula ) end -function Base.push!(acc::ValuesAsInModelAccumulator, vn::VarName, val) - setindex!(acc.values, deepcopy(val), vn) +function BangBang.push!!(acc::ValuesAsInModelAccumulator, vn::VarName, val) + # TODO(mhauru) Why do we deepcopy here? + Accessors.@reset acc.values = setindex!!(acc.values, deepcopy(val), vn) return acc end @@ -56,7 +57,7 @@ function is_extracting_values(vi::AbstractVarInfo) end function accumulate_assume!!(acc::ValuesAsInModelAccumulator, val, logjac, vn, right) - return push!(acc, vn, val) + return push!!(acc, vn, val) end accumulate_observe!!(acc::ValuesAsInModelAccumulator, right, left, vn) = acc @@ -75,6 +76,8 @@ working in unconstrained space. Hence this method is a "safe" way of obtaining realizations in constrained space at the cost of additional model evaluations. +Returns a `VarNamedTuple`. + # Arguments - `model::Model`: model to extract realizations from. - `include_colon_eq::Bool`: whether to also include variables on the LHS of `:=`. From 1db910c349295466dfed99ce4f432cd790e71944 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 3 Dec 2025 15:56:46 +0000 Subject: [PATCH 071/148] WIP implementation of empty!! for VNT --- src/varnamedtuple.jl | 45 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index d5faf3ade..6bfb6c98b 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -337,6 +337,13 @@ function Base.hash(pa::PartialArray, h::UInt) return h end +function BangBang.empty!!(pa::PartialArray) + for i in eachindex(pa.mask) + @inbounds pa.mask[i] = false + end + return pa +end + """ _concretise_eltype!!(pa::PartialArray) @@ -804,12 +811,48 @@ function Base.copy(vnt::VarNamedTuple{names}) where {names} ) end +_has_partial_array(::Type{T}) where {T} = false +_has_partial_array(::Type{<:PartialArray}) = true + +@generated function _has_partial_array( + ::Type{VarNamedTuple{Names,Values}} +) where {Names,Values} + exs = Expr[] + for T in Values.parameters + if _has_partial_array(T) + push!(exs, :(return true)) + end + end + push!(exs, :(return false)) + return Expr(:block, exs...) +end + # TODO(mhauru) Should this recur to PartialArray? Base.isempty(vnt::VarNamedTuple) = isempty(vnt.data) # TODO(mhauru) Should this in fact keep the PartialArrays in place, but set them all to have # mask = fill(false, size(pa.mask))? That might save some allocations. -Base.empty(::VarNamedTuple) = VarNamedTuple() +Base.empty(vnt::VarNamedTuple) = VarNamedTuple() + +@generated function BangBang.empty!!(vnt::VarNamedTuple{Names,Values}) where {Names,Values} + if !_has_partial_array(VarNamedTuple{Names,Values}) + return :(return VarNamedTuple()) + end + new_names = () + new_values = () + for (name, ValType) in zip(Names, Values.parameters) + if _has_partial_array(ValType) + new_values = (new_values..., :(BangBang.empty!!(vnt.data[$(QuoteNode(name))]))) + new_names = (new_names..., name) + end + end + if length(new_names) != length(new_values) + error(new_values) + end + return quote + return VarNamedTuple(NamedTuple{$new_names}(($(new_values...),))) + end +end """ varname_to_lens(name::VarName{S}) where {S} From b02d014e9b4478b92711409e36951183608ac6d0 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 17 Dec 2025 14:30:06 +0000 Subject: [PATCH 072/148] Make to_samples use VNT and fix related tests --- ext/DynamicPPLMCMCChainsExt.jl | 11 ++++++----- src/contexts/init.jl | 2 +- src/test_utils/models.jl | 32 ++++++++++++++++++++++++++++++-- test/chains.jl | 3 ++- test/model.jl | 12 +++++++++++- 5 files changed, 50 insertions(+), 10 deletions(-) diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 8ad828648..07324d665 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -1,6 +1,7 @@ module DynamicPPLMCMCChainsExt using DynamicPPL: DynamicPPL, AbstractPPL, AbstractMCMC, Random +using BangBang: setindex!! using MCMCChains: MCMCChains function getindex_varname( @@ -95,11 +96,11 @@ function AbstractMCMC.to_samples( idxs = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) # Get parameters params_matrix = map(idxs) do (sample_idx, chain_idx) - d = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}() + vnt = DynamicPPL.VarNamedTuple() for vn in get_varnames(chain) - d[vn] = getindex_varname(chain, sample_idx, vn, chain_idx) + vnt = setindex!!(vnt, getindex_varname(chain, sample_idx, vn, chain_idx), vn) end - d + vnt end # Statistics stats_matrix = if :internals in MCMCChains.sections(chain) @@ -164,8 +165,8 @@ end fallback=nothing, ) -Re-evaluate `model` for each sample in `chain` using the accumulators provided in `at`, -returning an matrix of `(retval, updated_at)` tuples. +Re-evaluate `model` for each sample in `chain` using the accumulators provided in `accs`, +returning a matrix of `(retval, updated_at)` tuples. This loops over all entries in the chain and uses `DynamicPPL.InitFromParams` as the initialisation strategy when re-evaluating the model. For many usecases the fallback should diff --git a/src/contexts/init.jl b/src/contexts/init.jl index dc811df85..dd9e99421 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -169,7 +169,7 @@ InitFromParams(params) = InitFromParams(params, InitFromPrior()) function init( rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams{P} -) where {P<:Union{AbstractDict{<:VarName},NamedTuple}} +) where {P<:Union{AbstractDict{<:VarName},NamedTuple,VarNamedTuple}} # TODO(penelopeysm): It would be nice to do a check to make sure that all # of the parameters in `p.params` were actually used, and either warn or # error if they aren't. This is actually quite non-trivial though because diff --git a/src/test_utils/models.jl b/src/test_utils/models.jl index dcc2d92a2..848ed35d4 100644 --- a/src/test_utils/models.jl +++ b/src/test_utils/models.jl @@ -7,6 +7,26 @@ # # Some additionally contain an implementation of `rand_prior_true`. +""" + varname(model::Model) + +Return the VarNames defined in `model`, as a Vector. +""" +function varnames end + +# TODO(mhauru) The fact that the below function exists is a sign that we are inconsistent in +# how we handle IndexLenses. This should hopefully be resolved once we consistently use +# VarNamedTuple rather than dictionaries everywhere. +""" + varnames_split(model::Model) + +Return the VarNames in `model`, with any ranges or colons split into individual indices. + +The default implementation is to just return `varname(model)`. If something else is needed, +this should be defined separately. +""" +varnames_split(model::Model) = varnames(model) + """ demo_dynamic_constraint() @@ -77,6 +97,9 @@ end function varnames(model::Model{typeof(demo_one_variable_multiple_constraints)}) return [@varname(x[1]), @varname(x[2]), @varname(x[3]), @varname(x[4:5])] end +function varnames_split(model::Model{typeof(demo_one_variable_multiple_constraints)}) + return [@varname(x[1]), @varname(x[2]), @varname(x[3]), @varname(x[4]), @varname(x[5])] +end function logprior_true_with_logabsdet_jacobian( model::Model{typeof(demo_one_variable_multiple_constraints)}, x ) @@ -624,8 +647,13 @@ function varnames(::Model{typeof(demo_nested_colons)}) AbstractPPL.ConcretizedSlice(Base.Slice(Base.OneTo(2))), ] ), - # @varname(s.params[1].subparams[1,1,1]), - # @varname(s.params[1].subparams[1,1,2]), + @varname(m), + ] +end +function varnames_split(::Model{typeof(demo_nested_colons)}) + return [ + @varname(s.params[1].subparams[1, 1, 1]), + @varname(s.params[1].subparams[1, 1, 2]), @varname(m), ] end diff --git a/test/chains.jl b/test/chains.jl index 36c274b62..608a9a9cf 100644 --- a/test/chains.jl +++ b/test/chains.jl @@ -82,7 +82,8 @@ end ps = ParamsWithStats(params, ldf) # Check that length of parameters is as expected - @test length(ps.params) == length(keys(vi)) + expected_length = sum(prod ∘ DynamicPPL.varnamesize, keys(vi)) + @test length(ps.params) == expected_length # Iterate over all variables to check that their values match for vn in keys(vi) diff --git a/test/model.jl b/test/model.jl index 3272fd8b5..29b9650a5 100644 --- a/test/model.jl +++ b/test/model.jl @@ -58,6 +58,15 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() #### logprior, logjoint, loglikelihood for MCMC chains #### @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + if model.f === DynamicPPL.TestUtils.demo_nested_colons + # TODO(mhauru) The below test fails on this model, due to the VarName + # s.params[1].subparams[:, 1, :], which AbstractPPL.varname_leaves splits + # into subvarnames like s.params[1].subparams[:, 1, :][1, 1], but the chain + # would know as s.params[1].subparams[1, 1, 1]. Unsure what the correct fix + # is, so leaving this for later. + @test false broken = true + continue + end N = 200 chain = make_chain_from_prior(model, N) logpriors = logprior(model, chain) @@ -441,6 +450,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @testset "values_as_in_model" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.ALL_MODELS vns = DynamicPPL.TestUtils.varnames(model) + vns_split = DynamicPPL.TestUtils.varnames_split(model) example_values = DynamicPPL.TestUtils.rand_prior_true(model) varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos @@ -450,7 +460,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() realizations = values_as_in_model(model, false, varinfo) # Ensure that all variables are found. vns_found = collect(keys(realizations)) - @test vns ∩ vns_found == vns ∪ vns_found + @test vns_split ∩ vns_found == vns_split ∪ vns_found # Ensure that the values are the same. for vn in vns @test realizations[vn] == varinfo[vn] From ea807fc1fff13543738068b0867c1a25768cffa2 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 17 Dec 2025 14:32:16 +0000 Subject: [PATCH 073/148] Add hasvalue and getvalue to VNT --- src/varnamedtuple.jl | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 6bfb6c98b..ab0f8e8a1 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -2,6 +2,8 @@ module VarNamedTuples using AbstractPPL +using AbstractPPL: AbstractPPL +using Distributions: Distribution using BangBang using Accessors using ..DynamicPPL: _compose_no_identity @@ -1042,4 +1044,17 @@ function make_leaf(value, optic::IndexLens) return _setindex!!(pa, value, optic) end +function to_dict(::Type{T}, vnt::VarNamedTuple) where {T<:AbstractDict{<:VarName}} + pairs = splat(Pair).(zip(keys(vnt), values(vnt))) + return T(pairs...) +end + +function AbstractPPL.hasvalue(vnt::VarNamedTuple, vn::VarName, ::Distribution) + return haskey(vnt, vn) +end + +function AbstractPPL.getvalue(vnt::VarNamedTuple, vn::VarName, ::Distribution) + return getindex(vnt, vn) +end + end From f2d0c33eabfc5842683fd10335af1c31e96f48a5 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 17 Dec 2025 14:35:41 +0000 Subject: [PATCH 074/148] Improve keys and values for VNT --- src/varnamedtuple.jl | 50 +++++++++++++++++++++++++------------------ test/varnamedtuple.jl | 22 +++++++++++++++++-- 2 files changed, 49 insertions(+), 23 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index ab0f8e8a1..e2db48a41 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -714,11 +714,16 @@ function _merge_recursive(pa1::PartialArray, pa2::PartialArray) end function Base.keys(pa::PartialArray) - inds = findall(pa.mask) - lenses = map(x -> IndexLens(Tuple(x)), inds) + # TODO(mhauru) Should this rather be Union{}[]? It would make this very type unstable + # and cause more allocations, but would result in more concrete element types. Same + # question for Base.keys on VNT and Base.values. ks = Any[] alb_inds_seen = Set{Tuple}() - for lens in lenses + for ind in CartesianIndices(pa.mask) + @inbounds if !pa.mask[ind] + continue + end + lens = IndexLens(Tuple(ind)) val = getindex(pa.data, lens.indices...) if val isa VarNamedTuple subkeys = keys(val) @@ -728,7 +733,7 @@ function Base.keys(pa::PartialArray) end elseif val isa ArrayLikeBlock if !(val.inds in alb_inds_seen) - push!(ks, IndexLens(Tuple(val.inds))) + ks = push!!(ks, IndexLens(Tuple(val.inds))) push!(alb_inds_seen, val.inds) end else @@ -739,13 +744,21 @@ function Base.keys(pa::PartialArray) end function Base.values(pa::PartialArray) - inds = findall(pa.mask) - vs = Union{}[] - for ind in inds - val = getindex(pa.data, ind...) + vs = Any[] + albs_seen = Set{ArrayLikeBlock}() + for ind in CartesianIndices(pa.mask) + @inbounds if !pa.mask[ind] + continue + end + val = getindex(pa.data, ind) if val isa VarNamedTuple subvalues = values(val) vs = push!!(vs, subvalues...) + elseif val isa ArrayLikeBlock + if !(val in albs_seen) + vs = push!!(vs, val.block) + push!(albs_seen, val) + end else vs = push!!(vs, val) end @@ -819,14 +832,12 @@ _has_partial_array(::Type{<:PartialArray}) = true @generated function _has_partial_array( ::Type{VarNamedTuple{Names,Values}} ) where {Names,Values} - exs = Expr[] for T in Values.parameters if _has_partial_array(T) - push!(exs, :(return true)) + return :(return true) end end - push!(exs, :(return false)) - return Expr(:block, exs...) + return :(return false) end # TODO(mhauru) Should this recur to PartialArray? @@ -834,7 +845,7 @@ Base.isempty(vnt::VarNamedTuple) = isempty(vnt.data) # TODO(mhauru) Should this in fact keep the PartialArrays in place, but set them all to have # mask = fill(false, size(pa.mask))? That might save some allocations. -Base.empty(vnt::VarNamedTuple) = VarNamedTuple() +Base.empty(::VarNamedTuple) = VarNamedTuple() @generated function BangBang.empty!!(vnt::VarNamedTuple{Names,Values}) where {Names,Values} if !_has_partial_array(VarNamedTuple{Names,Values}) @@ -848,9 +859,6 @@ Base.empty(vnt::VarNamedTuple) = VarNamedTuple() new_names = (new_names..., name) end end - if length(new_names) != length(new_values) - error(new_values) - end return quote return VarNamedTuple(NamedTuple{$new_names}(($(new_values...),))) end @@ -960,19 +968,19 @@ function Base.keys(vnt::VarNamedTuple) return result end -# TODO(mhauru) Same comments as for keys. function Base.values(vnt::VarNamedTuple) - result = () + # TODO(mhauru) Same comments as for keys for type stability and Any vs Union{} + result = Any[] for sym in keys(vnt.data) subdata = vnt.data[sym] if subdata isa VarNamedTuple subvalues = values(subdata) - result = (result..., subvalues...) + append!(result, subvalues) elseif subdata isa PartialArray subvalues = values(subdata) - result = (result..., subvalues...) + append!(result, subvalues) else - result = (result..., subdata) + push!(result, subdata) end end return result diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 6578d19ae..3d8223a2a 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -38,6 +38,12 @@ function test_invariants(vnt::VarNamedTuple) # Check that merge with an empty VarNamedTuple is a no-op. @test merge(vnt, VarNamedTuple()) == vnt @test merge(VarNamedTuple(), vnt) == vnt + # Check that the VNT can be constructed back from its keys and values. + vnt4 = VarNamedTuple() + for (k, v) in zip(keys(vnt), values(vnt)) + vnt4 = setindex!!(vnt4, v, k) + end + @test vnt == vnt4 end """ A type that has a size but is not an Array. Used in ArrayLikeBlock tests.""" @@ -371,26 +377,32 @@ Base.size(st::SizedThing) = st.size @test merge(vnt2, vnt1) == expected_merge_21 end - @testset "keys" begin + @testset "keys and values" begin vnt = VarNamedTuple() @test @inferred(keys(vnt)) == VarName[] + @test @inferred(values(vnt)) == Any[] vnt = setindex!!(vnt, 1.0, @varname(a)) # TODO(mhauru) that the below passes @inferred, but any of the later ones don't. # We should improve type stability of keys(). @test @inferred(keys(vnt)) == [@varname(a)] + @test @inferred(values(vnt)) == [1.0] vnt = setindex!!(vnt, [1, 2, 3], @varname(b)) @test keys(vnt) == [@varname(a), @varname(b)] + @test values(vnt) == [1.0, [1,2,3]] vnt = setindex!!(vnt, 15, @varname(b[2])) @test keys(vnt) == [@varname(a), @varname(b)] + @test values(vnt) == [1.0, [1,15,3]] vnt = setindex!!(vnt, [10], @varname(c.x.y)) @test keys(vnt) == [@varname(a), @varname(b), @varname(c.x.y)] + @test values(vnt) == [1.0, [1,15,3], [10]] vnt = setindex!!(vnt, -1.0, @varname(d[4])) @test keys(vnt) == [@varname(a), @varname(b), @varname(c.x.y), @varname(d[4])] + @test values(vnt) == [1.0, [1,15,3], [10], -1.0] vnt = setindex!!(vnt, 2.0, @varname(e.f[3, 3].g.h[2, 4, 1].i)) @test keys(vnt) == [ @@ -400,6 +412,7 @@ Base.size(st::SizedThing) = st.size @varname(d[4]), @varname(e.f[3, 3].g.h[2, 4, 1].i), ] + @test values(vnt) == [1.0, [1,15,3], [10], -1.0, 2.0] vnt = setindex!!(vnt, fill(1.0, 4), @varname(j[1:4])) @test keys(vnt) == [ @@ -413,8 +426,10 @@ Base.size(st::SizedThing) = st.size @varname(j[3]), @varname(j[4]), ] + @test values(vnt) == [1.0, [1,15,3], [10], -1.0, 2.0, fill(1.0, 4)...] + - vnt = setindex!!(vnt, 1.0, @varname(j[6])) + vnt = setindex!!(vnt, "a", @varname(j[6])) @test keys(vnt) == [ @varname(a), @varname(b), @@ -427,6 +442,7 @@ Base.size(st::SizedThing) = st.size @varname(j[4]), @varname(j[6]), ] + @test values(vnt) == [1.0, [1,15,3], [10], -1.0, 2.0, fill(1.0, 4)..., "a"] vnt = setindex!!(vnt, 1.0, @varname(n[2].a)) @test keys(vnt) == [ @@ -442,6 +458,7 @@ Base.size(st::SizedThing) = st.size @varname(j[6]), @varname(n[2].a), ] + @test values(vnt) == [1.0, [1,15,3], [10], -1.0, 2.0, fill(1.0, 4)..., "a", 1.0] vnt = setindex!!(vnt, SizedThing((3, 1, 4)), @varname(o[2:4, 5:5, 11:14])) @test keys(vnt) == [ @@ -458,6 +475,7 @@ Base.size(st::SizedThing) = st.size @varname(n[2].a), @varname(o[2:4, 5:5, 11:14]), ] + @test values(vnt) == [1.0, [1,15,3], [10], -1.0, 2.0, fill(1.0, 4)..., "a", 1.0, SizedThing((3, 1, 4))] end @testset "printing" begin From 086fd7c7ae8f7ad8046342e7af71b3b044e7c61e Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 17 Dec 2025 14:36:42 +0000 Subject: [PATCH 075/148] Add length and _to_dense_array for VNT --- src/varnamedtuple.jl | 94 ++++++++++++++++++++++++++++++++++++++++++- test/varnamedtuple.jl | 62 +++++++++++++++++++++++----- 2 files changed, 145 insertions(+), 11 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index e2db48a41..18c097087 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -766,6 +766,76 @@ function Base.values(pa::PartialArray) return vs end +function Base.length(pa::PartialArray) + len = 0 + for ind in CartesianIndices(pa.mask) + @inbounds if !pa.mask[ind] + continue + end + val = getindex(pa.data, ind) + if val isa VarNamedTuple + len += length(val) + else + # Note we don't need to special case here for ArrayLikeBlocks. That's because + # we want to treat index pointing to the same ArrayLikeBlock as contributing to + # the length. + len += 1 + end + end + return len +end + +""" + _dense_array(pa::PartialArray) + +Return a `Base.Array` of the elements of the `PartialArray`. + +If the `PartialArray` has any missing elements that are "within" the block of set elements, +this will error. Likewise, if any elements are blocks set as ArrayLikeBlocks, this will +error. +""" +function _dense_array(pa::PartialArray) + # Find the size of the dense array, by checking what are the largest indices set in pa. + num_dims = ndims(pa) + size_needed = fill(0, num_dims) + # TODO(mhauru) This could be optimised by not looping over the whole array: If e.g. + # (3,3) is set, we have no need to check any indices within the block (3,3). + for ind in CartesianIndices(pa.mask) + @inbounds if !pa.mask[ind] + continue + end + for d in 1:num_dims + size_needed[d] = max(size_needed[d], ind[d]) + end + end + + # Check that all indices within size_needed are set. + slice = ntuple(d -> 1:size_needed[d], num_dims) + if any(.!(pa.mask[slice...])) + throw( + ArgumentError( + "Cannot convert PartialArray to dense Array when some elements within " * + "the dense block are not set.", + ), + ) + end + + retval = pa.data[slice...] + if eltype(pa) <: ArrayLikeBlock || ArrayLikeBlock <: eltype(pa) + for ind in CartesianIndices(retval) + @inbounds if retval[ind] isa ArrayLikeBlock + throw( + ArgumentError( + "Cannot convert PartialArray to dense Array when some elements " * + "are set as ArrayLikeBlocks.", + ), + ) + end + end + end + return retval +end + """ VarNamedTuple{names,Values} @@ -986,6 +1056,21 @@ function Base.values(vnt::VarNamedTuple) return result end +function Base.length(vnt::VarNamedTuple) + len = 0 + for sym in keys(vnt.data) + subdata = vnt.data[sym] + if subdata isa VarNamedTuple + len += length(subdata) + elseif subdata isa PartialArray + len += length(subdata) + else + len += 1 + end + end + return len +end + # The following methods, indexing with ComposedFunction, are exactly the same for # VarNamedTuple and PartialArray, since they just fall back on indexing with the outer and # inner lenses. @@ -1011,7 +1096,13 @@ end # The entry points for getting, setting, and checking, using the familiar functions. Base.haskey(vnt::VarNamedTuple, vn::VarName) = _haskey(vnt, vn) -Base.getindex(vnt::VarNamedTuple, vn::VarName) = _getindex(vnt, vn) + +# PartialArrays are an implementation detail of VarNamedTuple, and should never be the +# return value of getindex. Thus, we automatically convert them to dense arrays if needed. +_dense_array_if_needed(pa::PartialArray) = _dense_array(pa) +_dense_array_if_needed(x) = x +Base.getindex(vnt::VarNamedTuple, vn::VarName) = _dense_array_if_needed(_getindex(vnt, vn)) + BangBang.setindex!!(vnt::VarNamedTuple, value, vn::VarName) = _setindex!!(vnt, value, vn) Base.haskey(vnt::PartialArray, key) = _haskey(vnt, key) @@ -1056,6 +1147,7 @@ function to_dict(::Type{T}, vnt::VarNamedTuple) where {T<:AbstractDict{<:VarName pairs = splat(Pair).(zip(keys(vnt), values(vnt))) return T(pairs...) end +to_dict(vnt::VarNamedTuple) = to_dict(Dict{VarName,Any}, vnt) function AbstractPPL.hasvalue(vnt::VarNamedTuple, vn::VarName, ::Distribution) return haskey(vnt, vn) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 3d8223a2a..ff7907c08 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -390,19 +390,19 @@ Base.size(st::SizedThing) = st.size vnt = setindex!!(vnt, [1, 2, 3], @varname(b)) @test keys(vnt) == [@varname(a), @varname(b)] - @test values(vnt) == [1.0, [1,2,3]] + @test values(vnt) == [1.0, [1, 2, 3]] vnt = setindex!!(vnt, 15, @varname(b[2])) @test keys(vnt) == [@varname(a), @varname(b)] - @test values(vnt) == [1.0, [1,15,3]] + @test values(vnt) == [1.0, [1, 15, 3]] vnt = setindex!!(vnt, [10], @varname(c.x.y)) @test keys(vnt) == [@varname(a), @varname(b), @varname(c.x.y)] - @test values(vnt) == [1.0, [1,15,3], [10]] + @test values(vnt) == [1.0, [1, 15, 3], [10]] vnt = setindex!!(vnt, -1.0, @varname(d[4])) @test keys(vnt) == [@varname(a), @varname(b), @varname(c.x.y), @varname(d[4])] - @test values(vnt) == [1.0, [1,15,3], [10], -1.0] + @test values(vnt) == [1.0, [1, 15, 3], [10], -1.0] vnt = setindex!!(vnt, 2.0, @varname(e.f[3, 3].g.h[2, 4, 1].i)) @test keys(vnt) == [ @@ -412,7 +412,7 @@ Base.size(st::SizedThing) = st.size @varname(d[4]), @varname(e.f[3, 3].g.h[2, 4, 1].i), ] - @test values(vnt) == [1.0, [1,15,3], [10], -1.0, 2.0] + @test values(vnt) == [1.0, [1, 15, 3], [10], -1.0, 2.0] vnt = setindex!!(vnt, fill(1.0, 4), @varname(j[1:4])) @test keys(vnt) == [ @@ -426,8 +426,7 @@ Base.size(st::SizedThing) = st.size @varname(j[3]), @varname(j[4]), ] - @test values(vnt) == [1.0, [1,15,3], [10], -1.0, 2.0, fill(1.0, 4)...] - + @test values(vnt) == [1.0, [1, 15, 3], [10], -1.0, 2.0, fill(1.0, 4)...] vnt = setindex!!(vnt, "a", @varname(j[6])) @test keys(vnt) == [ @@ -442,7 +441,7 @@ Base.size(st::SizedThing) = st.size @varname(j[4]), @varname(j[6]), ] - @test values(vnt) == [1.0, [1,15,3], [10], -1.0, 2.0, fill(1.0, 4)..., "a"] + @test values(vnt) == [1.0, [1, 15, 3], [10], -1.0, 2.0, fill(1.0, 4)..., "a"] vnt = setindex!!(vnt, 1.0, @varname(n[2].a)) @test keys(vnt) == [ @@ -458,7 +457,7 @@ Base.size(st::SizedThing) = st.size @varname(j[6]), @varname(n[2].a), ] - @test values(vnt) == [1.0, [1,15,3], [10], -1.0, 2.0, fill(1.0, 4)..., "a", 1.0] + @test values(vnt) == [1.0, [1, 15, 3], [10], -1.0, 2.0, fill(1.0, 4)..., "a", 1.0] vnt = setindex!!(vnt, SizedThing((3, 1, 4)), @varname(o[2:4, 5:5, 11:14])) @test keys(vnt) == [ @@ -475,7 +474,50 @@ Base.size(st::SizedThing) = st.size @varname(n[2].a), @varname(o[2:4, 5:5, 11:14]), ] - @test values(vnt) == [1.0, [1,15,3], [10], -1.0, 2.0, fill(1.0, 4)..., "a", 1.0, SizedThing((3, 1, 4))] + @test values(vnt) == [ + 1.0, + [1, 15, 3], + [10], + -1.0, + 2.0, + fill(1.0, 4)..., + "a", + 1.0, + SizedThing((3, 1, 4)), + ] + end + + @testset "length" begin + vnt = VarNamedTuple() + @test @inferred(length(vnt)) == 0 + + vnt = setindex!!(vnt, 1.0, @varname(a)) + @test @inferred(length(vnt)) == 1 + + vnt = setindex!!(vnt, [1, 2, 3], @varname(b)) + @test @inferred(length(vnt)) == 2 + + vnt = setindex!!(vnt, 15, @varname(b[2])) + @test @inferred(length(vnt)) == 2 + + vnt = setindex!!(vnt, [10, 11], @varname(c.x.y)) + @test @inferred(length(vnt)) == 3 + + vnt = setindex!!(vnt, -1.0, @varname(d[4])) + @test @inferred(length(vnt)) == 4 + + vnt = setindex!!(vnt, ["a", "b"], @varname(d[1:2])) + @test @inferred(length(vnt)) == 6 + + vnt = setindex!!(vnt, 2.0, @varname(e.f[3].g.h[2].i)) + vnt = setindex!!(vnt, 3.0, @varname(e.f[3].g.h[2].j)) + @test @inferred(length(vnt)) == 8 + + vnt = setindex!!(vnt, SizedThing((3, 2)), @varname(x[1, 2:4, 2, 1:2, 3])) + @test @inferred(length(vnt)) == 14 + + vnt = setindex!!(vnt, SizedThing((3, 2)), @varname(x[1, 4:6, 2, 1:2, 3])) + @test @inferred(length(vnt)) == 14 end @testset "printing" begin From e98b0e8d6224e376de7632ef30e4944fed1ffd28 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 17 Dec 2025 15:27:05 +0000 Subject: [PATCH 076/148] Fix empty and friends for VNT, add tests --- src/varnamedtuple.jl | 63 ++++++++++++++++++++++++++++++++++++++----- test/varnamedtuple.jl | 50 +++++++++++++++++++++++++++++++--- 2 files changed, 103 insertions(+), 10 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 18c097087..42f5fd4d1 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -339,6 +339,8 @@ function Base.hash(pa::PartialArray, h::UInt) return h end +Base.isempty(pa::PartialArray) = !any(pa.mask) +Base.empty(pa::PartialArray) = PartialArray(similar(pa.data), fill(false, size(pa.mask))) function BangBang.empty!!(pa::PartialArray) for i in eachindex(pa.mask) @inbounds pa.mask[i] = false @@ -896,9 +898,13 @@ function Base.copy(vnt::VarNamedTuple{names}) where {names} ) end -_has_partial_array(::Type{T}) where {T} = false -_has_partial_array(::Type{<:PartialArray}) = true +""" + _has_partial_array(::Type{VarNamedTuple{Names,Values}}) where {Names,Values} +Check if any of the types in the `Values` tuple is or contains a `PartialArray`. + +Recurses into any sub-`VarNamedTuple`s. +""" @generated function _has_partial_array( ::Type{VarNamedTuple{Names,Values}} ) where {Names,Values} @@ -910,22 +916,35 @@ _has_partial_array(::Type{<:PartialArray}) = true return :(return false) end -# TODO(mhauru) Should this recur to PartialArray? -Base.isempty(vnt::VarNamedTuple) = isempty(vnt.data) +_has_partial_array(::Type{T}) where {T} = false +_has_partial_array(::Type{<:PartialArray}) = true -# TODO(mhauru) Should this in fact keep the PartialArrays in place, but set them all to have -# mask = fill(false, size(pa.mask))? That might save some allocations. Base.empty(::VarNamedTuple) = VarNamedTuple() +""" + empty!!(vnt::VarNamedTuple) + +Create an empty version of `vnt` in place. + +This differs from `Base.empty` in that any `PartialArray`s contained within `vnt` are kept +but have their contents deleted, rather than being removed entirely. This means that + +1) The result has a "memory" of how many dimensions different variables had, and you cannot, + for example, set `a[1,2]` after emptying a `VarNamedTuple` that had only `a[1]` defined. +2) Memory allocations may be reduced when reusing `VarNamedTuple`s, since the internal + `PartialArray`s do not need to be reallocated from scratch. +""" @generated function BangBang.empty!!(vnt::VarNamedTuple{Names,Values}) where {Names,Values} if !_has_partial_array(VarNamedTuple{Names,Values}) return :(return VarNamedTuple()) end + # Check all the fields of the NamedTuple, and keep the ones that contain PartialArrays, + # calling empty!! on them recursively. new_names = () new_values = () for (name, ValType) in zip(Names, Values.parameters) if _has_partial_array(ValType) - new_values = (new_values..., :(BangBang.empty!!(vnt.data[$(QuoteNode(name))]))) + new_values = (new_values..., :(BangBang.empty!!(vnt.data.$name))) new_names = (new_names..., name) end end @@ -934,6 +953,36 @@ Base.empty(::VarNamedTuple) = VarNamedTuple() end end +@generated function Base.isempty(vnt::VarNamedTuple{Names,Values}) where {Names,Values} + if isempty(Names) + return :(return true) + end + if !_has_partial_array(VarNamedTuple{Names,Values}) + return :(return false) + end + exs = Expr[] + for (name, ValType) in zip(Names, Values.parameters) + if !_has_partial_array(ValType) + return :(return false) + end + push!( + exs, + quote + val = vnt.data.$name + if val isa VarNamedTuple || val isa PartialArray + if !Base.isempty(val) + return false + end + else + return false + end + end, + ) + end + push!(exs, :(return true)) + return Expr(:block, exs...) +end + """ varname_to_lens(name::VarName{S}) where {S} diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index ff7907c08..dc01cdaf8 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -5,7 +5,7 @@ using Test: @inferred, @test, @test_throws, @testset using DynamicPPL: DynamicPPL, @varname, VarNamedTuple using DynamicPPL.VarNamedTuples: PartialArray, ArrayLikeBlock using AbstractPPL: VarName, concretize, prefix -using BangBang: setindex!! +using BangBang: setindex!!, empty!! """ test_invariants(vnt::VarNamedTuple) @@ -15,8 +15,11 @@ Test properties that should hold for all VarNamedTuples. Uses @test for all the tests. Intended to be called inside a @testset. """ function test_invariants(vnt::VarNamedTuple) + # These will be needed repeatedly. + vnt_keys = keys(vnt) + vnt_values = values(vnt) # Check that for all keys in vnt, haskey is true, and resetting the value is a no-op. - for k in keys(vnt) + for k in vnt_keys @test haskey(vnt, k) v = getindex(vnt, k) # ArrayLikeBlocks are an implementation detail, and should not be exposed through @@ -40,10 +43,22 @@ function test_invariants(vnt::VarNamedTuple) @test merge(VarNamedTuple(), vnt) == vnt # Check that the VNT can be constructed back from its keys and values. vnt4 = VarNamedTuple() - for (k, v) in zip(keys(vnt), values(vnt)) + for (k, v) in zip(vnt_keys, vnt_values) vnt4 = setindex!!(vnt4, v, k) end @test vnt == vnt4 + # Check that vnt isempty only if it has no keys + was_empty = isempty(vnt) + @test was_empty == isempty(vnt_keys) + @test was_empty == isempty(vnt_values) + # Check that vnt can be emptied + @test empty(vnt) == VarNamedTuple() + emptied_vnt = empty!!(copy(vnt)) + @test isempty(emptied_vnt) + @test isempty(keys(emptied_vnt)) + @test isempty(values(emptied_vnt)) + # Check that the copy protected the original vnt from being modified. + @test isempty(vnt) == was_empty end """ A type that has a size but is not an Array. Used in ArrayLikeBlock tests.""" @@ -520,6 +535,35 @@ Base.size(st::SizedThing) = st.size @test @inferred(length(vnt)) == 14 end + @testset "empty" begin + # test_invariants already checks that many different kinds of VarNamedTuples can be + # emptied with empty and empty!!. What remains to check here is that + # 1) isempty gives the expected results: + vnt = VarNamedTuple() + @test @inferred(isempty(vnt)) == true + vnt = setindex!!(vnt, 1.0, @varname(a)) + @test @inferred(isempty(vnt)) == false + + vnt = VarNamedTuple() + vnt = setindex!!(vnt, [], @varname(a[1])) + @test @inferred(isempty(vnt)) == false + + # 2) empty!! keeps PartialArrays in place: + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, [1, 2, 3], @varname(a[1:3]))) + vnt = @inferred(empty!!(vnt)) + @test !haskey(vnt, @varname(a[1])) + @test !haskey(vnt, @varname(a[1:3])) + @test haskey(vnt, @varname(a)) + @test_throws BoundsError getindex(vnt, @varname(a[1])) + @test_throws BoundsError getindex(vnt, @varname(a[1:3])) + @test getindex(vnt, @varname(a)) == [] + vnt = @inferred(setindex!!(vnt, [1, 2, 3], @varname(a[2:4]))) + @test @inferred(getindex(vnt, @varname(a[2:4]))) == [1, 2, 3] + @test haskey(vnt, @varname(a[2:4])) + @test !haskey(vnt, @varname(a[1])) + end + @testset "printing" begin vnt = VarNamedTuple() io = IOBuffer() From 07bf11f68c07b73c3b7a710a480b75ee75ea0c1e Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 17 Dec 2025 15:36:01 +0000 Subject: [PATCH 077/148] Add tests of VNT densification --- test/varnamedtuple.jl | 35 +++++++++++++++++++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index dc01cdaf8..7b81708ed 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -22,9 +22,10 @@ function test_invariants(vnt::VarNamedTuple) for k in vnt_keys @test haskey(vnt, k) v = getindex(vnt, k) - # ArrayLikeBlocks are an implementation detail, and should not be exposed through - # getindex. + # ArrayLikeBlocks and PartialArrays are implementation details, and should not be + # exposed through getindex. @test !(v isa ArrayLikeBlock) + @test !(v isa PartialArray) vnt2 = setindex!!(copy(vnt), v, k) @test vnt == vnt2 @test isequal(vnt, vnt2) @@ -564,6 +565,36 @@ Base.size(st::SizedThing) = st.size @test !haskey(vnt, @varname(a[1])) end + @testset "densification" begin + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, 1.0, @varname(a.b[1].c[1, 1]))) + @test @inferred(getindex(vnt, @varname(a.b[1].c))) == fill(1.0, (1, 1)) + + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, 1.0, @varname(a.b[1].c[1, 1]))) + vnt = @inferred(setindex!!(vnt, 1.0, @varname(a.b[1].c[1, 2]))) + @test @inferred(getindex(vnt, @varname(a.b[1].c))) == fill(1.0, (1, 2)) + + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, 1.0, @varname(a.b[1].c[1, 1]))) + vnt = @inferred(setindex!!(vnt, 1.0, @varname(a.b[1].c[2, 1]))) + @test @inferred(getindex(vnt, @varname(a.b[1].c))) == fill(1.0, (2, 1)) + + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, 1.0, @varname(a.b[1].c[1, 1]))) + vnt = @inferred(setindex!!(vnt, 1.0, @varname(a.b[1].c[1, 2]))) + vnt = @inferred(setindex!!(vnt, 1.0, @varname(a.b[1].c[2, 1]))) + @test_throws ArgumentError @inferred(getindex(vnt, @varname(a.b[1].c))) + vnt = @inferred(setindex!!(vnt, 1.0, @varname(a.b[1].c[2, 2]))) + @test @inferred(getindex(vnt, @varname(a.b[1].c))) == fill(1.0, (2, 2)) + vnt = @inferred(setindex!!(vnt, 1.0, @varname(a.b[1].c[3, 3]))) + @test_throws ArgumentError @inferred(getindex(vnt, @varname(a.b[1].c))) + + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, SizedThing((2,)), @varname(x[1:2]))) + @test_throws ArgumentError @inferred(getindex(vnt, @varname(x))) + end + @testset "printing" begin vnt = VarNamedTuple() io = IOBuffer() From 852609d784d40bda8a29c7c5e2d6320df8fc039c Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 17 Dec 2025 16:10:19 +0000 Subject: [PATCH 078/148] Fix some comments --- src/test_utils/models.jl | 4 ++-- src/values_as_in_model.jl | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/test_utils/models.jl b/src/test_utils/models.jl index 848ed35d4..c7fe623fe 100644 --- a/src/test_utils/models.jl +++ b/src/test_utils/models.jl @@ -8,7 +8,7 @@ # Some additionally contain an implementation of `rand_prior_true`. """ - varname(model::Model) + varnames(model::Model) Return the VarNames defined in `model`, as a Vector. """ @@ -22,7 +22,7 @@ function varnames end Return the VarNames in `model`, with any ranges or colons split into individual indices. -The default implementation is to just return `varname(model)`. If something else is needed, +The default implementation is to just return `varnames(model)`. If something else is needed, this should be defined separately. """ varnames_split(model::Model) = varnames(model) diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index d7d2da18a..365ace706 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -30,6 +30,9 @@ end accumulator_name(::Type{<:ValuesAsInModelAccumulator}) = :ValuesAsInModel +# TODO(mhauru) We could start using reset!!, which could call empty!! on the VarNamedTuple. +# This would create VarNamedTuples that share memory with the original one, saving +# allocations but also making them not capable of taking in any arbitrary VarName. function _zero(acc::ValuesAsInModelAccumulator) return ValuesAsInModelAccumulator(empty(acc.values), acc.include_colon_eq) end From 4a585adb3cbb01193a54ed45fb7eab31a4289d39 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 17 Dec 2025 16:13:15 +0000 Subject: [PATCH 079/148] Apply polish --- src/values_as_in_model.jl | 4 +++- src/varnamedtuple.jl | 16 +++++++++------- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index 365ace706..992cbdc8d 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -49,7 +49,9 @@ function combine(acc1::ValuesAsInModelAccumulator, acc2::ValuesAsInModelAccumula end function BangBang.push!!(acc::ValuesAsInModelAccumulator, vn::VarName, val) - # TODO(mhauru) Why do we deepcopy here? + # TODO(mhauru) The deepcopy here is quite unfortunate. It is needed so that the model + # body can go mutating the object without that reactively affecting the value in the + # accumulator, which should be as it was at `~` time. Could there be a way around this? Accessors.@reset acc.values = setindex!!(acc.values, deepcopy(val), vn) return acc end diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 42f5fd4d1..bb1f4a14b 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -731,15 +731,15 @@ function Base.keys(pa::PartialArray) subkeys = keys(val) for vn in subkeys sublens = _varname_to_lens(vn) - ks = push!!(ks, _compose_no_identity(sublens, lens)) + push!(ks, _compose_no_identity(sublens, lens)) end elseif val isa ArrayLikeBlock if !(val.inds in alb_inds_seen) - ks = push!!(ks, IndexLens(Tuple(val.inds))) + push!(ks, IndexLens(Tuple(val.inds))) push!(alb_inds_seen, val.inds) end else - ks = push!!(ks, lens) + push!(ks, lens) end end return ks @@ -779,7 +779,7 @@ function Base.length(pa::PartialArray) len += length(val) else # Note we don't need to special case here for ArrayLikeBlocks. That's because - # we want to treat index pointing to the same ArrayLikeBlock as contributing to + # we treat every index pointing to the same ArrayLikeBlock as contributing to # the length. len += 1 end @@ -792,9 +792,11 @@ end Return a `Base.Array` of the elements of the `PartialArray`. -If the `PartialArray` has any missing elements that are "within" the block of set elements, -this will error. Likewise, if any elements are blocks set as ArrayLikeBlocks, this will -error. +If the `PartialArray` has any missing elements that are within the block of set elements, +this will error. For instance, if `pa` is two-dimensional and (2,2) is set, but one of +(1,1), (1,2), or (2,1) is not. + +Likewise, if `pa` includes any blocks set as `ArrayLikeBlocks`, this will error. """ function _dense_array(pa::PartialArray) # Find the size of the dense array, by checking what are the largest indices set in pa. From 34ad663e52a64282c3d6c6068e55e42a464f1521 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 18 Dec 2025 12:29:29 +0000 Subject: [PATCH 080/148] Fix and improve map!! and apply!! --- src/varnamedtuple.jl | 121 +++++++++++++++++++++++++++++++++++++----- test/varnamedtuple.jl | 83 ++++++++++++++++++++++++++++- 2 files changed, 190 insertions(+), 14 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index bb1f4a14b..c5cb2c681 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -8,7 +8,7 @@ using BangBang using Accessors using ..DynamicPPL: _compose_no_identity -export VarNamedTuple +export VarNamedTuple, map!!, apply!! # We define our own getindex, setindex!!, and haskey functions, which we use to # get/set/check values in VarNamedTuple and PartialArray. We do this because we want to be @@ -19,12 +19,33 @@ export VarNamedTuple # 2. We would want `haskey` to fall back onto `checkbounds` when called on Base.Arrays. function _getindex end function _haskey end + +""" + _setindex!!(collection, value, key; allow_new=Val(true)) + +Like `setindex!!`, but special-cased for `VarNamedTuple` and `PartialArray` to recurse +into nested structures. + +The `allow_new` keywword argument is a performance optimisation: If it is set to +`Val(false)`, the function can assume that the key being set already exists in `collection`. +This allows skipping some code paths, which may have a minor benefit at runtime, but more +importantly, allows for better constant propagation and type stability at compile time. + +`allow_new` being set to `Val(false)` does _not_ guarantee that no new keys will be added. +It only gives the implementation of `_setindex!!` the permission to assume that the key +already exists. Setting it to `Val(false)` should be done only when the caller is sure that +the key already exists, anything else is a bug in the caller. + +Most methods of _setindex!! ignore the `allow_new` keyword argument, as they have no use for +it. See the method for setting values in a `VarNamedTuple` with a `ComposedFunction` for +when it is useful. +""" function _setindex!! end _getindex(arr::AbstractArray, optic::IndexLens) = getindex(arr, optic.indices...) _haskey(arr::AbstractArray, optic::IndexLens) = _haskey(arr, optic.indices) _haskey(arr::AbstractArray, inds) = checkbounds(Bool, arr, inds...) -function _setindex!!(arr::AbstractArray, value, optic::IndexLens) +function _setindex!!(arr::AbstractArray, value, optic::IndexLens; allow_new=Val(true)) return setindex!!(arr, value, optic.indices...) end @@ -451,7 +472,7 @@ end _getindex(pa::PartialArray, optic::IndexLens) = _getindex(pa, optic.indices...) _haskey(pa::PartialArray, optic::IndexLens) = _haskey(pa, optic.indices) -function _setindex!!(pa::PartialArray, value, optic::IndexLens) +function _setindex!!(pa::PartialArray, value, optic::IndexLens; allow_new=Val(true)) return _setindex!!(pa, value, optic.indices...) end @@ -1006,11 +1027,13 @@ _haskey(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} = haskey(vnt.data, S) _haskey(vnt::VarNamedTuple, ::typeof(identity)) = true _haskey(::VarNamedTuple, ::IndexLens) = false -function _setindex!!(vnt::VarNamedTuple, value, name::VarName) - return _setindex!!(vnt, value, _varname_to_lens(name)) +function _setindex!!(vnt::VarNamedTuple, value, name::VarName; allow_new=Val(true)) + return _setindex!!(vnt, value, _varname_to_lens(name); allow_new=allow_new) end -function _setindex!!(vnt::VarNamedTuple, value, ::PropertyLens{S}) where {S} +function _setindex!!( + vnt::VarNamedTuple, value, ::PropertyLens{S}; allow_new=Val(true) +) where {S} # I would like for this to just read # return VarNamedTuple(_setindex!!(vnt.data, value, S)) # but that seems to be type unstable. Why? Shouldn't it obviously be the same as the @@ -1041,13 +1064,13 @@ Base.merge(x1::VarNamedTuple, x2::VarNamedTuple) = _merge_recursive(x1, x2) return Expr(:block, exs...) end -# TODO(mhauru) The below remains unfinished an undertested. I think it's incorrect for more -# complex VarNames. It is unexported though. """ apply!!(func, vnt::VarNamedTuple, name::VarName) Apply `func` to the subdata at `name` in `vnt`, and set the result back at `name`. +Like `map!!`, but only for a single `VarName`. + ```jldoctest julia> using DynamicPPL: VarNamedTuple, setindex!! @@ -1069,9 +1092,71 @@ function apply!!(func, vnt::VarNamedTuple, name::VarName) end subdata = _getindex(vnt, name) new_subdata = func(subdata) - return _setindex!!(vnt, new_subdata, name) + # The allow_new=Val(true) is a performance optimisation: Since we've already checked + # that the key exists, we know that no new fields will be created. + return _setindex!!(vnt, new_subdata, name; allow_new=Val(false)) +end + +""" + _map_recursive!!(func, x) + +Call `func` on `x`, except if `x` is a `VarNamedTuple` or `PartialArray`, in which case +call `_map_recursive!!` recursively on all their elements.. + +This is the internal implementation of `map!!`, but because it has a method defined for +literally every type in existence, we hide it behind the interface of the more +discriminating `map!!`. It makes the implementation a bit simpler, compared to checking +element types within `map!!` itself. +""" +_map_recursive!!(func, x) = func(x) + +function _map_recursive!!(func, pa::PartialArray) + # Ask the compiler to infer the return type of applying func to eltype(pa). + new_et = Core.Compiler.return_type(x -> _map_recursive!!(func, x), Tuple{eltype(pa)}) + new_data = if new_et <: eltype(pa) + pa.data + else + similar(pa.data, new_et) + end + @inbounds for i in eachindex(pa.mask) + if pa.mask[i] + new_data[i] = _map_recursive!!(func, pa.data[i]) + end + end + # The above type inference may be overly conservative, so we concretise the eltype. + return _concretise_eltype!!(PartialArray(new_data, pa.mask)) +end + +function _map_recursive!!(func, alb::ArrayLikeBlock) + new_block = _map_recursive!!(func, alb.block) + if size(new_block) != size(alb.block) + throw( + DimensionMismatch( + "map!! can't change the size of an ArrayLikeBlock. Tried to change from" * + "$(size(alb.block)) to $(size(new_block)).", + ), + ) + end + return ArrayLikeBlock(new_block, alb.inds) +end + +@generated function _map_recursive!!(func, vnt::VarNamedTuple{Names}) where {Names} + exs = Expr[] + for name in Names + push!(exs, :(_map_recursive!!(func, vnt.data.$name))) + end + return quote + return VarNamedTuple(NamedTuple{Names}(($(exs...),))) + end end +""" + map!!(func, vnt::VarNamedTuple) + +Apply `func` to all set elements of the `vnt`, in place if possible. +""" +map!!(func, vnt::VarNamedTuple) = _map_recursive!!(func, vnt) + function Base.keys(vnt::VarNamedTuple) result = VarName[] for sym in keys(vnt.data) @@ -1132,13 +1217,23 @@ function _getindex(x::VNT_OR_PA, optic::ComposedFunction) return _getindex(subdata, optic.outer) end -function _setindex!!(vnt::VNT_OR_PA, value, optic::ComposedFunction) +# The allow_new keyword argument is a performance optimisation that helps constant +# propagation and type inference by avoiding any possible dynamic dispatch calls to +# `make_leaf`. It should only be set to `Val(false) if we are sure that the key already +# exists, and thus there would be no need to call `make_leaf`. +function _setindex!!(vnt::VNT_OR_PA, value, optic::ComposedFunction; allow_new=Val(true)) sub = if _haskey(vnt, optic.inner) - _setindex!!(_getindex(vnt, optic.inner), value, optic.outer) - else + _setindex!!(_getindex(vnt, optic.inner), value, optic.outer; allow_new=allow_new) + elseif allow_new isa Val{true} make_leaf(value, optic.outer) + else + # If this branch is ever reached, then someone has used allow_new=Val(false) + # incorrectly. + error(""" + _setindex was called with allow_new=Val(false) but the key does not exist. + This indicates a bug in DynamicPPL: Please file an issue on GitHub.""") end - return _setindex!!(vnt, sub, optic.inner) + return _setindex!!(vnt, sub, optic.inner; allow_new=allow_new) end function _haskey(vnt::VNT_OR_PA, optic::ComposedFunction) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 7b81708ed..25cd14e49 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -3,7 +3,7 @@ module VarNamedTupleTests using Combinatorics: Combinatorics using Test: @inferred, @test, @test_throws, @testset using DynamicPPL: DynamicPPL, @varname, VarNamedTuple -using DynamicPPL.VarNamedTuples: PartialArray, ArrayLikeBlock +using DynamicPPL.VarNamedTuples: PartialArray, ArrayLikeBlock, map!!, apply!! using AbstractPPL: VarName, concretize, prefix using BangBang: setindex!!, empty!! @@ -741,6 +741,87 @@ Base.size(st::SizedThing) = st.size @test haskey(vnt, @varname(y.z[3, 2:3, 3, 2:3, 4])) @test @inferred(getindex(vnt, @varname(y.z[3, 2:3, 3, 2:3, 4]))) == val end + + @testset "map!! and apply!!" begin + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, 1, @varname(a))) + vnt = @inferred(setindex!!(vnt, [2, 2], @varname(b[1:2]))) + vnt = @inferred(setindex!!(vnt, [3.0], @varname(c.d))) + vnt = @inferred(setindex!!(vnt, "a", @varname(e.f[3].g.h[2].i))) + # The below can't be type stable because the element type of `h` depends on whether + # we are setting `h[2].j` (which overwrites the earlier `h[2]`) or some other + # `h[index].j` (which would leave both `h[2].i` and `h[index].j` in the same array). + vnt = setindex!!(vnt, 5.0, @varname(e.f[3].g.h[2].j)) + vnt = @inferred( + setindex!!(vnt, SizedThing((2, 2)), @varname(y.z[3, 2:3, 3, 2:3, 4])) + ) + test_invariants(vnt) + + struct AnotherSizedThing{T<:Tuple} + size::T + end + Base.size(st::AnotherSizedThing) = st.size + + function f(val) + if val isa Int + return val + 10 + elseif val isa AbstractVector{Int} + return val .+ 10 + elseif val isa Float64 + return val + 1.0 + elseif val isa AbstractVector{Float64} + return val .- 1.0 + elseif val isa String + return string(val, "b") + elseif val isa SizedThing + return AnotherSizedThing(size(val)) + else + error("Unexpected value type $(typeof(val))") + end + end + + vnt_mapped = @inferred(map!!(f, copy(vnt))) + test_invariants(vnt_mapped) + @test @inferred(getindex(vnt_mapped, @varname(a))) == 11 + @test @inferred(getindex(vnt_mapped, @varname(b[1:2]))) == [12, 12] + @test @inferred(getindex(vnt_mapped, @varname(c.d))) == [2.0] + @test @inferred(getindex(vnt_mapped, @varname(e.f[3].g.h[2].i))) == "ab" + @test @inferred(getindex(vnt_mapped, @varname(e.f[3].g.h[2].j))) == 6.0 + @test @inferred(getindex(vnt_mapped, @varname(y.z[3, 2:3, 3, 2:3, 4]))) == + AnotherSizedThing((2, 2)) + + vnt_applied = @inferred(apply!!(f, vnt, @varname(a))) + test_invariants(vnt_applied) + @test @inferred(getindex(vnt_applied, @varname(a))) == 11 + @test @inferred(getindex(vnt_applied, @varname(b[1:2]))) == [2, 2] + + vnt_applied = @inferred(apply!!(f, vnt_applied, @varname(b[1:2]))) + test_invariants(vnt_applied) + @test @inferred(getindex(vnt_applied, @varname(a))) == 11 + @test @inferred(getindex(vnt_applied, @varname(b[1:2]))) == [12, 12] + + vnt_applied = @inferred(apply!!(f, vnt_applied, @varname(c.d))) + test_invariants(vnt_applied) + @test @inferred(getindex(vnt_applied, @varname(c.d))) == [2.0] + + vnt_applied = @inferred(apply!!(f, vnt_applied, @varname(e.f[3].g.h[2].i))) + test_invariants(vnt_applied) + @test @inferred(getindex(vnt_applied, @varname(e.f[3].g.h[2].i))) == "ab" + @test @inferred(getindex(vnt_applied, @varname(e.f[3].g.h[2].j))) == 5.0 + + vnt_applied = @inferred(apply!!(f, vnt_applied, @varname(e.f[3].g.h[2].j))) + test_invariants(vnt_applied) + @test @inferred(getindex(vnt_applied, @varname(e.f[3].g.h[2].i))) == "ab" + @test @inferred(getindex(vnt_applied, @varname(e.f[3].g.h[2].j))) == 6.0 + + # This can't be type stable because y.z might have many elements set, and we can't + # know at compile time that this sets the only one, thus allowing the element type + # to be AnotherSizedThing. + vnt_applied = apply!!(f, vnt_applied, @varname(y.z[3, 2:3, 3, 2:3, 4])) + test_invariants(vnt_applied) + @test @inferred(getindex(vnt_applied, @varname(y.z[3, 2:3, 3, 2:3, 4]))) == + AnotherSizedThing((2, 2)) + end end end From dc6291d9ec83bf6ddf214f51d28003d669c1d174 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 18 Dec 2025 15:29:30 +0000 Subject: [PATCH 081/148] mapreduce and nested PartialArrays --- src/varnamedtuple.jl | 83 +++++++++++++++++++++++++++++++------ test/varnamedtuple.jl | 95 ++++++++++++++++++++++++++++++++++++++----- 2 files changed, 156 insertions(+), 22 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index c5cb2c681..b829523d7 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -286,10 +286,19 @@ function Base.copy(pa::PartialArray) # Make a shallow copy of pa, except for any VarNamedTuple elements, which we recursively # copy. pa_copy = PartialArray(copy(pa.data), copy(pa.mask)) - if VarNamedTuple <: eltype(pa) || eltype(pa) <: VarNamedTuple + et = eltype(pa) + if ( + VarNamedTuple <: et || + et <: VarNamedTuple || + PartialArray <: et || + et <: PartialArray + ) @inbounds for i in eachindex(pa.mask) - if pa.mask[i] && pa_copy.data[i] isa VarNamedTuple - pa_copy.data[i] = copy(pa.data[i]) + if pa.mask[i] + val = @inbounds pa_copy.data[i] + if val isa VarNamedTuple || val isa PartialArray + pa_copy.data[i] = copy(val) + end end end end @@ -754,6 +763,11 @@ function Base.keys(pa::PartialArray) sublens = _varname_to_lens(vn) push!(ks, _compose_no_identity(sublens, lens)) end + elseif val isa PartialArray + subkeys = keys(val) + for sublens in subkeys + push!(ks, _compose_no_identity(sublens, lens)) + end elseif val isa ArrayLikeBlock if !(val.inds in alb_inds_seen) push!(ks, IndexLens(Tuple(val.inds))) @@ -774,7 +788,7 @@ function Base.values(pa::PartialArray) continue end val = getindex(pa.data, ind) - if val isa VarNamedTuple + if val isa VarNamedTuple || val isa PartialArray subvalues = values(val) vs = push!!(vs, subvalues...) elseif val isa ArrayLikeBlock @@ -796,7 +810,7 @@ function Base.length(pa::PartialArray) continue end val = getindex(pa.data, ind) - if val isa VarNamedTuple + if val isa VarNamedTuple || val isa PartialArray len += length(val) else # Note we don't need to special case here for ArrayLikeBlocks. That's because @@ -1157,6 +1171,55 @@ Apply `func` to all set elements of the `vnt`, in place if possible. """ map!!(func, vnt::VarNamedTuple) = _map_recursive!!(func, vnt) +function Base.mapreduce(f, op, vnt::VarNamedTuple; init=nothing) + if init === nothing + throw( + NotImplementedError( + "mapreduce without init is not implemented for VarNamedTuple." + ), + ) + end + return _mapreduce_recursive(f, op, vnt, init) +end + +_mapreduce_recursive(f, op, x, init) = op(init, f(x)) +_mapreduce_recursive(f, op, pa::ArrayLikeBlock, init) = op(init, f(pa.block)) + +@generated function _mapreduce_recursive( + f, op, vnt::VarNamedTuple{Names}, init +) where {Names} + exs = Expr[] + push!( + exs, + quote + result = init + end, + ) + for name in Names + push!(exs, :(result = _mapreduce_recursive(f, op, vnt.data.$name, result))) + end + push!(exs, :(return result)) + return Expr(:block, exs...) +end + +function _mapreduce_recursive(f, op, pa::PartialArray, init) + result = init + albs_seen = Set{ArrayLikeBlock}() + @inbounds for i in eachindex(pa.mask) + if pa.mask[i] + val = @inbounds pa.data[i] + if val isa ArrayLikeBlock + if val in albs_seen + continue + end + push!(albs_seen, val) + end + result = _mapreduce_recursive(f, op, pa.data[i], result) + end + end + return result +end + function Base.keys(vnt::VarNamedTuple) result = VarName[] for sym in keys(vnt.data) @@ -1179,10 +1242,7 @@ function Base.values(vnt::VarNamedTuple) result = Any[] for sym in keys(vnt.data) subdata = vnt.data[sym] - if subdata isa VarNamedTuple - subvalues = values(subdata) - append!(result, subvalues) - elseif subdata isa PartialArray + if subdata isa VarNamedTuple || subdata isa PartialArray subvalues = values(subdata) append!(result, subvalues) else @@ -1196,9 +1256,7 @@ function Base.length(vnt::VarNamedTuple) len = 0 for sym in keys(vnt.data) subdata = vnt.data[sym] - if subdata isa VarNamedTuple - len += length(subdata) - elseif subdata isa PartialArray + if subdata isa VarNamedTuple || subdata isa PartialArray len += length(subdata) else len += 1 @@ -1245,6 +1303,7 @@ Base.haskey(vnt::VarNamedTuple, vn::VarName) = _haskey(vnt, vn) # PartialArrays are an implementation detail of VarNamedTuple, and should never be the # return value of getindex. Thus, we automatically convert them to dense arrays if needed. +# TODO(mhauru) The below doesn't handle nested PartialArrays. Is that a problem? _dense_array_if_needed(pa::PartialArray) = _dense_array(pa) _dense_array_if_needed(x) = x Base.getindex(vnt::VarNamedTuple, vn::VarName) = _dense_array_if_needed(_getindex(vnt, vn)) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 25cd14e49..fe0d6e1b2 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -272,6 +272,17 @@ Base.size(st::SizedThing) = st.size @test haskey(vnt, vn) @test @inferred(getindex(vnt, vn)) == x test_invariants(vnt) + + # Indices on indices + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, 1, @varname(a[1][1]))) + @test @inferred(getindex(vnt, @varname(a[1][1]))) == 1 + vnt = @inferred(setindex!!(vnt, [1], @varname(b[1].c[1]))) + @test @inferred(getindex(vnt, @varname(b[1].c[1]))) == [1] + vnt = @inferred(setindex!!(vnt, [1], @varname(e[3, 2].f[2, 2][10, 10]))) + @test @inferred(getindex(vnt, @varname(e[3, 2].f[2, 2][10, 10]))) == [1] + vnt = @inferred(setindex!!(vnt, [1], @varname(g[3, 2][10, 10].h[2, 2]))) + @test @inferred(getindex(vnt, @varname(g[3, 2][10, 10].h[2, 2]))) == [1] end @testset "equality and hash" begin @@ -352,15 +363,33 @@ Base.size(st::SizedThing) = st.size expected_merge = setindex!!(expected_merge, fill(2, 4), @varname(e.a[8:11])) @test @inferred(merge(vnt1, vnt2)) == expected_merge + vnt1 = setindex!!(vnt1, 1, @varname(e.b[1][13])) + vnt2 = setindex!!(vnt2, 2, @varname(e.b[2][13])) + expected_merge = setindex!!(expected_merge, 1, @varname(e.b[1][13])) + expected_merge = setindex!!(expected_merge, 2, @varname(e.b[2][13])) + vnt1 = setindex!!(vnt1, 1, @varname(e.b[3][13])) + vnt2 = setindex!!(vnt2, 2, @varname(e.b[3][13])) + expected_merge = setindex!!(expected_merge, 2, @varname(e.b[3][13])) + @test @inferred(merge(vnt1, vnt2)) == expected_merge + vnt1 = setindex!!(vnt1, 1, @varname(e.b[4][13])) + vnt2 = setindex!!(vnt2, 2, @varname(e.b[4][14])) + expected_merge = setindex!!(expected_merge, 1, @varname(e.b[4][13])) + expected_merge = setindex!!(expected_merge, 2, @varname(e.b[4][14])) + @test @inferred(merge(vnt1, vnt2)) == expected_merge + vnt1 = setindex!!(vnt1, ["1", "1"], @varname(f.a[1].b.c[2, 2].d[1, 3:4])) vnt2 = setindex!!(vnt2, ["2", "2"], @varname(f.a[1].b.c[2, 2].d[1, 3:4])) expected_merge = setindex!!( expected_merge, ["2", "2"], @varname(f.a[1].b.c[2, 2].d[1, 3:4]) ) - vnt1 = setindex!!(vnt1, :1, @varname(f.a[1].b.c[3, 2].d[1, 1])) - vnt2 = setindex!!(vnt2, :2, @varname(f.a[1].b.c[4, 2].d[1, 1])) - expected_merge = setindex!!(expected_merge, :1, @varname(f.a[1].b.c[3, 2].d[1, 1])) - expected_merge = setindex!!(expected_merge, :2, @varname(f.a[1].b.c[4, 2].d[1, 1])) + vnt1 = setindex!!(vnt1, :1, @varname(f.a[1].b.c[3, 2].d[1, 1][14, 13])) + vnt2 = setindex!!(vnt2, :2, @varname(f.a[1].b.c[4, 2].d[1, 1][14, 13])) + expected_merge = setindex!!( + expected_merge, :1, @varname(f.a[1].b.c[3, 2].d[1, 1][14, 13]) + ) + expected_merge = setindex!!( + expected_merge, :2, @varname(f.a[1].b.c[4, 2].d[1, 1][14, 13]) + ) @test merge(vnt1, vnt2) == expected_merge # PartialArrays with different sizes. @@ -501,6 +530,35 @@ Base.size(st::SizedThing) = st.size 1.0, SizedThing((3, 1, 4)), ] + + vnt = setindex!!(vnt, SizedThing((3, 1, 4)), @varname(p[2, 1][2:4, 5:5, 11:14])) + @test keys(vnt) == [ + @varname(a), + @varname(b), + @varname(c.x.y), + @varname(d[4]), + @varname(e.f[3, 3].g.h[2, 4, 1].i), + @varname(j[1]), + @varname(j[2]), + @varname(j[3]), + @varname(j[4]), + @varname(j[6]), + @varname(n[2].a), + @varname(o[2:4, 5:5, 11:14]), + @varname(p[2, 1][2:4, 5:5, 11:14]), + ] + @test values(vnt) == [ + 1.0, + [1, 15, 3], + [10], + -1.0, + 2.0, + fill(1.0, 4)..., + "a", + 1.0, + SizedThing((3, 1, 4)), + SizedThing((3, 1, 4)), + ] end @testset "length" begin @@ -534,6 +592,9 @@ Base.size(st::SizedThing) = st.size vnt = setindex!!(vnt, SizedThing((3, 2)), @varname(x[1, 4:6, 2, 1:2, 3])) @test @inferred(length(vnt)) == 14 + + vnt = setindex!!(vnt, [:a, :b], @varname(y[4][3][2][1:2])) + @test @inferred(length(vnt)) == 16 end @testset "empty" begin @@ -622,7 +683,7 @@ Base.size(st::SizedThing) = st.size VarNamedTuple(a = "s", b = [1, 2, 3], \ c = PartialArray{Symbol,1}((2,) => :dada))""" - vnt = setindex!!(vnt, [16.0, 17.0], @varname(d.e[3].f.g[1:2])) + vnt = setindex!!(vnt, [16.0, 17.0], @varname(d.e[3][2, 2].f.g[1:2])) io = IOBuffer() show(io, vnt) output = String(take!(io)) @@ -634,11 +695,13 @@ Base.size(st::SizedThing) = st.size VarNamedTuple(a = "s", b = [1, 2, 3], \ c = PartialArray{Symbol,1}((2,) => :dada), \ d = VarNamedTuple(\ - e = PartialArray{VarNamedTuple{(:f,), \ + e = PartialArray{PartialArray{VarNamedTuple{(:f,), \ + Tuple{VarNamedTuple{(:g,), \ + Tuple{PartialArray{Float64, 1}}}}}, 2},1}((3,) => \ + PartialArray{VarNamedTuple{(:f,), \ Tuple{VarNamedTuple{(:g,), \ - Tuple{PartialArray{Float64, 1}}}}},1}((3,) => \ - VarNamedTuple(f = VarNamedTuple(g = PartialArray{Float64,1}((1,) => 16.0, \ - (2,) => 17.0),),)),))""" + Tuple{PartialArray{Float64, 1}}}}},2}((2, 2) => VarNamedTuple(f = VarNamedTuple(g = PartialArray{Float64,1}((1,) => 16.0, \ + (2,) => 17.0),),))),))""" end @testset "block variables" begin @@ -742,7 +805,7 @@ Base.size(st::SizedThing) = st.size @test @inferred(getindex(vnt, @varname(y.z[3, 2:3, 3, 2:3, 4]))) == val end - @testset "map!! and apply!!" begin + @testset "map!!, apply!!, and mapreduce" begin vnt = VarNamedTuple() vnt = @inferred(setindex!!(vnt, 1, @varname(a))) vnt = @inferred(setindex!!(vnt, [2, 2], @varname(b[1:2]))) @@ -755,6 +818,7 @@ Base.size(st::SizedThing) = st.size vnt = @inferred( setindex!!(vnt, SizedThing((2, 2)), @varname(y.z[3, 2:3, 3, 2:3, 4])) ) + vnt = @inferred(setindex!!(vnt, "", @varname(w[4][3][2, 1]))) test_invariants(vnt) struct AnotherSizedThing{T<:Tuple} @@ -780,6 +844,12 @@ Base.size(st::SizedThing) = st.size end end + reduction = mapreduce(identity, vcat, vnt; init=Any[]) + @test reduction == vcat(Any[], 1, [2, 2], [3.0], "a", 5.0, SizedThing((2, 2)), "") + reduction = mapreduce(f, vcat, vnt; init=Any[]) + @test reduction == + vcat(Any[], 11, [12, 12], [2.0], "ab", 6.0, AnotherSizedThing((2, 2)), "b") + vnt_mapped = @inferred(map!!(f, copy(vnt))) test_invariants(vnt_mapped) @test @inferred(getindex(vnt_mapped, @varname(a))) == 11 @@ -789,6 +859,7 @@ Base.size(st::SizedThing) = st.size @test @inferred(getindex(vnt_mapped, @varname(e.f[3].g.h[2].j))) == 6.0 @test @inferred(getindex(vnt_mapped, @varname(y.z[3, 2:3, 3, 2:3, 4]))) == AnotherSizedThing((2, 2)) + @test @inferred(getindex(vnt_mapped, @varname(w[4][3][2, 1]))) == "b" vnt_applied = @inferred(apply!!(f, vnt, @varname(a))) test_invariants(vnt_applied) @@ -821,6 +892,10 @@ Base.size(st::SizedThing) = st.size test_invariants(vnt_applied) @test @inferred(getindex(vnt_applied, @varname(y.z[3, 2:3, 3, 2:3, 4]))) == AnotherSizedThing((2, 2)) + + vnt_applied = @inferred(apply!!(f, vnt_applied, @varname(w[4][3][2, 1]))) + test_invariants(vnt_applied) + @test @inferred(getindex(vnt_applied, @varname(w[4][3][2, 1]))) == "b" end end From 20ed5751f8d16f6499310a02a2b104168c4fb096 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 18 Dec 2025 15:39:04 +0000 Subject: [PATCH 082/148] Test invariants more --- test/varnamedtuple.jl | 37 +++++++++++++++++++++++++++++++------ 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index fe0d6e1b2..face2dc42 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -27,7 +27,9 @@ function test_invariants(vnt::VarNamedTuple) @test !(v isa ArrayLikeBlock) @test !(v isa PartialArray) vnt2 = setindex!!(copy(vnt), v, k) - @test vnt == vnt2 + equality = (vnt == vnt2) + # The value may be `missing` if vnt itself has values that are missing. + @test equality === true || equality === missing @test isequal(vnt, vnt2) @test hash(vnt) == hash(vnt2) end @@ -36,24 +38,26 @@ function test_invariants(vnt::VarNamedTuple) # reconstructability-from-repr property, this will fail. Likewise if any element uses # in its repr print out types that are not in scope in this module, it will fail. vnt3 = eval(Meta.parse(repr(vnt))) - @test vnt == vnt3 + equality = (vnt == vnt3) + # The value may be `missing` if vnt itself has values that are missing. + @test equality === true || equality === missing @test isequal(vnt, vnt3) @test hash(vnt) == hash(vnt3) # Check that merge with an empty VarNamedTuple is a no-op. - @test merge(vnt, VarNamedTuple()) == vnt - @test merge(VarNamedTuple(), vnt) == vnt + @test isequal(merge(vnt, VarNamedTuple()), vnt) + @test isequal(merge(VarNamedTuple(), vnt), vnt) # Check that the VNT can be constructed back from its keys and values. vnt4 = VarNamedTuple() for (k, v) in zip(vnt_keys, vnt_values) vnt4 = setindex!!(vnt4, v, k) end - @test vnt == vnt4 + @test isequal(vnt, vnt4) # Check that vnt isempty only if it has no keys was_empty = isempty(vnt) @test was_empty == isempty(vnt_keys) @test was_empty == isempty(vnt_values) # Check that vnt can be emptied - @test empty(vnt) == VarNamedTuple() + @test empty(vnt) === VarNamedTuple() emptied_vnt = empty!!(copy(vnt)) @test isempty(emptied_vnt) @test isempty(keys(emptied_vnt)) @@ -312,6 +316,8 @@ Base.size(st::SizedThing) = st.size expected_isequal = expected_isequal & isequal(v1, v2) expected_doubleequal = expected_doubleequal & (v1 == v2) end + test_invariants(vnt1) + test_invariants(vnt2) @test isequal(vnt1, vnt2) == expected_isequal @test (vnt1 == vnt2) === expected_doubleequal if expected_isequal @@ -335,6 +341,8 @@ Base.size(st::SizedThing) = st.size expected_merge = setindex!!(expected_merge, 2, @varname(c)) expected_merge = setindex!!(expected_merge, 2.0, @varname(b)) @test @inferred(merge(vnt1, vnt2)) == expected_merge + test_invariants(vnt1) + test_invariants(vnt2) vnt1 = VarNamedTuple() vnt2 = VarNamedTuple() @@ -391,6 +399,8 @@ Base.size(st::SizedThing) = st.size expected_merge, :2, @varname(f.a[1].b.c[4, 2].d[1, 1][14, 13]) ) @test merge(vnt1, vnt2) == expected_merge + test_invariants(vnt1) + test_invariants(vnt2) # PartialArrays with different sizes. vnt1 = VarNamedTuple() @@ -406,6 +416,8 @@ Base.size(st::SizedThing) = st.size @test @inferred(merge(vnt1, vnt2)) == expected_merge_12 expected_merge_21 = setindex!!(expected_merge_12, 1, @varname(a[1])) @test @inferred(merge(vnt2, vnt1)) == expected_merge_21 + test_invariants(vnt1) + test_invariants(vnt2) vnt1 = VarNamedTuple() vnt2 = VarNamedTuple() @@ -420,6 +432,8 @@ Base.size(st::SizedThing) = st.size @test merge(vnt1, vnt2) == expected_merge_12 expected_merge_21 = setindex!!(expected_merge_12, 1, @varname(a[1, 1])) @test merge(vnt2, vnt1) == expected_merge_21 + test_invariants(vnt1) + test_invariants(vnt2) end @testset "keys and values" begin @@ -559,6 +573,7 @@ Base.size(st::SizedThing) = st.size SizedThing((3, 1, 4)), SizedThing((3, 1, 4)), ] + test_invariants(vnt) end @testset "length" begin @@ -595,6 +610,7 @@ Base.size(st::SizedThing) = st.size vnt = setindex!!(vnt, [:a, :b], @varname(y[4][3][2][1:2])) @test @inferred(length(vnt)) == 16 + test_invariants(vnt) end @testset "empty" begin @@ -605,10 +621,12 @@ Base.size(st::SizedThing) = st.size @test @inferred(isempty(vnt)) == true vnt = setindex!!(vnt, 1.0, @varname(a)) @test @inferred(isempty(vnt)) == false + test_invariants(vnt) vnt = VarNamedTuple() vnt = setindex!!(vnt, [], @varname(a[1])) @test @inferred(isempty(vnt)) == false + test_invariants(vnt) # 2) empty!! keeps PartialArrays in place: vnt = VarNamedTuple() @@ -624,22 +642,26 @@ Base.size(st::SizedThing) = st.size @test @inferred(getindex(vnt, @varname(a[2:4]))) == [1, 2, 3] @test haskey(vnt, @varname(a[2:4])) @test !haskey(vnt, @varname(a[1])) + test_invariants(vnt) end @testset "densification" begin vnt = VarNamedTuple() vnt = @inferred(setindex!!(vnt, 1.0, @varname(a.b[1].c[1, 1]))) @test @inferred(getindex(vnt, @varname(a.b[1].c))) == fill(1.0, (1, 1)) + test_invariants(vnt) vnt = VarNamedTuple() vnt = @inferred(setindex!!(vnt, 1.0, @varname(a.b[1].c[1, 1]))) vnt = @inferred(setindex!!(vnt, 1.0, @varname(a.b[1].c[1, 2]))) @test @inferred(getindex(vnt, @varname(a.b[1].c))) == fill(1.0, (1, 2)) + test_invariants(vnt) vnt = VarNamedTuple() vnt = @inferred(setindex!!(vnt, 1.0, @varname(a.b[1].c[1, 1]))) vnt = @inferred(setindex!!(vnt, 1.0, @varname(a.b[1].c[2, 1]))) @test @inferred(getindex(vnt, @varname(a.b[1].c))) == fill(1.0, (2, 1)) + test_invariants(vnt) vnt = VarNamedTuple() vnt = @inferred(setindex!!(vnt, 1.0, @varname(a.b[1].c[1, 1]))) @@ -650,10 +672,12 @@ Base.size(st::SizedThing) = st.size @test @inferred(getindex(vnt, @varname(a.b[1].c))) == fill(1.0, (2, 2)) vnt = @inferred(setindex!!(vnt, 1.0, @varname(a.b[1].c[3, 3]))) @test_throws ArgumentError @inferred(getindex(vnt, @varname(a.b[1].c))) + test_invariants(vnt) vnt = VarNamedTuple() vnt = @inferred(setindex!!(vnt, SizedThing((2,)), @varname(x[1:2]))) @test_throws ArgumentError @inferred(getindex(vnt, @varname(x))) + test_invariants(vnt) end @testset "printing" begin @@ -702,6 +726,7 @@ Base.size(st::SizedThing) = st.size Tuple{VarNamedTuple{(:g,), \ Tuple{PartialArray{Float64, 1}}}}},2}((2, 2) => VarNamedTuple(f = VarNamedTuple(g = PartialArray{Float64,1}((1,) => 16.0, \ (2,) => 17.0),),))),))""" + test_invariants(vnt) end @testset "block variables" begin From 477b715a12776b30973e3d4e3ed7d53d3183500f Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 18 Dec 2025 16:52:54 +0000 Subject: [PATCH 083/148] Work-in-progress VNTVarInfo --- src/DynamicPPL.jl | 5 +- src/chains.jl | 4 +- src/contexts/init.jl | 10 +- src/logdensityfunction.jl | 91 ++++++++------ src/simple_varinfo.jl | 18 +-- src/test_utils/varinfo.jl | 42 ++++--- src/utils.jl | 1 + src/vntvarinfo.jl | 247 +++++++++++++++++++++++++++++++++++++ test/compiler.jl | 4 +- test/logdensityfunction.jl | 8 +- test/test_util.jl | 31 ++--- 11 files changed, 368 insertions(+), 93 deletions(-) create mode 100644 src/vntvarinfo.jl diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 25ca59018..5b831e100 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -185,7 +185,7 @@ abstract type AbstractVarInfo <: AbstractModelTrace end # Necessary forward declarations include("utils.jl") include("varnamedtuple.jl") -using .VarNamedTuples: VarNamedTuple +using .VarNamedTuples: VarNamedTuple, map!!, apply!! include("contexts.jl") include("contexts/default.jl") include("contexts/init.jl") @@ -201,7 +201,8 @@ include("accumulators.jl") include("default_accumulators.jl") include("abstract_varinfo.jl") include("threadsafe.jl") -include("varinfo.jl") +# include("varinfo.jl") +include("vntvarinfo.jl") include("simple_varinfo.jl") include("onlyaccs.jl") include("compiler.jl") diff --git a/src/chains.jl b/src/chains.jl index 71ca29a8f..dc3a91044 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -67,8 +67,8 @@ end # Re-evaluating the model is unconscionably slow for untyped VarInfo. It's much faster to # convert it to a typed varinfo first, hence this method. # https://github.com/TuringLang/Turing.jl/issues/2604 -maybe_to_typed_varinfo(vi::UntypedVarInfo) = typed_varinfo(vi) -maybe_to_typed_varinfo(vi::UntypedVectorVarInfo) = typed_vector_varinfo(vi) +# maybe_to_typed_varinfo(vi::UntypedVarInfo) = typed_varinfo(vi) +# maybe_to_typed_varinfo(vi::UntypedVectorVarInfo) = typed_vector_varinfo(vi) maybe_to_typed_varinfo(vi::AbstractVarInfo) = vi """ diff --git a/src/contexts/init.jl b/src/contexts/init.jl index dd9e99421..5422c7c85 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -320,7 +320,9 @@ function tilde_assume!!( insert_transformed_value = in_varinfo ? is_transformed(vi, vn) : is_transformed(vi) val_to_insert, logjac = if insert_transformed_value # Calculate the forward logjac and sum them up. - y, fwd_logjac = with_logabsdet_jacobian(link_transform(dist), x) + lt = link_transform(dist) + y, fwd_logjac = with_logabsdet_jacobian(lt, x) + transform = _compose_no_identity(transform, lt) # Note that if we use VectorWithRanges with a full VarInfo, this double-Jacobian # calculation wastes a lot of time going from linked vectorised -> unlinked -> # linked, and `inv_logjac` will also just be the negative of `fwd_logjac`. @@ -360,7 +362,11 @@ function tilde_assume!!( if in_varinfo vi = setindex!!(vi, val_to_insert, vn) else - vi = push!!(vi, vn, val_to_insert, dist) + vi = if vi isa VNTVarInfo + push!!(vi, vn, val_to_insert, inverse(transform)) + else + push!!(vi, vn, val_to_insert, dist) + end end # Neither of these set the `trans` flag so we have to do it manually if # necessary. diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 89e2b5989..adcb319c8 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -13,7 +13,7 @@ using DynamicPPL: OnlyAccsVarInfo, RangeAndLinked, VectorWithRanges, - Metadata, + # Metadata, VarNamedVector, default_accumulators, float_type_with_fallback, @@ -310,45 +310,56 @@ representation, along with whether each variable is linked or unlinked. This function returns a VarNamedTuple mapping all VarNames to their corresponding `RangeAndLinked`. """ -function get_ranges_and_linked(varinfo::VarInfo{<:NamedTuple{syms}}) where {syms} - all_ranges = VarNamedTuple() +function get_ranges_and_linked(vi::VNTVarInfo) offset = 1 - for sym in syms - md = varinfo.metadata[sym] - this_md_others, offset = get_ranges_and_linked_metadata(md, offset) - all_ranges = merge(all_ranges, this_md_others) + vnt = map!!(vi.values) do tv + val = tv.val + range = offset:(offset + length(val) - 1) + offset += length(val) + RangeAndLinked(range, tv.linked, size(val)) end - return all_ranges -end -function get_ranges_and_linked(varinfo::VarInfo{<:Union{Metadata,VarNamedVector}}) - all_ranges, _ = get_ranges_and_linked_metadata(varinfo.metadata, 1) - return all_ranges -end -function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int) - all_ranges = VarNamedTuple() - offset = start_offset - for (vn, idx) in md.idcs - is_linked = md.is_transformed[idx] - range = md.ranges[idx] .+ (start_offset - 1) - orig_size = varnamesize(vn) - all_ranges = BangBang.setindex!!( - all_ranges, RangeAndLinked(range, is_linked, orig_size), vn - ) - offset += length(range) - end - return all_ranges, offset -end -function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int) - all_ranges = VarNamedTuple() - offset = start_offset - for (vn, idx) in vnv.varname_to_index - is_linked = vnv.is_unconstrained[idx] - range = vnv.ranges[idx] .+ (start_offset - 1) - orig_size = varnamesize(vn) - all_ranges = BangBang.setindex!!( - all_ranges, RangeAndLinked(range, is_linked, orig_size), vn - ) - offset += length(range) - end - return all_ranges, offset + return vnt end + +# function get_ranges_and_linked(varinfo::VarInfo{<:NamedTuple{syms}}) where {syms} +# all_ranges = VarNamedTuple() +# offset = 1 +# for sym in syms +# md = varinfo.metadata[sym] +# this_md_others, offset = get_ranges_and_linked_metadata(md, offset) +# all_ranges = merge(all_ranges, this_md_others) +# end +# return all_ranges +# end +# function get_ranges_and_linked(varinfo::VarInfo{<:Union{Metadata,VarNamedVector}}) +# all_ranges, _ = get_ranges_and_linked_metadata(varinfo.metadata, 1) +# return all_ranges +# end +# function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int) +# all_ranges = VarNamedTuple() +# offset = start_offset +# for (vn, idx) in md.idcs +# is_linked = md.is_transformed[idx] +# range = md.ranges[idx] .+ (start_offset - 1) +# orig_size = varnamesize(vn) +# all_ranges = BangBang.setindex!!( +# all_ranges, RangeAndLinked(range, is_linked, orig_size), vn +# ) +# offset += length(range) +# end +# return all_ranges, offset +# end +# function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int) +# all_ranges = VarNamedTuple() +# offset = start_offset +# for (vn, idx) in vnv.varname_to_index +# is_linked = vnv.is_unconstrained[idx] +# range = vnv.ranges[idx] .+ (start_offset - 1) +# orig_size = varnamesize(vn) +# all_ranges = BangBang.setindex!!( +# all_ranges, RangeAndLinked(range, is_linked, orig_size), vn +# ) +# offset += length(range) +# end +# return all_ranges, offset +# end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 9d3fb1925..4add65d6d 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -256,15 +256,15 @@ function SimpleVarInfo(model::Model, init_strategy::AbstractInitStrategy=InitFro end # Constructor from `VarInfo`. -function SimpleVarInfo(vi::NTVarInfo, ::Type{D}) where {D} - values = values_as(vi, D) - return SimpleVarInfo(values, copy(getaccs(vi))) -end -function SimpleVarInfo{T}(vi::NTVarInfo, ::Type{D}) where {T<:Real,D} - values = values_as(vi, D) - accs = map(acc -> convert_eltype(T, acc), getaccs(vi)) - return SimpleVarInfo(values, accs) -end +# function SimpleVarInfo(vi::NTVarInfo, ::Type{D}) where {D} +# values = values_as(vi, D) +# return SimpleVarInfo(values, copy(getaccs(vi))) +# end +# function SimpleVarInfo{T}(vi::NTVarInfo, ::Type{D}) where {T<:Real,D} +# values = values_as(vi, D) +# accs = map(acc -> convert_eltype(T, acc), getaccs(vi)) +# return SimpleVarInfo(values, accs) +# end function untyped_simple_varinfo(model::Model) varinfo = SimpleVarInfo(OrderedDict{VarName,Any}()) diff --git a/src/test_utils/varinfo.jl b/src/test_utils/varinfo.jl index 6483b29e8..79b92ce13 100644 --- a/src/test_utils/varinfo.jl +++ b/src/test_utils/varinfo.jl @@ -33,26 +33,32 @@ of the varinfo instances. function setup_varinfos( model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false ) - # VarInfo - vi_untyped_metadata = DynamicPPL.untyped_varinfo(model) - vi_untyped_vnv = DynamicPPL.untyped_vector_varinfo(model) - vi_typed_metadata = DynamicPPL.typed_varinfo(model) - vi_typed_vnv = DynamicPPL.typed_vector_varinfo(model) + # # VarInfo + # vi_untyped_metadata = DynamicPPL.untyped_varinfo(model) + # vi_untyped_vnv = DynamicPPL.untyped_vector_varinfo(model) + # vi_typed_metadata = DynamicPPL.typed_varinfo(model) + # vi_typed_vnv = DynamicPPL.typed_vector_varinfo(model) - # SimpleVarInfo - svi_typed = SimpleVarInfo(example_values) - svi_untyped = SimpleVarInfo(OrderedDict{VarName,Any}()) - svi_vnv = SimpleVarInfo(DynamicPPL.VarNamedVector()) + # # SimpleVarInfo + # svi_typed = SimpleVarInfo(example_values) + # svi_untyped = SimpleVarInfo(OrderedDict{VarName,Any}()) + # svi_vnv = SimpleVarInfo(DynamicPPL.VarNamedVector()) - varinfos = map(( - vi_untyped_metadata, - vi_untyped_vnv, - vi_typed_metadata, - vi_typed_vnv, - svi_typed, - svi_untyped, - svi_vnv, - )) do vi + # varinfos = map(( + # vi_untyped_metadata, + # vi_untyped_vnv, + # vi_typed_metadata, + # vi_typed_vnv, + # svi_typed, + # svi_untyped, + # svi_vnv, + # )) do vi + # # Set them all to the same values and evaluate logp. + # vi = update_values!!(vi, example_values, varnames) + # last(DynamicPPL.evaluate!!(model, vi)) + # end + # + varinfos = map((DynamicPPL.typed_varinfo(model),)) do vi # Set them all to the same values and evaluate logp. vi = update_values!!(vi, example_values, varnames) last(DynamicPPL.evaluate!!(model, vi)) diff --git a/src/utils.jl b/src/utils.jl index fe2879182..ed9f3aa13 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -49,6 +49,7 @@ function typed_identity end @inline typed_identity(x) = x @inline Bijectors.with_logabsdet_jacobian(::typeof(typed_identity), x) = (x, zero(LogProbType)) +@inline Bijectors.inverse(::typeof(typed_identity)) = typed_identity """ @addlogprob!(ex) diff --git a/src/vntvarinfo.jl b/src/vntvarinfo.jl new file mode 100644 index 000000000..9b6ee2c7e --- /dev/null +++ b/src/vntvarinfo.jl @@ -0,0 +1,247 @@ +struct VNTVarInfo{T<:VarNamedTuple,Accs<:AccumulatorTuple} <: AbstractVarInfo + values::T + accs::Accs +end + +# TODO(mhauru) Make this renaming permanent. +const VarInfo = VNTVarInfo + +struct TransformedValue{ValType,TransformType} + val::ValType + linked::Bool + transform::TransformType +end + +VNTVarInfo() = VNTVarInfo(VarNamedTuple(), default_accumulators()) + +function VNTVarInfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) + return VNTVarInfo(Random.default_rng(), model, init_strategy) +end + +function VNTVarInfo( + rng::Random.AbstractRNG, + model::Model, + init_strategy::AbstractInitStrategy=InitFromPrior(), +) + return last(init!!(rng, model, VNTVarInfo(), init_strategy)) +end + +getaccs(vi::VNTVarInfo) = vi.accs +setaccs!!(vi::VNTVarInfo, accs::AccumulatorTuple) = VNTVarInfo(vi.values, accs) + +transformation(::VNTVarInfo) = DynamicTransformation() + +Base.haskey(vi::VNTVarInfo, vn::VarName) = haskey(vi.values, vn) + +Base.length(vi::VNTVarInfo) = length(vi.values) + +function Base.getindex(vi::VNTVarInfo, vn::VarName) + tv = getindex(vi.values, vn) + return tv.transform(tv.val) +end + +Base.isempty(vi::VNTVarInfo) = isempty(vi.values) + +# TODO(mhauru) This should be called setindex_internal!!, but that's not the current +# convention. +function BangBang.setindex!!(vi::VNTVarInfo, val, vn::VarName) + old_tv = getindex(vi.values, vn) + new_tv = TransformedValue(val, old_tv.linked, old_tv.transform) + new_values = setindex!!(vi.values, new_tv, vn) + return VNTVarInfo(new_values, vi.accs) +end + +# TODO(mhauru) The arguments are in the wrong order, but this is the current convetion. +function BangBang.push!!(vi::VNTVarInfo, vn::VarName, val, transform=typed_identity) + new_tv = TransformedValue(val, false, transform) + new_values = setindex!!(vi.values, new_tv, vn) + return VNTVarInfo(new_values, vi.accs) +end + +Base.keys(vi::VNTVarInfo) = keys(vi.values) + +function set_transformed!!(vi::VNTVarInfo, linked::Bool, vn::VarName) + old_tv = getindex(vi.values, vn) + new_tv = TransformedValue(old_tv.val, linked, old_tv.transform) + new_values = setindex!!(vi.values, new_tv, vn) + return VNTVarInfo(new_values, vi.accs) +end + +function set_transformed!!(vi::VNTVarInfo, linked::Bool) + new_values = map!!(vi.values) do tv + TransformedValue(tv.val, linked, tv.transform) + end + return VNTVarInfo(new_values, vi.accs) +end + +function getindex_internal(vi::VNTVarInfo, vn::VarName) + tv = getindex(vi.values, vn) + return tv.val +end + +getindex_internal(vi::VNTVarInfo, ::Colon) = values_as(vi, Vector) + +function is_transformed(vi::VNTVarInfo, vn::VarName) + tv = getindex(vi.values, vn) + return tv.linked +end + +# TODO(mhauru) Other VarInfos have something like this. Do we need it? +# function from_internal_transform(::VNTVarInfo, ::VarName, dist::Distribution) +# return from_vec_transform(dist) +# end + +function from_internal_transform(vi::VNTVarInfo, vn::VarName, ::Distribution) + return getindex(vi.values, vn).transform +end + +function from_linked_internal_transform(::VNTVarInfo, ::VarName, dist::Distribution) + return from_linked_vec_transform(dist) +end + +function from_linked_internal_transform(vi::VNTVarInfo, vn::VarName) + return getindex(vi.values, vn).transform +end + +function change_transform(tv::TransformedValue, new_transform, linked) + val_untransformed, logjac1 = with_logabsdet_jacobian(tv.transform, tv.val) + val_new, logjac2 = with_logabsdet_jacobian(inverse(new_transform), val_untransformed) + return TransformedValue(val_new, linked, new_transform), logjac1 + logjac2 +end + +function link!!(::DynamicTransformation, vi::VNTVarInfo, vns, model::Model) + dists = extract_priors(model, vi) + cumulative_logjac = zero(LogProbType) + new_values = vi.values + for vn in vns + new_values = apply!!(new_values, vn) do tv + dist = getindex(dists, vn) + transform = from_linked_vec_transform(dist) + new_tv, logjac = change_transform(tv, transform, true) + cumulative_logjac += logjac + return new_tv + end + end + vi = VNTVarInfo(new_values, vi.accs) + if hasacc(vi, Val(:LogJacobian)) + vi = acclogjac!!(vi, cumulative_logjac) + end + return vi +end + +function link!!(::DynamicTransformation, vi::VNTVarInfo, model::Model) + # TODO(mhauru) This is probably pretty inefficient. Do this better. Would like to use + # map!!, but it doesn't have access to the VarName. + dists = extract_priors(model, vi) + cumulative_logjac = zero(LogProbType) + new_values = vi.values + vns = keys(vi) + for vn in vns + new_values = apply!!(vi.values, vn) do tv + dist = getindex(dists, vn) + transform = from_linked_vec_transform(dist) + new_tv, logjac = change_transform(tv, transform, true) + cumulative_logjac += logjac + return new_tv + end + end + vi = VNTVarInfo(new_values, vi.accs) + if hasacc(vi, Val(:LogJacobian)) + vi = acclogjac!!(vi, cumulative_logjac) + end + return vi +end + +function invlink!!(::DynamicTransformation, vi::VNTVarInfo, vns, model::Model) + cumulative_logjac = zero(LogProbType) + new_values = vi.values + for vn in vns + new_values = apply!!(new_values, vn) do tv + transform = typed_identity + new_tv, logjac = change_transform(tv, transform, false) + cumulative_logjac += logjac + return new_tv + end + end + vi = VNTVarInfo(new_values, vi.accs) + if hasacc(vi, Val(:LogJacobian)) + vi = acclogjac!!(vi, cumulative_logjac) + end + return vi +end + +function invlink!!(::DynamicTransformation, vi::VNTVarInfo, model::Model) + # TODO(mhauru) This is probably pretty inefficient. Do this better. Would like to use + # map!!, but it doesn't have access to the VarName. + cumulative_logjac = zero(LogProbType) + new_values = vi.values + vns = keys(vi) + for vn in vns + new_values = apply!!(vi.values, vn) do tv + transform = typed_identity + new_tv, logjac = change_transform(tv, transform, false) + cumulative_logjac += logjac + return new_tv + end + end + vi = VNTVarInfo(new_values, vi.accs) + if hasacc(vi, Val(:LogJacobian)) + vi = acclogjac!!(vi, cumulative_logjac) + end + return vi +end + +# TODO(mhauru) I don't think this should return the internal values, but that's the current +# convention. +function values_as(vi::VNTVarInfo, ::Type{Vector}) + return mapreduce(tv -> tovec(tv.val), vcat, vi.values; init=Union{}[]) +end + +# TODO(mhauru) These two are now redundant, just conforming to the old interface +# temporarily. +function untyped_varinfo( + rng::Random.AbstractRNG, + model::Model, + init_strategy::AbstractInitStrategy=InitFromPrior(), +) + return VNTVarInfo(rng, model, init_strategy) +end + +function typed_varinfo( + rng::Random.AbstractRNG, + model::Model, + init_strategy::AbstractInitStrategy=InitFromPrior(), +) + return VNTVarInfo(rng, model, init_strategy) +end + +typed_varinfo(vi::VNTVarInfo) = vi + +function typed_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) + return typed_varinfo(Random.default_rng(), model, init_strategy) +end + +function untyped_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) + return untyped_varinfo(Random.default_rng(), model, init_strategy) +end + +function unflatten(vi::VNTVarInfo, vec::AbstractVector) + index = 1 + new_values = map!!(vi.values) do tv + # TODO(mhauru) This is quite crude, assuming that the value stored currently is + # an AbstractArray of some kind that has a size, and that reshape makes sense here. + # I may fix this later, but I'm also tempted to just get rid of unflatten entirely. + # This works for now for making most tests pass. + old_val = tv.val + len = length(old_val) + new_val = reshape(vec[index:(index + len - 1)], size(old_val)) + # If the old_val was a scalar then new_val is a 0-dimensional array. + # Convert it to a scalar. + if !(old_val isa AbstractArray) && length(old_val) == 1 + new_val = new_val[1] + end + index += len + return TransformedValue(new_val, tv.linked, tv.transform) + end + return VNTVarInfo(new_values, vi.accs) +end diff --git a/test/compiler.jl b/test/compiler.jl index 9056f666a..5101bd602 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -236,9 +236,9 @@ module Issue537 end # https://github.com/TuringLang/Turing.jl/issues/1464#issuecomment-731153615 vi = VarInfo(gdemo(x)) - @test haskey(vi.metadata, :x) + @test haskey(vi, @varname(x)) vi = VarInfo(gdemo(x)) - @test haskey(vi.metadata, :x) + @test haskey(vi, @varname(x)) # Non-array variables @model function testmodel_nonarray(x, y) diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index f96e7bf27..153962c9e 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -18,10 +18,10 @@ using Mooncake: Mooncake @testset "LogDensityFunction: Correctness" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.ALL_MODELS @testset "$varinfo_func" for varinfo_func in [ - DynamicPPL.untyped_varinfo, + # DynamicPPL.untyped_varinfo, DynamicPPL.typed_varinfo, - DynamicPPL.untyped_vector_varinfo, - DynamicPPL.typed_vector_varinfo, + # DynamicPPL.untyped_vector_varinfo, + # DynamicPPL.typed_vector_varinfo, ] unlinked_vi = varinfo_func(m) @testset "$islinked" for islinked in (false, true) @@ -38,7 +38,7 @@ using Mooncake: Mooncake # directly range_with_linked = ranges[vn] @test params[range_with_linked.range] == - DynamicPPL.getindex_internal(vi, vn) + DynamicPPL.tovec(DynamicPPL.getindex_internal(vi, vn)) # Check that the link status is correct @test range_with_linked.is_linked == islinked end diff --git a/test/test_util.jl b/test/test_util.jl index 94fdbd744..821b1e0db 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -16,28 +16,31 @@ Return string representing a short description of `vi`. function short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) return "threadsafe($(short_varinfo_name(vi.varinfo)))" end -function short_varinfo_name(vi::DynamicPPL.NTVarInfo) - return if DynamicPPL.has_varnamedvector(vi) - "TypedVectorVarInfo" - else - "TypedVarInfo" - end -end -short_varinfo_name(::DynamicPPL.UntypedVarInfo) = "UntypedVarInfo" -short_varinfo_name(::DynamicPPL.UntypedVectorVarInfo) = "UntypedVectorVarInfo" +# function short_varinfo_name(vi::DynamicPPL.NTVarInfo) +# return if DynamicPPL.has_varnamedvector(vi) +# "TypedVectorVarInfo" +# else +# "TypedVarInfo" +# end +# end +# short_varinfo_name(::DynamicPPL.UntypedVarInfo) = "UntypedVarInfo" +# short_varinfo_name(::DynamicPPL.UntypedVectorVarInfo) = "UntypedVectorVarInfo" function short_varinfo_name(::SimpleVarInfo{<:NamedTuple,<:Ref}) return "SimpleVarInfo{<:NamedTuple,<:Ref}" end function short_varinfo_name(::SimpleVarInfo{<:OrderedDict,<:Ref}) return "SimpleVarInfo{<:OrderedDict,<:Ref}" end -function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector,<:Ref}) - return "SimpleVarInfo{<:VarNamedVector,<:Ref}" -end +# function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector,<:Ref}) +# return "SimpleVarInfo{<:VarNamedVector,<:Ref}" +# end short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}" short_varinfo_name(::SimpleVarInfo{<:OrderedDict}) = "SimpleVarInfo{<:OrderedDict}" -function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector}) - return "SimpleVarInfo{<:VarNamedVector}" +# function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector}) +# return "SimpleVarInfo{<:VarNamedVector}" +# end +function short_varinfo_name(::DynamicPPL.VNTVarInfo) + return "VNTVarInfo" end # convenient functions for testing model.jl From 7aa601312b9e86e8139fcaa57c2e2c5782b0c42f Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 18 Dec 2025 17:37:39 +0000 Subject: [PATCH 084/148] Fix a bug in link --- src/vntvarinfo.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/vntvarinfo.jl b/src/vntvarinfo.jl index 9b6ee2c7e..184fbd201 100644 --- a/src/vntvarinfo.jl +++ b/src/vntvarinfo.jl @@ -137,7 +137,7 @@ function link!!(::DynamicTransformation, vi::VNTVarInfo, model::Model) new_values = vi.values vns = keys(vi) for vn in vns - new_values = apply!!(vi.values, vn) do tv + new_values = apply!!(new_values, vn) do tv dist = getindex(dists, vn) transform = from_linked_vec_transform(dist) new_tv, logjac = change_transform(tv, transform, true) @@ -177,7 +177,7 @@ function invlink!!(::DynamicTransformation, vi::VNTVarInfo, model::Model) new_values = vi.values vns = keys(vi) for vn in vns - new_values = apply!!(vi.values, vn) do tv + new_values = apply!!(new_values, vn) do tv transform = typed_identity new_tv, logjac = change_transform(tv, transform, false) cumulative_logjac += logjac From d9e5405df819835daa81429e51343659cb444e3d Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 19 Dec 2025 10:42:24 +0000 Subject: [PATCH 085/148] Mark a test as broken on 1.10 --- test/logdensityfunction.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index f96e7bf27..9f30d7b68 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -113,7 +113,10 @@ end end ldf = DynamicPPL.LogDensityFunction(m, DynamicPPL.getlogjoint_internal, vi) x = vi[:] - @inferred LogDensityProblems.logdensity(ldf, x) + # The below type inference fails on v1.10. + @test begin + @inferred LogDensityProblems.logdensity(ldf, x) + end broken = (VERSION < v"1.11.0") end end end From 05dd3afea8953b55130e1c07644897f5be86f8df Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 19 Dec 2025 11:45:08 +0000 Subject: [PATCH 086/148] Fix hasvalue and getvalue for VNT --- ext/DynamicPPLMCMCChainsExt.jl | 2 +- src/varnamedtuple.jl | 85 ++++++++++++++++++++++++++++++++-- 2 files changed, 83 insertions(+), 4 deletions(-) diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 07324d665..485504766 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -83,7 +83,7 @@ end """ AbstractMCMC.to_samples( ::Type{DynamicPPL.ParamsWithStats}, - chain::MCMCChains.Chains + chain::MCMCChains.Chains, ) Convert an `MCMCChains.Chains` object to an array of `DynamicPPL.ParamsWithStats`. diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index bb1f4a14b..0346ec6e6 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -3,7 +3,7 @@ module VarNamedTuples using AbstractPPL using AbstractPPL: AbstractPPL -using Distributions: Distribution +using Distributions: Distributions, Distribution using BangBang using Accessors using ..DynamicPPL: _compose_no_identity @@ -1200,12 +1200,91 @@ function to_dict(::Type{T}, vnt::VarNamedTuple) where {T<:AbstractDict{<:VarName end to_dict(vnt::VarNamedTuple) = to_dict(Dict{VarName,Any}, vnt) -function AbstractPPL.hasvalue(vnt::VarNamedTuple, vn::VarName, ::Distribution) +function AbstractPPL.hasvalue(vnt::VarNamedTuple, vn::VarName) return haskey(vnt, vn) end -function AbstractPPL.getvalue(vnt::VarNamedTuple, vn::VarName, ::Distribution) +function AbstractPPL.getvalue(vnt::VarNamedTuple, vn::VarName) return getindex(vnt, vn) end +# TODO(mhauru) The following methods mimic the structure of those in +# AbstractPPLDistributionsExtension, and fall back on converting any PartialArrays to +# dictionaries, and calling the AbstractPPL methods. We should eventually make +# implementations of these directly for PartialArray, and maybe move these methods +# elsewhere. Better yet, once we no longer store VarName values in Dictionaries anywhere, +# and FlexiChains takes over from MCMCChains, this could hopefully all be removed. + +# The only case where the Distribution argument makes a difference is if the distribution +# is multivariate and the values are stored in a PartialArray. + +function AbstractPPL.hasvalue( + vnt::VarNamedTuple, vn::VarName, ::Distributions.UnivariateDistribution +) + return AbstractPPL.hasvalue(vnt, vn) +end + +function AbstractPPL.getvalue( + vnt::VarNamedTuple, vn::VarName, ::Distributions.UnivariateDistribution +) + return AbstractPPL.getvalue(vnt, vn) +end + +function AbstractPPL.hasvalue(vals::VarNamedTuple, vn::VarName, dist::Distribution) + @warn "`hasvalue(vals, vn, dist)` is not implemented for $(typeof(dist)); falling back to `hasvalue(vals, vn)`." + return AbstractPPL.hasvalue(vals, vn) +end + +function AbstractPPL.getvalue(vals::VarNamedTuple, vn::VarName, dist::Distribution) + @warn "`getvalue(vals, vn, dist)` is not implemented for $(typeof(dist)); falling back to `getvalue(vals, vn)`." + return AbstractPPL.getvalue(vals, vn) +end + +const MV_DIST_TYPES = Union{ + Distributions.MultivariateDistribution, + Distributions.MatrixDistribution, + Distributions.LKJCholesky, +} + +function AbstractPPL.hasvalue(vnt::VarNamedTuple, vn::VarName, dist::MV_DIST_TYPES) + if !haskey(vnt, vn) + # Can't even find the parent VarName, there is no hope. + return false + end + # Note that _getindex, rather than getindex, skips the need to denseify PartialArrays. + val = _getindex(vnt, vn) + if !(val isa VarNamedTuple || val isa PartialArray) + # There is _a_ value. Where it's the right kind, we do not know, but returning true + # is no worse than `hasvalue` returning true for e.g. UnivariateDistributions + # whenever there is at least some value. + return true + end + # Convert to VarName-keyed Dict. + et = val isa VarNamedTuple ? Any : eltype(val) + dval = Dict{VarName,et}() + for k in keys(val) + # VarNamedTuples have VarNames as keys, PartialArrays have IndexLenses. + subvn = val isa VarNamedTuple ? prefix(k, vn) : (k ∘ vn) + dval[subvn] = getindex(val, k) + end + return hasvalue(dval, vn, dist) +end + +function AbstractPPL.getvalue(vnt::VarNamedTuple, vn::VarName, dist::MV_DIST_TYPES) + # Note that _getindex, rather than getindex, skips the need to denseify PartialArrays. + val = _getindex(vnt, vn) + if !(val isa VarNamedTuple || val isa PartialArray) + return val + end + # Convert to VarName-keyed Dict. + et = val isa VarNamedTuple ? Any : eltype(val) + dval = Dict{VarName,et}() + for k in keys(val) + # VarNamedTuples have VarNames as keys, PartialArrays have IndexLenses. + subvn = val isa VarNamedTuple ? prefix(k, vn) : (k ∘ vn) + dval[subvn] = getindex(val, k) + end + return getvalue(dval, vn, dist) +end + end From 267c55471f44bb1ae03daf4877a25405bb79b437 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 19 Dec 2025 11:46:12 +0000 Subject: [PATCH 087/148] Trivial bug fix --- test/logdensityfunction.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index 9f30d7b68..2e3c56c53 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -116,6 +116,7 @@ end # The below type inference fails on v1.10. @test begin @inferred LogDensityProblems.logdensity(ldf, x) + true end broken = (VERSION < v"1.11.0") end end From 76ac5b617218f093c35dfd0b3403f9ba10902a7f Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 19 Dec 2025 18:06:26 +0000 Subject: [PATCH 088/148] Use skip rather than broken for an inference test --- test/logdensityfunction.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index 2e3c56c53..b58006de2 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -117,7 +117,7 @@ end @test begin @inferred LogDensityProblems.logdensity(ldf, x) true - end broken = (VERSION < v"1.11.0") + end skip = (VERSION < v"1.11.0") end end end From 0c50bd74276dd1e5bd3efa0e00ba0766bc81ee86 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 6 Jan 2026 14:21:46 +0000 Subject: [PATCH 089/148] Fix a docs typo Co-authored-by: Penelope Yong --- HISTORY.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/HISTORY.md b/HISTORY.md index 0ad1824dd..bb40b8464 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -40,7 +40,7 @@ Both of the above examples previously created 2-dimensional models, with two dis TODO(mhauru) This may cause surprising issues when using `eachindex`, which is generally encouraged, e.g. ``` -x = Array{Float64,2}(undef, (3, 3) +x = Array{Float64,2}(undef, (3, 3)) for i in eachindex(x) x[i] ~ Normal() end From 6b211b105739a470d69d11b18d81ba9b069e0fb1 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 6 Jan 2026 14:26:31 +0000 Subject: [PATCH 090/148] Use floatmin in test_utils --- src/test_utils/models.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/test_utils/models.jl b/src/test_utils/models.jl index dcc2d92a2..283d5f01b 100644 --- a/src/test_utils/models.jl +++ b/src/test_utils/models.jl @@ -680,8 +680,8 @@ function likelihood_optima(model::MultivariateAssumeDemoModels) vals = rand_prior_true(model) # NOTE: These are "as close to zero as we can get". - vals.s[1] = 1e-32 - vals.s[2] = 1e-32 + vals.s[1] = floatmin() + vals.s[2] = floatmin() vals.m[1] = 1.5 vals.m[2] = 2.0 @@ -733,8 +733,8 @@ function likelihood_optima(model::MatrixvariateAssumeDemoModels) vals = rand_prior_true(model) # NOTE: These are "as close to zero as we can get". - vals.s[1, 1] = 1e-32 - vals.s[1, 2] = 1e-32 + vals.s[1, 1] = floatmin() + vals.s[1, 2] = floatmin() vals.m[1] = 1.5 vals.m[2] = 2.0 @@ -783,8 +783,8 @@ function likelihood_optima(model::Model{typeof(demo_nested_colons)}) vals = rand_prior_true(model) # NOTE: These are "as close to zero as we can get". - vals.s.params[1].subparams[1, 1, 1] = 1e-32 - vals.s.params[1].subparams[1, 1, 2] = 1e-32 + vals.s.params[1].subparams[1, 1, 1] = floatmin() + vals.s.params[1].subparams[1, 1, 2] = floatmin() vals.m[1] = 1.5 vals.m[2] = 2.0 From 57fd84b20c0ad7c251a415739a64ef186fafa2c7 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 6 Jan 2026 14:38:42 +0000 Subject: [PATCH 091/148] Narrow a skip clause --- test/logdensityfunction.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index b58006de2..7014140b9 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -114,10 +114,11 @@ end ldf = DynamicPPL.LogDensityFunction(m, DynamicPPL.getlogjoint_internal, vi) x = vi[:] # The below type inference fails on v1.10. + skip = (VERSION < v"1.11.0" && m.f === DynamicPPL.TestUtils.demo_nested_colons) @test begin @inferred LogDensityProblems.logdensity(ldf, x) true - end skip = (VERSION < v"1.11.0") + end skip = skip end end end From 345b6058796a26c34702afcd11537c46ff5c6d35 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 6 Jan 2026 14:50:23 +0000 Subject: [PATCH 092/148] Use vnt_size instead of Base.size --- docs/src/internals/varnamedtuple.md | 4 ++++ src/DynamicPPL.jl | 2 +- src/contexts/init.jl | 2 +- src/varnamedtuple.jl | 30 ++++++++++++++++++++--------- 4 files changed, 27 insertions(+), 11 deletions(-) diff --git a/docs/src/internals/varnamedtuple.md b/docs/src/internals/varnamedtuple.md index aa08c119d..daa062d2d 100644 --- a/docs/src/internals/varnamedtuple.md +++ b/docs/src/internals/varnamedtuple.md @@ -168,6 +168,10 @@ For instance, if `setindex!!(vnt, @varname(a[1:5]), val)` has been set, then the Not `@varname(a[1:10])`, nor `@varname(a[3])`, nor for anything else that overlaps with `@varname(a[1:5])`. `haskey` likewise only returns true for `@varname(a[1:5])`, and `keys(vnt)` only has that as an element. +The size of a value, for the purposes of inserting it into a `PartialArray`, is determined by a call to `vnt_size`. +`vnt_size` falls back to calling `Base.size`. +The reason we define a distinct function is to be able to control its behaviour, if necessary, without type piracy. + ## Limitations This design has a several of benefits, for performance and generality, but it also has limitations: diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 25ca59018..95831062f 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -185,7 +185,7 @@ abstract type AbstractVarInfo <: AbstractModelTrace end # Necessary forward declarations include("utils.jl") include("varnamedtuple.jl") -using .VarNamedTuples: VarNamedTuple +using .VarNamedTuples: VarNamedTuples, VarNamedTuple include("contexts.jl") include("contexts/default.jl") include("contexts/init.jl") diff --git a/src/contexts/init.jl b/src/contexts/init.jl index dd9e99421..f5259f1cd 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -215,7 +215,7 @@ struct RangeAndLinked{T<:Tuple} original_size::T end -Base.size(ral::RangeAndLinked) = ral.original_size +VarNamedTuples.vnt_size(ral::RangeAndLinked) = ral.original_size """ VectorWithRanges{Tlink}( diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 0346ec6e6..359129753 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -8,7 +8,7 @@ using BangBang using Accessors using ..DynamicPPL: _compose_no_identity -export VarNamedTuple +export VarNamedTuple, vnt_size # We define our own getindex, setindex!!, and haskey functions, which we use to # get/set/check values in VarNamedTuple and PartialArray. We do this because we want to be @@ -81,6 +81,17 @@ const INDEX_TYPES = Union{Integer,AbstractUnitRange,Colon,AbstractPPL.Concretize _unwrap_concretized_slice(cs::AbstractPPL.ConcretizedSlice) = cs.range _unwrap_concretized_slice(x::Union{Integer,AbstractUnitRange,Colon}) = x +""" + vnt_size(x) + +Get the size of an object `x` for use in `VarNamedTuple` and `PartialArray`. + +By default, this falls back onto `Base.size`, but can be overloaded for custom types. +This notion of type is used to determine whether a value can be set into a `PartialArray` +as a block, see the docstring of `PartialArray` and `ArrayLikeBlock` for details. +""" +vnt_size(x) = size(x) + """ ArrayLikeBlock{T,I} @@ -156,11 +167,12 @@ Like `Base.Array`s, `PartialArray`s have a well-defined, compile-time-known elem One can set values in a `PartialArray` either element-by-element, or with ranges like `arr[1:3,2] = [5,10,15]`. When setting values over a range of indices, the value being set -must either be an `AbstractArray` or otherwise something for which `size(value)` is defined, -and the size mathces the range. If the value is an `AbstractArray`, the elements are copied -individually, but if it is not, the value is stored as a block, that takes up the whole -range, e.g. `[1:3,2]`, but is only a single object. Getting such a block-value must be done -with the exact same range of indices, otherwise an error is thrown. +must either be an `AbstractArray` or otherwise something for which `vnt_size(value)` or +`Base.size(value)` (which `vnt_size` falls back onto) is defined, and the size matches the +range. If the value is an `AbstractArray`, the elements are copied individually, but if it +is not, the value is stored as a block, that takes up the whole range, e.g. `[1:3,2]`, but +is only a single object. Getting such a block-value must be done with the exact same range +of indices, otherwise an error is thrown. If the element type of a `PartialArray` is not concrete, any call to `setindex!!` will check if, after the new value has been set, the element type can be made more concrete. If so, @@ -612,11 +624,11 @@ function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) new_data = pa.data if _needs_arraylikeblock(value, inds...) inds_size = reduce((x, y) -> tuple(x..., y...), map(size, inds)) - if size(value) != inds_size + if vnt_size(value) != inds_size throw( DimensionMismatch( - "Assigned value has size $(size(value)), which does not match the " * - "size implied by the indices $(map(x -> _length_needed(x), inds)).", + "Assigned value has size $(vnt_size(value)), which does not match " * + "the size implied by the indices $(map(x -> _length_needed(x), inds)).", ), ) end From 4f12e995e6648f1ee4a9c16918aca414a8650167 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 6 Jan 2026 15:11:37 +0000 Subject: [PATCH 093/148] Bugfix --- src/varnamedtuple.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 359129753..4e99aa1ec 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -608,7 +608,7 @@ The value only depends on the types of the arguments, and should be constant pro function _needs_arraylikeblock(value, inds::Vararg{INDEX_TYPES}) return _is_multiindex(inds) && !isa(value, AbstractArray) && - hasmethod(size, Tuple{typeof(value)}) + hasmethod(vnt_size, Tuple{typeof(value)}) end function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) From 60b2399551cb3c77e1896f1195aec2bfeef99273 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 6 Jan 2026 16:27:05 +0000 Subject: [PATCH 094/148] Style improvements Co-authored-by: Penelope Yong --- src/varnamedtuple.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 0346ec6e6..57af53609 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -342,9 +342,7 @@ end Base.isempty(pa::PartialArray) = !any(pa.mask) Base.empty(pa::PartialArray) = PartialArray(similar(pa.data), fill(false, size(pa.mask))) function BangBang.empty!!(pa::PartialArray) - for i in eachindex(pa.mask) - @inbounds pa.mask[i] = false - end + fill!(pa.mask, false) return pa end @@ -815,7 +813,7 @@ function _dense_array(pa::PartialArray) # Check that all indices within size_needed are set. slice = ntuple(d -> 1:size_needed[d], num_dims) - if any(.!(pa.mask[slice...])) + if !all(pa.mask[slice...]) throw( ArgumentError( "Cannot convert PartialArray to dense Array when some elements within " * From e603229b8d8ae77a32ec489349eee60668220f95 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 6 Jan 2026 16:38:56 +0000 Subject: [PATCH 095/148] Remove to_dict on VNT --- src/varnamedtuple.jl | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 57af53609..cc25c6cd5 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -1192,12 +1192,6 @@ function make_leaf(value, optic::IndexLens) return _setindex!!(pa, value, optic) end -function to_dict(::Type{T}, vnt::VarNamedTuple) where {T<:AbstractDict{<:VarName}} - pairs = splat(Pair).(zip(keys(vnt), values(vnt))) - return T(pairs...) -end -to_dict(vnt::VarNamedTuple) = to_dict(Dict{VarName,Any}, vnt) - function AbstractPPL.hasvalue(vnt::VarNamedTuple, vn::VarName) return haskey(vnt, vn) end From 21ba31d1990fc022d16fd4863c6859a0a02f0c9a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 6 Jan 2026 18:41:55 +0000 Subject: [PATCH 096/148] Remove unneeded Union --- src/chains.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/chains.jl b/src/chains.jl index 71ca29a8f..beec1d3e1 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -5,7 +5,7 @@ A struct which contains parameter values extracted from a `VarInfo`, along with statistics associated with the VarInfo. The statistics are provided as a NamedTuple and are optional. """ -struct ParamsWithStats{P<:Union{OrderedDict{<:VarName,<:Any},VarNamedTuple},S<:NamedTuple} +struct ParamsWithStats{P<:VarNamedTuple,S<:NamedTuple} params::P stats::S end From cacd42690cda43d1efd01535077eac183418ea3e Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 7 Jan 2026 17:00:37 +0000 Subject: [PATCH 097/148] Add vnt_size docstring to API docs --- docs/src/api.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/src/api.md b/docs/src/api.md index 20eb1ce35..f687fd90a 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -377,6 +377,7 @@ SimpleVarInfo ```@docs DynamicPPL.VarNamedTuples.VarNamedTuple +DynamicPPL.VarNamedTuples.vnt_size ``` ### Accumulators From bdeeb4ab1df4a3eeb678907bbe8e72a4818cd90a Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 8 Jan 2026 16:42:31 +0000 Subject: [PATCH 098/148] Update map!! to operate on pairs --- src/varnamedtuple.jl | 58 ++++++++++++++++++++++++++++++++++--------- test/varnamedtuple.jl | 23 +++++++++-------- 2 files changed, 59 insertions(+), 22 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index b368ab8cd..faf2a298d 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -1112,37 +1112,54 @@ function apply!!(func, vnt::VarNamedTuple, name::VarName) end """ - _map_recursive!!(func, x) + _map_recursive!!(func, x, vn) -Call `func` on `x`, except if `x` is a `VarNamedTuple` or `PartialArray`, in which case -call `_map_recursive!!` recursively on all their elements.. +Call `func` on `vn => x`, except if `x` is a `VarNamedTuple` or `PartialArray`, in which +case call `_map_recursive!!` recursively on all their elements, updating `vn` with the right +prefix. This is the internal implementation of `map!!`, but because it has a method defined for literally every type in existence, we hide it behind the interface of the more discriminating `map!!`. It makes the implementation a bit simpler, compared to checking element types within `map!!` itself. """ -_map_recursive!!(func, x) = func(x) - -function _map_recursive!!(func, pa::PartialArray) - # Ask the compiler to infer the return type of applying func to eltype(pa). - new_et = Core.Compiler.return_type(x -> _map_recursive!!(func, x), Tuple{eltype(pa)}) +_map_recursive!!(func, x, vn) = func(vn => x) + +# TODO(mhauru) The below is type unstable for some complex VarNames. My example case +# for which type stability fails is @varname(e.f[3].g.h[2].i). I don't understand this +# well, but I think it's just because constant propagation gives up at some point, and fails +# to go through the lines that figure out `new_et`. I could be wrong. I tried fixing this by +# lifting the first three lines of the function into a generated function, but that seems +# to run into trouble when trying to call Core.Compiler.return_type recursively on the same +# function. An earlier implementation of this function that only operated on the values, +# not on pairs of key => value, was type stable (presumably because it was a bit easier on +# constant propagation). +function _map_recursive!!(func, pa::PartialArray, vn) + # Ask the compiler to infer the return type of applying func recursively to eltype(pa). + index_type = IndexLens{NTuple{ndims(pa),Int}} + new_vn_type = Core.Compiler.return_type(∘, Tuple{index_type,typeof(vn)}) + new_et = Core.Compiler.return_type( + Tuple{typeof(_map_recursive!!),typeof(func),eltype(pa),new_vn_type} + ) new_data = if new_et <: eltype(pa) + # We can reuse the existing data array. pa.data else + # We need to allocate a new data array. similar(pa.data, new_et) end @inbounds for i in eachindex(pa.mask) if pa.mask[i] - new_data[i] = _map_recursive!!(func, pa.data[i]) + new_vn = IndexLens(Tuple(i)) ∘ vn + new_data[i] = _map_recursive!!(func, pa.data[i], new_vn) end end # The above type inference may be overly conservative, so we concretise the eltype. return _concretise_eltype!!(PartialArray(new_data, pa.mask)) end -function _map_recursive!!(func, alb::ArrayLikeBlock) - new_block = _map_recursive!!(func, alb.block) +function _map_recursive!!(func, alb::ArrayLikeBlock, vn) + new_block = _map_recursive!!(func, alb.block, vn) if size(new_block) != size(alb.block) throw( DimensionMismatch( @@ -1157,7 +1174,22 @@ end @generated function _map_recursive!!(func, vnt::VarNamedTuple{Names}) where {Names} exs = Expr[] for name in Names - push!(exs, :(_map_recursive!!(func, vnt.data.$name))) + push!(exs, :(_map_recursive!!(func, vnt.data.$name, VarName{$(QuoteNode(name))}()))) + end + return quote + return VarNamedTuple(NamedTuple{Names}(($(exs...),))) + end +end + +@generated function _map_recursive!!(func, vnt::VarNamedTuple{Names}, vn::T) where {Names,T} + exs = Expr[] + for name in Names + push!( + exs, + :(_map_recursive!!( + func, vnt.data.$name, AbstractPPL.prefix(vn, VarName{$(QuoteNode(name))}()) + )), + ) end return quote return VarNamedTuple(NamedTuple{Names}(($(exs...),))) @@ -1168,6 +1200,8 @@ end map!!(func, vnt::VarNamedTuple) Apply `func` to all set elements of the `vnt`, in place if possible. + +`func` should accept a pair of `VarName` and value, and return the new value to be set. """ map!!(func, vnt::VarNamedTuple) = _map_recursive!!(func, vnt) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index face2dc42..0b3076468 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -851,7 +851,7 @@ Base.size(st::SizedThing) = st.size end Base.size(st::AnotherSizedThing) = st.size - function f(val) + function f_val(val) if val isa Int return val + 10 elseif val isa AbstractVector{Int} @@ -869,13 +869,16 @@ Base.size(st::SizedThing) = st.size end end + f_pair(pair) = f_val(pair.second) + reduction = mapreduce(identity, vcat, vnt; init=Any[]) @test reduction == vcat(Any[], 1, [2, 2], [3.0], "a", 5.0, SizedThing((2, 2)), "") - reduction = mapreduce(f, vcat, vnt; init=Any[]) + reduction = mapreduce(f_val, vcat, vnt; init=Any[]) @test reduction == vcat(Any[], 11, [12, 12], [2.0], "ab", 6.0, AnotherSizedThing((2, 2)), "b") - vnt_mapped = @inferred(map!!(f, copy(vnt))) + # vnt_mapped = @inferred(map!!(f, copy(vnt))) + vnt_mapped = map!!(f_pair, copy(vnt)) test_invariants(vnt_mapped) @test @inferred(getindex(vnt_mapped, @varname(a))) == 11 @test @inferred(getindex(vnt_mapped, @varname(b[1:2]))) == [12, 12] @@ -886,26 +889,26 @@ Base.size(st::SizedThing) = st.size AnotherSizedThing((2, 2)) @test @inferred(getindex(vnt_mapped, @varname(w[4][3][2, 1]))) == "b" - vnt_applied = @inferred(apply!!(f, vnt, @varname(a))) + vnt_applied = @inferred(apply!!(f_val, vnt, @varname(a))) test_invariants(vnt_applied) @test @inferred(getindex(vnt_applied, @varname(a))) == 11 @test @inferred(getindex(vnt_applied, @varname(b[1:2]))) == [2, 2] - vnt_applied = @inferred(apply!!(f, vnt_applied, @varname(b[1:2]))) + vnt_applied = @inferred(apply!!(f_val, vnt_applied, @varname(b[1:2]))) test_invariants(vnt_applied) @test @inferred(getindex(vnt_applied, @varname(a))) == 11 @test @inferred(getindex(vnt_applied, @varname(b[1:2]))) == [12, 12] - vnt_applied = @inferred(apply!!(f, vnt_applied, @varname(c.d))) + vnt_applied = @inferred(apply!!(f_val, vnt_applied, @varname(c.d))) test_invariants(vnt_applied) @test @inferred(getindex(vnt_applied, @varname(c.d))) == [2.0] - vnt_applied = @inferred(apply!!(f, vnt_applied, @varname(e.f[3].g.h[2].i))) + vnt_applied = @inferred(apply!!(f_val, vnt_applied, @varname(e.f[3].g.h[2].i))) test_invariants(vnt_applied) @test @inferred(getindex(vnt_applied, @varname(e.f[3].g.h[2].i))) == "ab" @test @inferred(getindex(vnt_applied, @varname(e.f[3].g.h[2].j))) == 5.0 - vnt_applied = @inferred(apply!!(f, vnt_applied, @varname(e.f[3].g.h[2].j))) + vnt_applied = @inferred(apply!!(f_val, vnt_applied, @varname(e.f[3].g.h[2].j))) test_invariants(vnt_applied) @test @inferred(getindex(vnt_applied, @varname(e.f[3].g.h[2].i))) == "ab" @test @inferred(getindex(vnt_applied, @varname(e.f[3].g.h[2].j))) == 6.0 @@ -913,12 +916,12 @@ Base.size(st::SizedThing) = st.size # This can't be type stable because y.z might have many elements set, and we can't # know at compile time that this sets the only one, thus allowing the element type # to be AnotherSizedThing. - vnt_applied = apply!!(f, vnt_applied, @varname(y.z[3, 2:3, 3, 2:3, 4])) + vnt_applied = apply!!(f_val, vnt_applied, @varname(y.z[3, 2:3, 3, 2:3, 4])) test_invariants(vnt_applied) @test @inferred(getindex(vnt_applied, @varname(y.z[3, 2:3, 3, 2:3, 4]))) == AnotherSizedThing((2, 2)) - vnt_applied = @inferred(apply!!(f, vnt_applied, @varname(w[4][3][2, 1]))) + vnt_applied = @inferred(apply!!(f_val, vnt_applied, @varname(w[4][3][2, 1]))) test_invariants(vnt_applied) @test @inferred(getindex(vnt_applied, @varname(w[4][3][2, 1]))) == "b" end From 5498d8279a9667aadba5173171b3edd3220ffe59 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 8 Jan 2026 17:09:48 +0000 Subject: [PATCH 099/148] Split map!! into map_pairs!! and map_values!!, fix some bugs --- src/varnamedtuple.jl | 33 ++++++++++++++++++++------------- test/varnamedtuple.jl | 33 +++++++++++++++++++++++++++++---- 2 files changed, 49 insertions(+), 17 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index faf2a298d..2217d35ee 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -8,7 +8,7 @@ using BangBang using Accessors using ..DynamicPPL: _compose_no_identity -export VarNamedTuple, map!!, apply!! +export VarNamedTuple, map_pairs!!, map_values!!, apply!! # We define our own getindex, setindex!!, and haskey functions, which we use to # get/set/check values in VarNamedTuple and PartialArray. We do this because we want to be @@ -1083,7 +1083,7 @@ end Apply `func` to the subdata at `name` in `vnt`, and set the result back at `name`. -Like `map!!`, but only for a single `VarName`. +Like `map_values!!`, but only for a single `VarName`. ```jldoctest julia> using DynamicPPL: VarNamedTuple, setindex!! @@ -1118,10 +1118,10 @@ Call `func` on `vn => x`, except if `x` is a `VarNamedTuple` or `PartialArray`, case call `_map_recursive!!` recursively on all their elements, updating `vn` with the right prefix. -This is the internal implementation of `map!!`, but because it has a method defined for -literally every type in existence, we hide it behind the interface of the more -discriminating `map!!`. It makes the implementation a bit simpler, compared to checking -element types within `map!!` itself. +This is the internal implementation of `map_pairs!!`, but because it has a method defined +for literally every type in existence, we hide it behind the interface of the more +discriminating `map_pairs!!`. It makes the implementation a bit simpler, compared to +checking element types within `map_pairs!!` itself. """ _map_recursive!!(func, x, vn) = func(vn => x) @@ -1148,7 +1148,7 @@ function _map_recursive!!(func, pa::PartialArray, vn) # We need to allocate a new data array. similar(pa.data, new_et) end - @inbounds for i in eachindex(pa.mask) + @inbounds for i in CartesianIndices(pa.mask) if pa.mask[i] new_vn = IndexLens(Tuple(i)) ∘ vn new_data[i] = _map_recursive!!(func, pa.data[i], new_vn) @@ -1163,8 +1163,8 @@ function _map_recursive!!(func, alb::ArrayLikeBlock, vn) if size(new_block) != size(alb.block) throw( DimensionMismatch( - "map!! can't change the size of an ArrayLikeBlock. Tried to change from" * - "$(size(alb.block)) to $(size(new_block)).", + "map_pairs!! can't change the size of an ArrayLikeBlock. Tried to change " * + "from $(size(alb.block)) to $(size(new_block)).", ), ) end @@ -1187,7 +1187,7 @@ end push!( exs, :(_map_recursive!!( - func, vnt.data.$name, AbstractPPL.prefix(vn, VarName{$(QuoteNode(name))}()) + func, vnt.data.$name, AbstractPPL.prefix(VarName{$(QuoteNode(name))}(), vn) )), ) end @@ -1197,13 +1197,20 @@ end end """ - map!!(func, vnt::VarNamedTuple) + map_pairs!!(func, vnt::VarNamedTuple) -Apply `func` to all set elements of the `vnt`, in place if possible. +Apply `func` to all key => value pairs of `vnt`, in place if possible. `func` should accept a pair of `VarName` and value, and return the new value to be set. """ -map!!(func, vnt::VarNamedTuple) = _map_recursive!!(func, vnt) +map_pairs!!(func, vnt::VarNamedTuple) = _map_recursive!!(func, vnt) + +""" + map_values!!(func, vnt::VarNamedTuple) + +Apply `func` to elements of `vnt`, in place if possible. +""" +map_values!!(func, vnt::VarNamedTuple) = map_pairs!!(pair -> func(pair.second), vnt) function Base.mapreduce(f, op, vnt::VarNamedTuple; init=nothing) if init === nothing diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 0b3076468..13436d39a 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -3,7 +3,8 @@ module VarNamedTupleTests using Combinatorics: Combinatorics using Test: @inferred, @test, @test_throws, @testset using DynamicPPL: DynamicPPL, @varname, VarNamedTuple -using DynamicPPL.VarNamedTuples: PartialArray, ArrayLikeBlock, map!!, apply!! +using DynamicPPL.VarNamedTuples: + PartialArray, ArrayLikeBlock, map_pairs!!, map_values!!, apply!! using AbstractPPL: VarName, concretize, prefix using BangBang: setindex!!, empty!! @@ -64,6 +65,9 @@ function test_invariants(vnt::VarNamedTuple) @test isempty(values(emptied_vnt)) # Check that the copy protected the original vnt from being modified. @test isempty(vnt) == was_empty + # Check that map is a no-op when using identity functions. + @test isequal(map_pairs!!(pair -> pair.second, copy(vnt)), vnt) + @test isequal(map_values!!(identity, copy(vnt)), vnt) end """ A type that has a size but is not an Array. Used in ArrayLikeBlock tests.""" @@ -830,7 +834,7 @@ Base.size(st::SizedThing) = st.size @test @inferred(getindex(vnt, @varname(y.z[3, 2:3, 3, 2:3, 4]))) == val end - @testset "map!!, apply!!, and mapreduce" begin + @testset "map and friends" begin vnt = VarNamedTuple() vnt = @inferred(setindex!!(vnt, 1, @varname(a))) vnt = @inferred(setindex!!(vnt, [2, 2], @varname(b[1:2]))) @@ -877,8 +881,11 @@ Base.size(st::SizedThing) = st.size @test reduction == vcat(Any[], 11, [12, 12], [2.0], "ab", 6.0, AnotherSizedThing((2, 2)), "b") - # vnt_mapped = @inferred(map!!(f, copy(vnt))) - vnt_mapped = map!!(f_pair, copy(vnt)) + # TODO(mhauru) This should hopefully be type stable, but fails to be so because of + # some complex VarNames being too much for constant propagation. See comment in + # src/varnamedtuple.jl for more. + vnt_mapped = map_pairs!!(f_pair, copy(vnt)) + @test vnt_mapped == map_values!!(f_val, copy(vnt)) test_invariants(vnt_mapped) @test @inferred(getindex(vnt_mapped, @varname(a))) == 11 @test @inferred(getindex(vnt_mapped, @varname(b[1:2]))) == [12, 12] @@ -924,6 +931,24 @@ Base.size(st::SizedThing) = st.size vnt_applied = @inferred(apply!!(f_val, vnt_applied, @varname(w[4][3][2, 1]))) test_invariants(vnt_applied) @test @inferred(getindex(vnt_applied, @varname(w[4][3][2, 1]))) == "b" + + # map a function that maps every key => value pair to key => key. + # For this, use a simpler VarNamedTuple, because block variables don't work with + # this mapping function. It also allows us to check type stability. + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, 1, @varname(a))) + vnt = @inferred(setindex!!(vnt, 2, @varname(b[2]))) + vnt = @inferred(setindex!!(vnt, [3.0], @varname(c.d))) + vnt = @inferred(setindex!!(vnt, :oi, @varname(y.z[3, 2, 3, 2, 4]))) + vnt = @inferred(setindex!!(vnt, "", @varname(w[4][2, 1]))) + + get_key(pair) = pair.first + vnt_key_mapped = @inferred(map_pairs!!(get_key, copy(vnt))) + vnt_key_mapped_expected = VarNamedTuple() + for k in keys(vnt) + vnt_key_mapped_expected = setindex!!(vnt_key_mapped_expected, k, k) + end + @test vnt_key_mapped == vnt_key_mapped_expected end end From 81be716f37455a4121fbeadb174a20063164936b Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 8 Jan 2026 17:31:01 +0000 Subject: [PATCH 100/148] Make mapreduce operate on pairs --- src/varnamedtuple.jl | 80 +++++++++++++++++++++++++++++++++++++------ test/varnamedtuple.jl | 18 ++++++++-- 2 files changed, 85 insertions(+), 13 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 2217d35ee..4b91fbb12 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -1136,12 +1136,13 @@ _map_recursive!!(func, x, vn) = func(vn => x) # constant propagation). function _map_recursive!!(func, pa::PartialArray, vn) # Ask the compiler to infer the return type of applying func recursively to eltype(pa). + et = eltype(pa) index_type = IndexLens{NTuple{ndims(pa),Int}} new_vn_type = Core.Compiler.return_type(∘, Tuple{index_type,typeof(vn)}) new_et = Core.Compiler.return_type( - Tuple{typeof(_map_recursive!!),typeof(func),eltype(pa),new_vn_type} + Tuple{typeof(_map_recursive!!),typeof(func),et,new_vn_type} ) - new_data = if new_et <: eltype(pa) + new_data = if new_et <: et # We can reuse the existing data array. pa.data else @@ -1150,7 +1151,13 @@ function _map_recursive!!(func, pa::PartialArray, vn) end @inbounds for i in CartesianIndices(pa.mask) if pa.mask[i] - new_vn = IndexLens(Tuple(i)) ∘ vn + val = pa.data[i] + # The first two checks on the below line are just a performance optimisation: + # They may short circuit at compile time. + is_alb = + (et <: ArrayLikeBlock || ArrayLikeBlock <: et) && val isa ArrayLikeBlock + ind = is_alb ? val.inds : Tuple(i) + new_vn = IndexLens(ind) ∘ vn new_data[i] = _map_recursive!!(func, pa.data[i], new_vn) end end @@ -1212,6 +1219,16 @@ Apply `func` to elements of `vnt`, in place if possible. """ map_values!!(func, vnt::VarNamedTuple) = map_pairs!!(pair -> func(pair.second), vnt) +""" + mapreduce(f, op, vnt::VarNamedTuple; init) + +Apply `f` to all elements of `vnt`, and reduce the results using `op`, starting from `init`. + +`init` is a keyword argument to conform to the usual `mapreduce` interface in Base, but it +is not optional. + +`f` op` should accept pairs of `VarName` and value. +""" function Base.mapreduce(f, op, vnt::VarNamedTuple; init=nothing) if init === nothing throw( @@ -1223,8 +1240,8 @@ function Base.mapreduce(f, op, vnt::VarNamedTuple; init=nothing) return _mapreduce_recursive(f, op, vnt, init) end -_mapreduce_recursive(f, op, x, init) = op(init, f(x)) -_mapreduce_recursive(f, op, pa::ArrayLikeBlock, init) = op(init, f(pa.block)) +_mapreduce_recursive(f, op, x, vn, init) = op(init, f(vn => x)) +_mapreduce_recursive(f, op, pa::ArrayLikeBlock, vn, init) = op(init, f(vn => pa.block)) @generated function _mapreduce_recursive( f, op, vnt::VarNamedTuple{Names}, init @@ -1237,25 +1254,68 @@ _mapreduce_recursive(f, op, pa::ArrayLikeBlock, init) = op(init, f(pa.block)) end, ) for name in Names - push!(exs, :(result = _mapreduce_recursive(f, op, vnt.data.$name, result))) + push!( + exs, + :( + result = _mapreduce_recursive( + f, op, vnt.data.$name, VarName{$(QuoteNode(name))}(), result + ) + ), + ) + end + push!(exs, :(return result)) + return Expr(:block, exs...) +end + +@generated function _mapreduce_recursive( + f, op, vnt::VarNamedTuple{Names}, vn, init +) where {Names} + exs = Expr[] + push!( + exs, + quote + result = init + end, + ) + for name in Names + push!( + exs, + :( + result = _mapreduce_recursive( + f, + op, + vnt.data.$name, + AbstractPPL.prefix(VarName{$(QuoteNode(name))}(), vn), + result, + ) + ), + ) end push!(exs, :(return result)) return Expr(:block, exs...) end -function _mapreduce_recursive(f, op, pa::PartialArray, init) +function _mapreduce_recursive(f, op, pa::PartialArray, vn, init) result = init + et = eltype(pa) + albs_seen = Set{ArrayLikeBlock}() - @inbounds for i in eachindex(pa.mask) + @inbounds for i in CartesianIndices(pa.mask) if pa.mask[i] val = @inbounds pa.data[i] - if val isa ArrayLikeBlock + # The first two checks on the below line are just a performance optimisation: + # They may short circuit at compile time. + is_alb = + (et <: ArrayLikeBlock || ArrayLikeBlock <: et) && val isa ArrayLikeBlock + if is_alb if val in albs_seen continue end push!(albs_seen, val) end - result = _mapreduce_recursive(f, op, pa.data[i], result) + ind = is_alb ? val.inds : Tuple(i) + new_vn = IndexLens(ind) ∘ vn + result = _mapreduce_recursive(f, op, pa.data[i], new_vn, result) end end return result diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 13436d39a..fe0417f2b 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -875,9 +875,21 @@ Base.size(st::SizedThing) = st.size f_pair(pair) = f_val(pair.second) - reduction = mapreduce(identity, vcat, vnt; init=Any[]) - @test reduction == vcat(Any[], 1, [2, 2], [3.0], "a", 5.0, SizedThing((2, 2)), "") - reduction = mapreduce(f_val, vcat, vnt; init=Any[]) + val_reduction = mapreduce(pair -> pair.second, vcat, vnt; init=Any[]) + @test val_reduction == + vcat(Any[], 1, [2, 2], [3.0], "a", 5.0, SizedThing((2, 2)), "") + key_reduction = mapreduce(pair -> pair.first, vcat, vnt; init=Any[]) + @test key_reduction == vcat( + @varname(a), + @varname(b[1]), + @varname(b[2]), + @varname(c.d), + @varname(e.f[3].g.h[2].i), + @varname(e.f[3].g.h[2].j), + @varname(y.z[3, 2:3, 3, 2:3, 4]), + @varname(w[4][3][2, 1]), + ) + reduction = mapreduce(f_pair, vcat, vnt; init=Any[]) @test reduction == vcat(Any[], 11, [12, 12], [2.0], "ab", 6.0, AnotherSizedThing((2, 2)), "b") From 37f4adfb66f70c2d4b3226c1714d4aa15043c89f Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 8 Jan 2026 19:08:54 +0000 Subject: [PATCH 101/148] Implement keys and values using mapreduce --- src/varnamedtuple.jl | 145 +++++++------------------------------------ 1 file changed, 21 insertions(+), 124 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 4b91fbb12..d165ca3a5 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -378,6 +378,21 @@ function BangBang.empty!!(pa::PartialArray) return pa end +# Length could be defined as a special case of mapreduce, but it's harder to keep it type +# stable that way: If the element type is abstract, we end up calling _mapreduce_recursive +# on an abstract type, which makes the type of the cumulant Any. +function Base.length(pa::PartialArray) + len = 0 + @inbounds for i in eachindex(pa.mask) + if !pa.mask[i] + continue + end + val = pa.data[i] + len += val isa VarNamedTuple || val isa PartialArray ? length(val) : 1 + end + return len +end + """ _concretise_eltype!!(pa::PartialArray) @@ -745,83 +760,6 @@ function _merge_recursive(pa1::PartialArray, pa2::PartialArray) end end -function Base.keys(pa::PartialArray) - # TODO(mhauru) Should this rather be Union{}[]? It would make this very type unstable - # and cause more allocations, but would result in more concrete element types. Same - # question for Base.keys on VNT and Base.values. - ks = Any[] - alb_inds_seen = Set{Tuple}() - for ind in CartesianIndices(pa.mask) - @inbounds if !pa.mask[ind] - continue - end - lens = IndexLens(Tuple(ind)) - val = getindex(pa.data, lens.indices...) - if val isa VarNamedTuple - subkeys = keys(val) - for vn in subkeys - sublens = _varname_to_lens(vn) - push!(ks, _compose_no_identity(sublens, lens)) - end - elseif val isa PartialArray - subkeys = keys(val) - for sublens in subkeys - push!(ks, _compose_no_identity(sublens, lens)) - end - elseif val isa ArrayLikeBlock - if !(val.inds in alb_inds_seen) - push!(ks, IndexLens(Tuple(val.inds))) - push!(alb_inds_seen, val.inds) - end - else - push!(ks, lens) - end - end - return ks -end - -function Base.values(pa::PartialArray) - vs = Any[] - albs_seen = Set{ArrayLikeBlock}() - for ind in CartesianIndices(pa.mask) - @inbounds if !pa.mask[ind] - continue - end - val = getindex(pa.data, ind) - if val isa VarNamedTuple || val isa PartialArray - subvalues = values(val) - vs = push!!(vs, subvalues...) - elseif val isa ArrayLikeBlock - if !(val in albs_seen) - vs = push!!(vs, val.block) - push!(albs_seen, val) - end - else - vs = push!!(vs, val) - end - end - return vs -end - -function Base.length(pa::PartialArray) - len = 0 - for ind in CartesianIndices(pa.mask) - @inbounds if !pa.mask[ind] - continue - end - val = getindex(pa.data, ind) - if val isa VarNamedTuple || val isa PartialArray - len += length(val) - else - # Note we don't need to special case here for ArrayLikeBlocks. That's because - # we treat every index pointing to the same ArrayLikeBlock as contributing to - # the length. - len += 1 - end - end - return len -end - """ _dense_array(pa::PartialArray) @@ -1152,11 +1090,7 @@ function _map_recursive!!(func, pa::PartialArray, vn) @inbounds for i in CartesianIndices(pa.mask) if pa.mask[i] val = pa.data[i] - # The first two checks on the below line are just a performance optimisation: - # They may short circuit at compile time. - is_alb = - (et <: ArrayLikeBlock || ArrayLikeBlock <: et) && val isa ArrayLikeBlock - ind = is_alb ? val.inds : Tuple(i) + ind = val isa ArrayLikeBlock ? val.inds : Tuple(i) new_vn = IndexLens(ind) ∘ vn new_data[i] = _map_recursive!!(func, pa.data[i], new_vn) end @@ -1303,10 +1237,7 @@ function _mapreduce_recursive(f, op, pa::PartialArray, vn, init) @inbounds for i in CartesianIndices(pa.mask) if pa.mask[i] val = @inbounds pa.data[i] - # The first two checks on the below line are just a performance optimisation: - # They may short circuit at compile time. - is_alb = - (et <: ArrayLikeBlock || ArrayLikeBlock <: et) && val isa ArrayLikeBlock + is_alb = val isa ArrayLikeBlock if is_alb if val in albs_seen continue @@ -1321,47 +1252,13 @@ function _mapreduce_recursive(f, op, pa::PartialArray, vn, init) return result end -function Base.keys(vnt::VarNamedTuple) - result = VarName[] - for sym in keys(vnt.data) - subdata = vnt.data[sym] - if subdata isa VarNamedTuple - subkeys = keys(subdata) - append!(result, [AbstractPPL.prefix(sk, VarName{sym}()) for sk in subkeys]) - elseif subdata isa PartialArray - subkeys = keys(subdata) - append!(result, [VarName{sym}(lens) for lens in subkeys]) - else - push!(result, VarName{sym}()) - end - end - return result -end - -function Base.values(vnt::VarNamedTuple) - # TODO(mhauru) Same comments as for keys for type stability and Any vs Union{} - result = Any[] - for sym in keys(vnt.data) - subdata = vnt.data[sym] - if subdata isa VarNamedTuple || subdata isa PartialArray - subvalues = values(subdata) - append!(result, subvalues) - else - push!(result, subdata) - end - end - return result -end +Base.keys(vnt::VarNamedTuple) = mapreduce(first, push!, vnt; init=VarName[]) +Base.values(vnt::VarNamedTuple) = mapreduce(pair -> pair.second, push!, vnt; init=Any[]) function Base.length(vnt::VarNamedTuple) len = 0 - for sym in keys(vnt.data) - subdata = vnt.data[sym] - if subdata isa VarNamedTuple || subdata isa PartialArray - len += length(subdata) - else - len += 1 - end + for subdata in vnt.data + len += subdata isa VarNamedTuple || subdata isa PartialArray ? length(subdata) : 1 end return len end From fc29cc66cc013701f1ed472f8b9e0cbcb845455d Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 9 Jan 2026 11:23:58 +0000 Subject: [PATCH 102/148] Add more VNT constructors --- src/varnamedtuple.jl | 48 +++++++++++++++++++++++++++++++++++++++++-- test/varnamedtuple.jl | 34 ++++++++++++++++++++++++++++-- 2 files changed, 78 insertions(+), 4 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index d165ca3a5..bd33397d7 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -820,8 +820,25 @@ A `NamedTuple`-like structure with `VarName` keys. `VarNamedTuple` is a data structure for storing arbitrary data, keyed by `VarName`s, in an efficient and type stable manner. It is mainly used through `getindex`, `setindex!!`, and -`haskey`, all of which accept `VarName`s and only `VarName`s as keys. Anther notable methods -is `merge`, which recursively merges two `VarNamedTuple`s. +`haskey`, all of which accept `VarName`s and only `VarName`s as keys. Other notable methods +are `merge` and `subset`. + +`VarNamedTuple` has an ordering to its elements, and two `VarNamedTuple`s with the same keys +and values but in different orders are considered different for equality and hashing. +Iterations such as `keys` and `values` respect this ordering. The ordering is dependent on +the order in which elements were inserted into the `VarNamedTuple`, though isn't always +equal to it. More specifically + +* Any new keys that have a joint parent `VarName` with an existing key are inserted after + that key. For instance, if one first inserts, in order, `@varname(a.x)`, `@varname(b)`, + and `@varname(a.y)`, the resulting order will be + `(@varname(a.x), @varname(a.y), @varname(b))`. +* `IndexLens` keys`, like `@varname(a[3])` or `@varname(b[2,3,4:5])`, are always iterated + in the same order an `Array` with the same indices would be iterated. For instance, + if one first inserts, in order, `@varname(a[2])`, `@varname(b)`, and `@varname(a[1])`, + the resulting order will be `(@varname(a[1]), @varname(a[2]), @varname(b))`. + +Otherwise insertion order is respected. The there are two major limitations to indexing by VarNamedTuples: @@ -844,10 +861,37 @@ related to `VarName`s with `IndexLens` components. """ struct VarNamedTuple{Names,Values} data::NamedTuple{Names,Values} + + function VarNamedTuple(data::NamedTuple{Names,Values}) where {Names,Values} + return new{Names,Values}(data) + end end VarNamedTuple(; kwargs...) = VarNamedTuple((; kwargs...)) +""" + VarNamedTuple(d) + VarNamedTuple(nt::NamedTuple) + +Create a `VarNamedTuple` from a collection or a `NamedTuple`. + +Any collection `d` is assumed to be an iterable of key-value pairs, where the keys are +`VarName`s. This could be a an `AbstractDict`, a vector of `Pair`s or `Tuple`s, etc. The +only exception is `NamedTuple`s, for which the `Symbol` keys are converted to `VarName`s. + +Note that `VarNamedTuple` has an ordering to its elements, and two `VarNamedTuple`s with the +same keys and values but in different orders are considered different. If `d` does not +guarantee an iteration order, then the order of the elements in the resulting +`VarNamedTuple` is undefined. +""" +function VarNamedTuple(d) + vnt = VarNamedTuple() + for (k, v) in d + vnt = setindex!!(vnt, v, k) + end + return vnt +end + Base.:(==)(vnt1::VarNamedTuple, vnt2::VarNamedTuple) = vnt1.data == vnt2.data Base.isequal(vnt1::VarNamedTuple, vnt2::VarNamedTuple) = isequal(vnt1.data, vnt2.data) Base.hash(vnt::VarNamedTuple, h::UInt) = hash(vnt.data, h) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index fe0417f2b..5efd9fe49 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -1,6 +1,7 @@ module VarNamedTupleTests using Combinatorics: Combinatorics +using OrderedCollections: OrderedDict using Test: @inferred, @test, @test_throws, @testset using DynamicPPL: DynamicPPL, @varname, VarNamedTuple using DynamicPPL.VarNamedTuples: @@ -19,6 +20,7 @@ function test_invariants(vnt::VarNamedTuple) # These will be needed repeatedly. vnt_keys = keys(vnt) vnt_values = values(vnt) + # Check that for all keys in vnt, haskey is true, and resetting the value is a no-op. for k in vnt_keys @test haskey(vnt, k) @@ -34,6 +36,7 @@ function test_invariants(vnt::VarNamedTuple) @test isequal(vnt, vnt2) @test hash(vnt) == hash(vnt2) end + # Check that the printed representation can be parsed back to an equal VarNamedTuple. # The below eval test is a bit fragile: If any elements in vnt don't respect the same # reconstructability-from-repr property, this will fail. Likewise if any element uses @@ -44,27 +47,33 @@ function test_invariants(vnt::VarNamedTuple) @test equality === true || equality === missing @test isequal(vnt, vnt3) @test hash(vnt) == hash(vnt3) + # Check that merge with an empty VarNamedTuple is a no-op. @test isequal(merge(vnt, VarNamedTuple()), vnt) @test isequal(merge(VarNamedTuple(), vnt), vnt) + # Check that the VNT can be constructed back from its keys and values. vnt4 = VarNamedTuple() for (k, v) in zip(vnt_keys, vnt_values) vnt4 = setindex!!(vnt4, v, k) end @test isequal(vnt, vnt4) + # Check that vnt isempty only if it has no keys was_empty = isempty(vnt) - @test was_empty == isempty(vnt_keys) - @test was_empty == isempty(vnt_values) + @test isequal(was_empty, isempty(vnt_keys)) + @test isequal(was_empty, isempty(vnt_values)) + # Check that vnt can be emptied @test empty(vnt) === VarNamedTuple() emptied_vnt = empty!!(copy(vnt)) @test isempty(emptied_vnt) @test isempty(keys(emptied_vnt)) @test isempty(values(emptied_vnt)) + # Check that the copy protected the original vnt from being modified. @test isempty(vnt) == was_empty + # Check that map is a no-op when using identity functions. @test isequal(map_pairs!!(pair -> pair.second, copy(vnt)), vnt) @test isequal(map_values!!(identity, copy(vnt)), vnt) @@ -84,12 +93,33 @@ Base.size(st::SizedThing) = st.size vnt1 = setindex!!(vnt1, [1, 2, 3], @varname(b)) vnt1 = setindex!!(vnt1, "a", @varname(c.d.e)) test_invariants(vnt1) + vnt2 = VarNamedTuple(; a=1.0, b=[1, 2, 3], c=VarNamedTuple(; d=VarNamedTuple(; e="a")) ) test_invariants(vnt2) @test vnt1 == vnt2 + vnt3 = VarNamedTuple((; + a=1.0, b=[1, 2, 3], c=VarNamedTuple((; d=VarNamedTuple((; e="a")))) + )) + test_invariants(vnt3) + @test vnt1 == vnt3 + + vnt4 = VarNamedTuple( + OrderedDict( + @varname(a) => 1.0, @varname(b) => [1, 2, 3], @varname(c.d.e) => "a" + ), + ) + test_invariants(vnt4) + @test vnt1 == vnt4 + + vnt5 = VarNamedTuple(( + (@varname(a), 1.0), (@varname(b), [1, 2, 3]), (@varname(c.d.e), "a") + )) + test_invariants(vnt5) + @test vnt1 == vnt5 + pa1 = PartialArray{Float64,1}() pa1 = setindex!!(pa1, 1.0, 16) pa2 = PartialArray{Float64,1}(; min_size=(16,)) From c6d067720823792297d92242423c8f7e17c527e9 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 9 Jan 2026 11:24:39 +0000 Subject: [PATCH 103/148] Add VNT subset --- src/varnamedtuple.jl | 30 +++++++++++++++++++++++++++++- src/vntvarinfo.jl | 11 +++++++++++ test/varnamedtuple.jl | 42 +++++++++++++++++++++++++++++++++++++++++- 3 files changed, 81 insertions(+), 2 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index bd33397d7..cc7648447 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -6,7 +6,7 @@ using AbstractPPL: AbstractPPL using Distributions: Distributions, Distribution using BangBang using Accessors -using ..DynamicPPL: _compose_no_identity +using ..DynamicPPL: DynamicPPL, _compose_no_identity export VarNamedTuple, map_pairs!!, map_values!!, apply!! @@ -1060,6 +1060,31 @@ Base.merge(x1::VarNamedTuple, x2::VarNamedTuple) = _merge_recursive(x1, x2) return Expr(:block, exs...) end +""" + subset(vnt::VarNamedTuple, vns) + +Create a new `VarNamedTuple` containing only the variables subsumed by ones in `vns`. +""" +function DynamicPPL.subset(vnt::VarNamedTuple, vns) + # TODO(mhauru) This could be done more efficiently by generating the code directly, + # because we could short-circuit: For instance, if `vns` contains `a`, we could + # directly include the whole subtree under `a`, without checking each individual + # variable under it. + return mapfoldl( + identity, + function (init, pair) + name, value = pair + return if any(vn -> subsumes(vn, name), vns) + setindex!!(init, value, name) + else + init + end + end, + vnt; + init=VarNamedTuple(), + ) +end + """ apply!!(func, vnt::VarNamedTuple, name::VarName) @@ -1218,6 +1243,9 @@ function Base.mapreduce(f, op, vnt::VarNamedTuple; init=nothing) return _mapreduce_recursive(f, op, vnt, init) end +# Our mapreduce is always left-associative. +Base.mapfoldl(f, op, vnt::VarNamedTuple; init=nothing) = mapreduce(f, op, vnt; init=init) + _mapreduce_recursive(f, op, x, vn, init) = op(init, f(vn => x)) _mapreduce_recursive(f, op, pa::ArrayLikeBlock, vn, init) = op(init, f(vn => pa.block)) diff --git a/src/vntvarinfo.jl b/src/vntvarinfo.jl index 184fbd201..ad698d169 100644 --- a/src/vntvarinfo.jl +++ b/src/vntvarinfo.jl @@ -245,3 +245,14 @@ function unflatten(vi::VNTVarInfo, vec::AbstractVector) end return VNTVarInfo(new_values, vi.accs) end + +function subset(varinfo::VNTVarInfo, vns) + new_values = subset(varinfo.values, vns) + return VNTVarInfo(new_values, map(copy, getaccs(varinfo))) +end + +function Base.merge(varinfo_left::VNTVarInfo, varinfo_right::VNTVarInfo) + new_values = merge(varinfo_left.values, varinfo_right.values) + new_accs = map(copy, getaccs(varinfo_right)) + return VNTVarInfo(new_values, new_accs) +end diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 5efd9fe49..655b8e9e5 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -3,7 +3,7 @@ module VarNamedTupleTests using Combinatorics: Combinatorics using OrderedCollections: OrderedDict using Test: @inferred, @test, @test_throws, @testset -using DynamicPPL: DynamicPPL, @varname, VarNamedTuple +using DynamicPPL: DynamicPPL, @varname, VarNamedTuple, subset using DynamicPPL.VarNamedTuples: PartialArray, ArrayLikeBlock, map_pairs!!, map_values!!, apply!! using AbstractPPL: VarName, concretize, prefix @@ -77,6 +77,10 @@ function test_invariants(vnt::VarNamedTuple) # Check that map is a no-op when using identity functions. @test isequal(map_pairs!!(pair -> pair.second, copy(vnt)), vnt) @test isequal(map_values!!(identity, copy(vnt)), vnt) + + # Check that subsetting works as expected. + @test isequal(subset(vnt, vnt_keys), vnt) + @test isequal(subset(vnt, VarName[]), VarNamedTuple()) end """ A type that has a size but is not an Array. Used in ArrayLikeBlock tests.""" @@ -470,6 +474,42 @@ Base.size(st::SizedThing) = st.size test_invariants(vnt2) end + @testset "subset" begin + vnt = VarNamedTuple() + vnt = setindex!!(vnt, 1.0, @varname(a)) + vnt = setindex!!(vnt, [1, 2, 3], @varname(b)) + vnt = setindex!!(vnt, [10], @varname(c.x.y)) + vnt = setindex!!(vnt, :1, @varname(d[1])) + vnt = setindex!!(vnt, :2, @varname(d[2])) + vnt = setindex!!(vnt, :3, @varname(d[3])) + vnt = setindex!!(vnt, 2.0, @varname(e.f[3, 3].g.h[2, 4, 1].i)) + vnt = setindex!!(vnt, SizedThing((3, 1, 4)), @varname(p[2, 1][2:4, 5:5, 11:14])) + test_invariants(vnt) + + @test subset(vnt, VarName[]) == VarNamedTuple() + @test subset(vnt, (@varname(z),)) == VarNamedTuple() + @test subset(vnt, (@varname(d[4]),)) == VarNamedTuple() + # TODO(mhauru) Not sure what to do about the below. AbstractPPL considers d[1,1] to + # subsume d[1], but that breaks my idea of how VNT subset should work. + @test subset(vnt, (@varname(d[1, 1]),)) == VarNamedTuple() broken = true + @test subset(vnt, [@varname(a)]) == VarNamedTuple(; a=1.0) + @test subset(vnt, [@varname(b), @varname(d[1])]) == + VarNamedTuple((@varname(b) => [1, 2, 3], @varname(d[1]) => :1)) + @test subset(vnt, [@varname(d[2:3])]) == + VarNamedTuple((@varname(d[2]) => :2, @varname(d[3]) => :3)) + @test subset(vnt, [@varname(d)]) == VarNamedTuple(( + @varname(d[1]) => :1, @varname(d[2]) => :2, @varname(d[3]) => :3 + )) + @test subset(vnt, [@varname(c.x.y)]) == VarNamedTuple((@varname(c.x.y) => [10],)) + @test subset(vnt, [@varname(c)]) == VarNamedTuple((@varname(c.x.y) => [10],)) + @test subset(vnt, [@varname(e.f[3, 3].g.h[2, 4, 1].i)]) == + VarNamedTuple((@varname(e.f[3, 3].g.h[2, 4, 1].i) => 2.0,)) + @test subset(vnt, [@varname(p[2, 1][2:4, 5:5, 11:14])]) == + VarNamedTuple((@varname(p[2, 1][2:4, 5:5, 11:14]) => SizedThing((3, 1, 4)),)) + # Cutting the last range a bit short should mean that nothing is returned. + @test subset(vnt, [@varname(p[2, 1][2:4, 5:5, 11:13])]) == VarNamedTuple() + end + @testset "keys and values" begin vnt = VarNamedTuple() @test @inferred(keys(vnt)) == VarName[] From c18258cfbc717452ffdc550943c9f0c26c85a5be Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 9 Jan 2026 15:45:18 +0000 Subject: [PATCH 104/148] Make _compose_no_identity handle typed_identity too --- src/utils.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index ed9f3aa13..11261334c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -951,6 +951,8 @@ Return `typeof(x)` stripped of its type parameters. """ basetypeof(x::T) where {T} = Base.typename(T).wrapper +const MaybeTypedIdentity = Union{typeof(typed_identity),typeof(identity)} + # TODO(mhauru) Might add another specialisation to _compose_no_identity, where if # ReshapeTransforms are composed with each other or with a an UnwrapSingeltonTransform, only # the latter one would be kept. @@ -963,6 +965,6 @@ This helps avoid trivial cases of `ComposedFunction` that would cause unnecessar conflicts. """ _compose_no_identity(f, g) = f ∘ g -_compose_no_identity(::typeof(identity), g) = g -_compose_no_identity(f, ::typeof(identity)) = f -_compose_no_identity(::typeof(identity), ::typeof(identity)) = identity +_compose_no_identity(::MaybeTypedIdentity, g) = g +_compose_no_identity(f, ::MaybeTypedIdentity) = f +_compose_no_identity(::MaybeTypedIdentity, ::MaybeTypedIdentity) = typed_identity From b91e6ff2f31c0cef2f22e9ae7e283512ceabb20a Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 9 Jan 2026 16:00:40 +0000 Subject: [PATCH 105/148] Myriad improvements to VNTVarInfo, overhaul varinfo.jl tests to use VNTVarInfo only --- src/DynamicPPL.jl | 2 +- src/test_utils/varinfo.jl | 42 +---- src/threadsafe.jl | 7 +- src/vntvarinfo.jl | 195 +++++++++++++++------ test/varinfo.jl | 351 +++++++++++--------------------------- 5 files changed, 249 insertions(+), 348 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 5b831e100..7c1f53081 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -185,7 +185,7 @@ abstract type AbstractVarInfo <: AbstractModelTrace end # Necessary forward declarations include("utils.jl") include("varnamedtuple.jl") -using .VarNamedTuples: VarNamedTuple, map!!, apply!! +using .VarNamedTuples: VarNamedTuple, map_pairs!!, map_values!!, apply!! include("contexts.jl") include("contexts/default.jl") include("contexts/init.jl") diff --git a/src/test_utils/varinfo.jl b/src/test_utils/varinfo.jl index 79b92ce13..25f4fd04f 100644 --- a/src/test_utils/varinfo.jl +++ b/src/test_utils/varinfo.jl @@ -33,40 +33,14 @@ of the varinfo instances. function setup_varinfos( model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false ) - # # VarInfo - # vi_untyped_metadata = DynamicPPL.untyped_varinfo(model) - # vi_untyped_vnv = DynamicPPL.untyped_vector_varinfo(model) - # vi_typed_metadata = DynamicPPL.typed_varinfo(model) - # vi_typed_vnv = DynamicPPL.typed_vector_varinfo(model) - - # # SimpleVarInfo - # svi_typed = SimpleVarInfo(example_values) - # svi_untyped = SimpleVarInfo(OrderedDict{VarName,Any}()) - # svi_vnv = SimpleVarInfo(DynamicPPL.VarNamedVector()) - - # varinfos = map(( - # vi_untyped_metadata, - # vi_untyped_vnv, - # vi_typed_metadata, - # vi_typed_vnv, - # svi_typed, - # svi_untyped, - # svi_vnv, - # )) do vi - # # Set them all to the same values and evaluate logp. - # vi = update_values!!(vi, example_values, varnames) - # last(DynamicPPL.evaluate!!(model, vi)) - # end - # - varinfos = map((DynamicPPL.typed_varinfo(model),)) do vi - # Set them all to the same values and evaluate logp. - vi = update_values!!(vi, example_values, varnames) - last(DynamicPPL.evaluate!!(model, vi)) + vi = DynamicPPL.VarInfo(model) + vi = update_values!!(vi, example_values, varnames) + last(DynamicPPL.evaluate!!(model, vi)) + + varinfos = if include_threadsafe + (vi, DynamicPPL.ThreadSafeVarInfo(deepcopy(vi))) + else + (vi,) end - - if include_threadsafe - varinfos = (varinfos..., map(DynamicPPL.ThreadSafeVarInfo ∘ deepcopy, varinfos)...) - end - return varinfos end diff --git a/src/threadsafe.jl b/src/threadsafe.jl index c7ab106a2..44b4da316 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -36,6 +36,9 @@ function getacc(vi::ThreadSafeVarInfo, accname::Val) return foldl(combine, other_accs; init=main_acc) end +function Base.copy(vi::ThreadSafeVarInfo) + return ThreadSafeVarInfo(copy(vi.varinfo), deepcopy(vi.accs_by_thread)) +end hasacc(vi::ThreadSafeVarInfo, accname::Val) = hasacc(vi.varinfo, accname) acckeys(vi::ThreadSafeVarInfo) = acckeys(vi.varinfo) @@ -195,8 +198,8 @@ end getindex_internal(vi::ThreadSafeVarInfo, vn::VarName) = getindex_internal(vi.varinfo, vn) -function unflatten(vi::ThreadSafeVarInfo, x::AbstractVector) - return Accessors.@set vi.varinfo = unflatten(vi.varinfo, x) +function unflatten!!(vi::ThreadSafeVarInfo, x::AbstractVector) + return Accessors.@set vi.varinfo = unflatten!!(vi.varinfo, x) end function subset(varinfo::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) diff --git a/src/vntvarinfo.jl b/src/vntvarinfo.jl index ad698d169..1ae9bc9d8 100644 --- a/src/vntvarinfo.jl +++ b/src/vntvarinfo.jl @@ -31,6 +31,8 @@ setaccs!!(vi::VNTVarInfo, accs::AccumulatorTuple) = VNTVarInfo(vi.values, accs) transformation(::VNTVarInfo) = DynamicTransformation() +Base.copy(vi::VNTVarInfo) = VNTVarInfo(copy(vi.values), copy(getaccs(vi))) + Base.haskey(vi::VNTVarInfo, vn::VarName) = haskey(vi.values, vn) Base.length(vi::VNTVarInfo) = length(vi.values) @@ -40,25 +42,37 @@ function Base.getindex(vi::VNTVarInfo, vn::VarName) return tv.transform(tv.val) end +function Base.getindex(vi::VNTVarInfo, vn::VarName, dist::Distribution) + val = getindex_internal(vi, vn) + return from_maybe_linked_internal(vi, vn, dist, val) +end + Base.isempty(vi::VNTVarInfo) = isempty(vi.values) +Base.empty(vi::VNTVarInfo) = VNTVarInfo(empty(vi.values), map(reset, vi.accs)) +BangBang.empty!!(vi::VNTVarInfo) = VNTVarInfo(empty!!(vi.values), map(reset, vi.accs)) -# TODO(mhauru) This should be called setindex_internal!!, but that's not the current -# convention. -function BangBang.setindex!!(vi::VNTVarInfo, val, vn::VarName) +function setindex_internal!!(vi::VNTVarInfo, val, vn::VarName) old_tv = getindex(vi.values, vn) new_tv = TransformedValue(val, old_tv.linked, old_tv.transform) new_values = setindex!!(vi.values, new_tv, vn) return VNTVarInfo(new_values, vi.accs) end +BangBang.setindex!!(vi::VNTVarInfo, val, vn::VarName) = push!!(vi, vn, val) + # TODO(mhauru) The arguments are in the wrong order, but this is the current convetion. function BangBang.push!!(vi::VNTVarInfo, vn::VarName, val, transform=typed_identity) + # TODO(mhauru) We should move away from having all values vectorised by default. + # That messes with our use of unflatten though, so will require some thought. + transform = _compose_no_identity(transform, from_vec_transform(val)) + val = to_vec_transform(val)(val) new_tv = TransformedValue(val, false, transform) new_values = setindex!!(vi.values, new_tv, vn) return VNTVarInfo(new_values, vi.accs) end Base.keys(vi::VNTVarInfo) = keys(vi.values) +Base.values(vi::VNTVarInfo) = mapreduce(p -> p.second.val, push!, vi.values; init=Any[]) function set_transformed!!(vi::VNTVarInfo, linked::Bool, vn::VarName) old_tv = getindex(vi.values, vn) @@ -68,7 +82,7 @@ function set_transformed!!(vi::VNTVarInfo, linked::Bool, vn::VarName) end function set_transformed!!(vi::VNTVarInfo, linked::Bool) - new_values = map!!(vi.values) do tv + new_values = map_values!!(vi.values) do tv TransformedValue(tv.val, linked, tv.transform) end return VNTVarInfo(new_values, vi.accs) @@ -79,6 +93,8 @@ function getindex_internal(vi::VNTVarInfo, vn::VarName) return tv.val end +# TODO(mhauru) This is mimicing old behaviour, but is now wrong: The internal +# representation does not have to be a Vector. getindex_internal(vi::VNTVarInfo, ::Colon) = values_as(vi, Vector) function is_transformed(vi::VNTVarInfo, vn::VarName) @@ -86,15 +102,16 @@ function is_transformed(vi::VNTVarInfo, vn::VarName) return tv.linked end -# TODO(mhauru) Other VarInfos have something like this. Do we need it? -# function from_internal_transform(::VNTVarInfo, ::VarName, dist::Distribution) -# return from_vec_transform(dist) -# end - -function from_internal_transform(vi::VNTVarInfo, vn::VarName, ::Distribution) - return getindex(vi.values, vn).transform +# TODO(mhauru) Other VarInfos have something like this. Do we need it? Or should we use the +# below version? +function from_internal_transform(::VNTVarInfo, ::VarName, dist::Distribution) + return from_vec_transform(dist) end +# function from_internal_transform(vi::VNTVarInfo, vn::VarName, ::Distribution) +# return getindex(vi.values, vn).transform +# end + function from_linked_internal_transform(::VNTVarInfo, ::VarName, dist::Distribution) return from_linked_vec_transform(dist) end @@ -113,14 +130,17 @@ function link!!(::DynamicTransformation, vi::VNTVarInfo, vns, model::Model) dists = extract_priors(model, vi) cumulative_logjac = zero(LogProbType) new_values = vi.values - for vn in vns - new_values = apply!!(new_values, vn) do tv - dist = getindex(dists, vn) - transform = from_linked_vec_transform(dist) - new_tv, logjac = change_transform(tv, transform, true) - cumulative_logjac += logjac - return new_tv + new_values = map_pairs!!(new_values) do pair + vn, tv = pair + if !any(x -> subsumes(x, vn), vns) + # Not one of the target variables. + return tv end + dist = getindex(dists, vn) + transform = from_linked_vec_transform(dist) + new_tv, logjac = change_transform(tv, transform, true) + cumulative_logjac += logjac + return new_tv end vi = VNTVarInfo(new_values, vi.accs) if hasacc(vi, Val(:LogJacobian)) @@ -135,15 +155,13 @@ function link!!(::DynamicTransformation, vi::VNTVarInfo, model::Model) dists = extract_priors(model, vi) cumulative_logjac = zero(LogProbType) new_values = vi.values - vns = keys(vi) - for vn in vns - new_values = apply!!(new_values, vn) do tv - dist = getindex(dists, vn) - transform = from_linked_vec_transform(dist) - new_tv, logjac = change_transform(tv, transform, true) - cumulative_logjac += logjac - return new_tv - end + new_values = map_pairs!!(new_values) do pair + vn, tv = pair + dist = getindex(dists, vn) + transform = from_linked_vec_transform(dist) + new_tv, logjac = change_transform(tv, transform, true) + cumulative_logjac += logjac + return new_tv end vi = VNTVarInfo(new_values, vi.accs) if hasacc(vi, Val(:LogJacobian)) @@ -155,13 +173,17 @@ end function invlink!!(::DynamicTransformation, vi::VNTVarInfo, vns, model::Model) cumulative_logjac = zero(LogProbType) new_values = vi.values - for vn in vns - new_values = apply!!(new_values, vn) do tv - transform = typed_identity - new_tv, logjac = change_transform(tv, transform, false) - cumulative_logjac += logjac - return new_tv + new_values = map_pairs!!(new_values) do pair + vn, tv = pair + if !any(x -> subsumes(x, vn), vns) + # Not one of the target variables. + return tv end + current_val = tv.transform(tv.val) + transform = from_vec_transform(current_val) + new_tv, logjac = change_transform(tv, transform, false) + cumulative_logjac += logjac + return new_tv end vi = VNTVarInfo(new_values, vi.accs) if hasacc(vi, Val(:LogJacobian)) @@ -175,14 +197,12 @@ function invlink!!(::DynamicTransformation, vi::VNTVarInfo, model::Model) # map!!, but it doesn't have access to the VarName. cumulative_logjac = zero(LogProbType) new_values = vi.values - vns = keys(vi) - for vn in vns - new_values = apply!!(new_values, vn) do tv - transform = typed_identity - new_tv, logjac = change_transform(tv, transform, false) - cumulative_logjac += logjac - return new_tv - end + new_values = map_values!!(new_values) do tv + current_val = tv.transform(tv.val) + transform = from_vec_transform(current_val) + new_tv, logjac = change_transform(tv, transform, false) + cumulative_logjac += logjac + return new_tv end vi = VNTVarInfo(new_values, vi.accs) if hasacc(vi, Val(:LogJacobian)) @@ -191,10 +211,54 @@ function invlink!!(::DynamicTransformation, vi::VNTVarInfo, model::Model) return vi end +function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VNTVarInfo}, model::Model) + # By default this will simply evaluate the model with `DynamicTransformationContext`, + # and so we need to specialize to avoid this. + return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, model) +end + +function link!!( + t::DynamicTransformation, + vi::ThreadSafeVarInfo{<:VNTVarInfo}, + vns::VarNameTuple, + model::Model, +) + # By default this will simply evaluate the model with `DynamicTransformationContext`, + # and so we need to specialize to avoid this. + return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, vns, model) +end + +function invlink!!( + t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VNTVarInfo}, model::Model +) + # By default this will simply evaluate the model with `DynamicTransformationContext`, + # and so we need to specialize to avoid this. + return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(t, vi.varinfo, model) +end + +function invlink!!( + ::DynamicTransformation, + vi::ThreadSafeVarInfo{<:VNTVarInfo}, + vns::VarNameTuple, + model::Model, +) + # By default this will simply evaluate the model with `DynamicTransformationContext`, and so + # we need to specialize to avoid this. + return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(vi.varinfo, vns, model) +end + # TODO(mhauru) I don't think this should return the internal values, but that's the current # convention. function values_as(vi::VNTVarInfo, ::Type{Vector}) - return mapreduce(tv -> tovec(tv.val), vcat, vi.values; init=Union{}[]) + return mapfoldl(pair -> tovec(pair.second.val), vcat, vi.values; init=Union{}[]) +end + +function values_as(vi::VNTVarInfo, ::Type{T}) where {T<:AbstractDict} + return mapfoldl(identity, function (cumulant, pair) + vn, tv = pair + val = tv.transform(tv.val) + return setindex!!(cumulant, val, vn) + end, vi.values; init=T()) end # TODO(mhauru) These two are now redundant, just conforming to the old interface @@ -225,22 +289,41 @@ function untyped_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitF return untyped_varinfo(Random.default_rng(), model, init_strategy) end -function unflatten(vi::VNTVarInfo, vec::AbstractVector) - index = 1 - new_values = map!!(vi.values) do tv - # TODO(mhauru) This is quite crude, assuming that the value stored currently is - # an AbstractArray of some kind that has a size, and that reshape makes sense here. - # I may fix this later, but I'm also tempted to just get rid of unflatten entirely. - # This works for now for making most tests pass. +""" + VectorChunkIterator{T<:AbstractVector} + +A tiny struct for getting chunks of a vector sequentially. + +The only function provided is `get_next_chunk!`, which takes a length and returns +a view into the next chunk of that length, updating the internal index. +""" +mutable struct VectorChunkIterator{T<:AbstractVector} + vec::T + index::Int +end + +function get_next_chunk!(vci::VectorChunkIterator, len::Int) + i = vci.index + chunk = @view vci.vec[i:(i + len - 1)] + vci.index += len + return chunk +end + +function unflatten!!(vi::VNTVarInfo, vec::AbstractVector) + # You may wonder, why have a whole struct for this, rather than just an index variable + # that the mapping function would close over. I wonder too. But for some reason type + # inference fails on such an index variable, turning it into a Core.Box. + vci = VectorChunkIterator(vec, 1) + new_values = map_values!!(vi.values) do tv old_val = tv.val - len = length(old_val) - new_val = reshape(vec[index:(index + len - 1)], size(old_val)) - # If the old_val was a scalar then new_val is a 0-dimensional array. - # Convert it to a scalar. - if !(old_val isa AbstractArray) && length(old_val) == 1 - new_val = new_val[1] + if !(old_val isa AbstractVector) + error( + "Can not unflatten a VarInfo for which existing values are not vectors:" * + " Got value of type $(typeof(old_val)).", + ) end - index += len + len = length(old_val) + new_val = get_next_chunk!(vci, len) return TransformedValue(new_val, tv.linked, tv.transform) end return VNTVarInfo(new_values, vi.accs) diff --git a/test/varinfo.jl b/test/varinfo.jl index a7948cc32..0a8e58eef 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -14,124 +14,59 @@ function check_varinfo_keys(varinfo, vns) end end -""" -Return the value of `vn` in `vi`. If one doesn't exist, sample and set it. -""" -function randr(vi::DynamicPPL.VarInfo, vn::VarName, dist::Distribution) - if !haskey(vi, vn) - r = rand(dist) - push!!(vi, vn, r, dist) - r - else - vi[vn] - end -end - @testset "varinfo.jl" begin - @testset "VarInfo with NT of Metadata" begin - @model gdemo(x, y) = begin - s ~ InverseGamma(2, 3) - m ~ truncated(Normal(0.0, sqrt(s)), 0.0, 2.0) - x ~ Normal(m, sqrt(s)) - y ~ Normal(m, sqrt(s)) - end - model = gdemo(1.0, 2.0) - - _, vi = DynamicPPL.init!!(model, VarInfo(), InitFromUniform()) - tvi = DynamicPPL.typed_varinfo(vi) - - meta = vi.metadata - for f in fieldnames(typeof(tvi.metadata)) - fmeta = getfield(tvi.metadata, f) - for vn in fmeta.vns - @test tvi[vn] == vi[vn] - ind = meta.idcs[vn] - tind = fmeta.idcs[vn] - @test meta.dists[ind] == fmeta.dists[tind] - @test meta.is_transformed[ind] == fmeta.is_transformed[tind] - range = meta.ranges[ind] - trange = fmeta.ranges[tind] - @test all(meta.vals[range] .== fmeta.vals[trange]) - end - end - end - @testset "Base" begin # Test Base functions: # in, keys, haskey, isempty, push!!, empty!!, # getindex, setindex!, getproperty, setproperty! - function test_base(vi_original) - vi = deepcopy(vi_original) - @test getlogjoint(vi) == 0 - @test isempty(vi[:]) - - vn = @varname x - dist = Normal(0, 1) - r = rand(dist) - - @test isempty(vi) - @test !haskey(vi, vn) - @test !(vn in keys(vi)) - vi = push!!(vi, vn, r, dist) - @test !isempty(vi) - @test haskey(vi, vn) - @test vn in keys(vi) - - @test length(vi[vn]) == 1 - @test vi[vn] == r - @test vi[:] == [r] - vi = DynamicPPL.setindex!!(vi, 2 * r, vn) - @test vi[vn] == 2 * r - @test vi[:] == [2 * r] - - # TODO(mhauru) Implement these functions for other VarInfo types too. - if vi isa DynamicPPL.UntypedVectorVarInfo - delete!(vi, vn) - @test isempty(vi) - vi = push!!(vi, vn, r, dist) - end - - vi = empty!!(vi) - @test isempty(vi) - vi = push!!(vi, vn, r, dist) - @test !isempty(vi) - end - - test_base(VarInfo()) - test_base(DynamicPPL.typed_varinfo(VarInfo())) - test_base(SimpleVarInfo()) - test_base(SimpleVarInfo(OrderedDict{VarName,Any}())) - test_base(SimpleVarInfo(DynamicPPL.VarNamedVector())) + vi = VarInfo() + @test getlogjoint(vi) == 0 + @test isempty(vi[:]) + + vn = @varname x + r = rand() + + @test isempty(vi) + @test !haskey(vi, vn) + @test !(vn in keys(vi)) + vi = push!!(vi, vn, r) + @test !isempty(vi) + @test haskey(vi, vn) + @test vn in keys(vi) + + @test length(vi[vn]) == 1 + @test vi[vn] == r + @test vi[:] == [r] + vi = DynamicPPL.setindex!!(vi, 2 * r, vn) + @test vi[vn] == 2 * r + @test vi[:] == [2 * r] + + vi = empty!!(vi) + @test isempty(vi) + vi = push!!(vi, vn, r) + @test !isempty(vi) end @testset "get/set/acclogp" begin - function test_varinfo_logp!(vi) - @test DynamicPPL.getlogjoint(vi) === 0.0 - vi = DynamicPPL.setlogprior!!(vi, 1.0) - @test DynamicPPL.getlogprior(vi) === 1.0 - @test DynamicPPL.getloglikelihood(vi) === 0.0 - @test DynamicPPL.getlogjoint(vi) === 1.0 - vi = DynamicPPL.acclogprior!!(vi, 1.0) - @test DynamicPPL.getlogprior(vi) === 2.0 - @test DynamicPPL.getloglikelihood(vi) === 0.0 - @test DynamicPPL.getlogjoint(vi) === 2.0 - vi = DynamicPPL.setloglikelihood!!(vi, 1.0) - @test DynamicPPL.getlogprior(vi) === 2.0 - @test DynamicPPL.getloglikelihood(vi) === 1.0 - @test DynamicPPL.getlogjoint(vi) === 3.0 - vi = DynamicPPL.accloglikelihood!!(vi, 1.0) - @test DynamicPPL.getlogprior(vi) === 2.0 - @test DynamicPPL.getloglikelihood(vi) === 2.0 - @test DynamicPPL.getlogjoint(vi) === 4.0 - end - vi = VarInfo() - test_varinfo_logp!(vi) - test_varinfo_logp!(DynamicPPL.typed_varinfo(vi)) - test_varinfo_logp!(SimpleVarInfo()) - test_varinfo_logp!(SimpleVarInfo(OrderedDict())) - test_varinfo_logp!(SimpleVarInfo(DynamicPPL.VarNamedVector())) + @test DynamicPPL.getlogjoint(vi) === 0.0 + vi = DynamicPPL.setlogprior!!(vi, 1.0) + @test DynamicPPL.getlogprior(vi) === 1.0 + @test DynamicPPL.getloglikelihood(vi) === 0.0 + @test DynamicPPL.getlogjoint(vi) === 1.0 + vi = DynamicPPL.acclogprior!!(vi, 1.0) + @test DynamicPPL.getlogprior(vi) === 2.0 + @test DynamicPPL.getloglikelihood(vi) === 0.0 + @test DynamicPPL.getlogjoint(vi) === 2.0 + vi = DynamicPPL.setloglikelihood!!(vi, 1.0) + @test DynamicPPL.getlogprior(vi) === 2.0 + @test DynamicPPL.getloglikelihood(vi) === 1.0 + @test DynamicPPL.getlogjoint(vi) === 3.0 + vi = DynamicPPL.accloglikelihood!!(vi, 1.0) + @test DynamicPPL.getlogprior(vi) === 2.0 + @test DynamicPPL.getloglikelihood(vi) === 2.0 + @test DynamicPPL.getlogjoint(vi) === 4.0 end @testset "logp accumulators" begin @@ -150,7 +85,7 @@ end lp_d = logpdf(Normal(), values.d) m = demo() | (; c=values.c, d=values.d) - vi = DynamicPPL.unflatten(VarInfo(m), collect(values)) + vi = DynamicPPL.unflatten!!(VarInfo(m), collect(values)) vi = last(DynamicPPL.evaluate!!(m, deepcopy(vi))) @test getlogprior(vi) == lp_a + lp_b @@ -284,39 +219,23 @@ end end @testset "is_transformed flag" begin - # Test is_transformed and set_transformed!! - function test_varinfo!(vi) - vn_x = @varname x - dist = Normal(0, 1) - r = rand(dist) - - push!!(vi, vn_x, r, dist) + vi = VarInfo() + vn_x = @varname x + r = rand() - # is_transformed is set by default - @test !is_transformed(vi, vn_x) + vi = push!!(vi, vn_x, r) - vi = set_transformed!!(vi, true, vn_x) - @test is_transformed(vi, vn_x) + # is_transformed is unset by default + @test !is_transformed(vi, vn_x) - vi = set_transformed!!(vi, false, vn_x) - @test !is_transformed(vi, vn_x) - end - vi = VarInfo() - test_varinfo!(vi) - test_varinfo!(empty!!(DynamicPPL.typed_varinfo(vi))) - end + vi = set_transformed!!(vi, true, vn_x) + @test is_transformed(vi, vn_x) - @testset "push!! to VarInfo with NT of Metadata" begin - vn_x = @varname x - vn_y = @varname y - untyped_vi = VarInfo() - untyped_vi = push!!(untyped_vi, vn_x, 1.0, Normal(0, 1)) - typed_vi = DynamicPPL.typed_varinfo(untyped_vi) - typed_vi = push!!(typed_vi, vn_y, 2.0, Normal(0, 1)) - @test typed_vi[vn_x] == 1.0 - @test typed_vi[vn_y] == 2.0 + vi = set_transformed!!(vi, false, vn_x) + @test !is_transformed(vi, vn_x) end + # TODO(mhauru) Move this to a different file. @testset "returned on MCMCChains.Chains" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.ALL_MODELS chain = make_chain_from_prior(model, 10) @@ -354,39 +273,23 @@ end # change the VarInfo object. # TODO(penelopeysm): Move this to InitFromUniform tests rather than here. vi = VarInfo() - meta = vi.metadata _, vi = DynamicPPL.init!!(model, vi, InitFromUniform()) - @test all(x -> !is_transformed(vi, x), meta.vns) + vals = values(vi) + + all_transformed(vi) = mapreduce(p -> p.second.linked, &, vi.values; init=true) + any_transformed(vi) = mapreduce(p -> p.second.linked, |, vi.values; init=false) + + @test !any_transformed(vi) # Check that linking and invlinking set the `is_transformed` flag accordingly - v = copy(meta.vals) vi = link!!(vi, model) - @test all(x -> is_transformed(vi, x), meta.vns) + @test all_transformed(vi) vi = invlink!!(vi, model) - @test all(x -> !is_transformed(vi, x), meta.vns) - @test meta.vals ≈ v atol = 1e-10 - - # Check that linking and invlinking preserves the values - vi = DynamicPPL.typed_varinfo(vi) - meta = vi.metadata - v_s = copy(meta.s.vals) - v_m = copy(meta.m.vals) - v_x = copy(meta.x.vals) - v_y = copy(meta.y.vals) - - @test all(x -> !is_transformed(vi, x), meta.s.vns) - @test all(x -> !is_transformed(vi, x), meta.m.vns) - vi = link!!(vi, model) - @test all(x -> is_transformed(vi, x), meta.s.vns) - @test all(x -> is_transformed(vi, x), meta.m.vns) - vi = invlink!!(vi, model) - @test all(x -> !is_transformed(vi, x), meta.s.vns) - @test all(x -> !is_transformed(vi, x), meta.m.vns) - @test meta.s.vals ≈ v_s atol = 1e-10 - @test meta.m.vals ≈ v_m atol = 1e-10 + @test !any_transformed(vi) + @test values(vi) ≈ vals atol = 1e-10 # Transform only one variable - all_vns = vcat(meta.s.vns, meta.m.vns, meta.x.vns, meta.y.vns) + all_vns = keys(vi) for vn in [ @varname(s), @varname(m), @@ -400,14 +303,11 @@ end @test !isempty(target_vns) @test !isempty(other_vns) vi = link!!(vi, (vn,), model) - @test all(x -> is_transformed(vi, x), target_vns) - @test all(x -> !is_transformed(vi, x), other_vns) + @test all_transformed(subset(vi, target_vns)) + @test !any_transformed(subset(vi, other_vns)) vi = invlink!!(vi, (vn,), model) - @test all(x -> !is_transformed(vi, x), all_vns) - @test meta.s.vals ≈ v_s atol = 1e-10 - @test meta.m.vals ≈ v_m atol = 1e-10 - @test meta.x.vals ≈ v_x atol = 1e-10 - @test meta.y.vals ≈ v_y atol = 1e-10 + @test !any_transformed(vi) + @test values(vi) ≈ vals atol = 1e-10 end end @@ -417,46 +317,17 @@ end vn = @varname(x) dist = truncated(Normal(); lower=0) - function test_linked_varinfo(model, vi) - # vn and dist are taken from the containing scope - vi = last(DynamicPPL.init!!(model, vi, InitFromPrior())) - f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) - x = f(DynamicPPL.getindex_internal(vi, vn)) - @test is_transformed(vi, vn) - @test getlogjoint_internal(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) - @test getlogprior_internal(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) - @test getloglikelihood(vi) == 0.0 - @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, false) - @test getlogprior(vi) ≈ Bijectors.logpdf_with_trans(dist, x, false) - end - - ### `VarInfo` - # Need to run once since we can't specify that we want to _sample_ - # in the unconstrained space for `VarInfo` without having `vn` - # present in the `varinfo`. - - ## `untyped_varinfo` - vi = DynamicPPL.untyped_varinfo(model) - vi = DynamicPPL.set_transformed!!(vi, true, vn) - test_linked_varinfo(model, vi) - - ## `typed_varinfo` - vi = DynamicPPL.typed_varinfo(model) + vi = DynamicPPL.VarInfo(model) vi = DynamicPPL.set_transformed!!(vi, true, vn) - test_linked_varinfo(model, vi) - - ### `SimpleVarInfo` - ## `SimpleVarInfo{<:NamedTuple}` - vi = DynamicPPL.set_transformed!!(SimpleVarInfo(), true) - test_linked_varinfo(model, vi) - - ## `SimpleVarInfo{<:Dict}` - vi = DynamicPPL.set_transformed!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true) - test_linked_varinfo(model, vi) - - ## `SimpleVarInfo{<:VarNamedVector}` - vi = DynamicPPL.set_transformed!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) - test_linked_varinfo(model, vi) + vi = last(DynamicPPL.init!!(model, vi, InitFromPrior())) + f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) + x = f(DynamicPPL.getindex_internal(vi, vn)) + @test is_transformed(vi, vn) + @test getlogjoint_internal(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + @test getlogprior_internal(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + @test getloglikelihood(vi) == 0.0 + @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, false) + @test getlogprior(vi) ≈ Bijectors.logpdf_with_trans(dist, x, false) end @testset "values_as" begin @@ -471,32 +342,16 @@ end @testset "$(short_varinfo_name(vi))" for vi in varinfos # Just making sure. DynamicPPL.TestUtils.test_values(vi, example_values, vns) - - @testset "NamedTuple" begin - vals = values_as(vi, NamedTuple) - for vn in vns - if haskey(vals, Symbol(vn)) - # Assumed to be of form `(var"m[1]" = 1.0, ...)`. - @test getindex(vals, Symbol(vn)) == getindex(vi, vn) - else - # Assumed to be of form `(m = [1.0, ...], ...)`. - @test get(vals, vn) == getindex(vi, vn) - end - end + vals = values_as(vi, OrderedDict) + # All varnames in `vns` should be subsumed by one of `keys(vals)`. + @test all(vns) do vn + any(DynamicPPL.subsumes(vn_left, vn) for vn_left in keys(vals)) end - - @testset "OrderedDict" begin - vals = values_as(vi, OrderedDict) - # All varnames in `vns` should be subsumed by one of `keys(vals)`. - @test all(vns) do vn - any(DynamicPPL.subsumes(vn_left, vn) for vn_left in keys(vals)) - end - # Iterate over `keys(vals)` because we might have scenarios such as - # `vals = OrderedDict(@varname(m) => [1.0])` but `@varname(m[1])` is - # the varname present in `vns`, not `@varname(m)`. - for vn in keys(vals) - @test getindex(vals, vn) == getindex(vi, vn) - end + # Iterate over `keys(vals)` because we might have scenarios such as + # `vals = OrderedDict(@varname(m) => [1.0])` but `@varname(m[1])` is + # the varname present in `vns`, not `@varname(m)`. + for vn in keys(vals) + @test getindex(vals, vn) == getindex(vi, vn) end end end @@ -546,8 +401,8 @@ end @test DynamicPPL.is_transformed(varinfo_linked, vn) end @test length(varinfo[:]) > length(varinfo_linked[:]) - varinfo_linked_unflattened = DynamicPPL.unflatten( - varinfo_linked, varinfo_linked[:] + varinfo_linked_unflattened = DynamicPPL.unflatten!!( + copy(varinfo_linked), varinfo_linked[:] ) @test length(varinfo_linked_unflattened[:]) == length(varinfo_linked[:]) @@ -591,13 +446,7 @@ end model, (; x=1.0), (@varname(x),); include_threadsafe=true ) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - # Skip the inconcrete `SimpleVarInfo` types, since checking for type - # stability for them doesn't make much sense anyway. - if varinfo isa SimpleVarInfo{<:AbstractDict} || - varinfo isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo{<:AbstractDict}} - continue - end - @inferred DynamicPPL.unflatten(varinfo, varinfo[:]) + @inferred DynamicPPL.unflatten!!(varinfo, varinfo[:]) end end @@ -718,15 +567,6 @@ end @test varinfo_subset[:] == ground_truth end end - - # For certain varinfos we should have errors. - # `SimpleVarInfo{<:NamedTuple}` can only handle varnames with `identity`. - varinfo = varinfos[findfirst(Base.Fix2(isa, SimpleVarInfo{<:NamedTuple}), varinfos)] - @testset "$(short_varinfo_name(varinfo)): failure cases" begin - @test_throws ArgumentError subset( - varinfo, [@varname(s), @varname(m), @varname(x[1])] - ) - end end @testset "merge" begin @@ -817,9 +657,9 @@ end @testset "merge different dimensions" begin vn = @varname(x) vi_single = VarInfo() - vi_single = push!!(vi_single, vn, 1.0, Normal()) + vi_single = push!!(vi_single, vn, 1.0) vi_double = VarInfo() - vi_double = push!!(vi_double, vn, [0.5, 0.6], Dirichlet(2, 1.0)) + vi_double = push!!(vi_double, vn, [0.5, 0.6]) @test merge(vi_single, vi_double)[vn] == [0.5, 0.6] @test merge(vi_double, vi_single)[vn] == 1.0 end @@ -830,8 +670,9 @@ end n = length(varinfo[:]) # `Bool`. - @test getlogjoint(DynamicPPL.unflatten(varinfo, fill(true, n))) isa typeof(float(1)) + @test getlogjoint(DynamicPPL.unflatten!!(varinfo, fill(true, n))) isa + typeof(float(1)) # `Int`. - @test getlogjoint(DynamicPPL.unflatten(varinfo, fill(1, n))) isa typeof(float(1)) + @test getlogjoint(DynamicPPL.unflatten!!(varinfo, fill(1, n))) isa typeof(float(1)) end end From 8018f451a6207a4908f1886ab928c25751dd712b Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 9 Jan 2026 17:09:22 +0000 Subject: [PATCH 106/148] Fix a couple of ArrayLikeBlock bugs --- src/varnamedtuple.jl | 11 +++++++---- test/varnamedtuple.jl | 12 +++++++++++- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 5d83afc5e..a0640fb3b 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -525,6 +525,8 @@ function _check_index_validity(pa::PartialArray, inds::NTuple{N,INDEX_TYPES}) wh end function _getindex(pa::PartialArray, inds::Vararg{INDEX_TYPES}) + # The original, non-bare inds is needed later for ArrayLikeBlock checks. + orig_inds = inds inds = _unwrap_concretized_slice.(inds) _check_index_validity(pa, inds) if !(checkbounds(Bool, pa.mask, inds...) && all(@inbounds(getindex(pa.mask, inds...)))) @@ -561,7 +563,7 @@ function _getindex(pa::PartialArray, inds::Vararg{INDEX_TYPES}) if !(first_elem isa ArrayLikeBlock) throw(err) end - if inds != first_elem.inds + if orig_inds != first_elem.inds # The requested indices do not match the ones used to set the value. throw(err) end @@ -655,6 +657,7 @@ function _needs_arraylikeblock(value, inds::Vararg{INDEX_TYPES}) end function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) + orig_inds = inds inds = _unwrap_concretized_slice.(inds) _check_index_validity(pa, inds) pa = if checkbounds(Bool, pa.mask, inds...) @@ -679,7 +682,7 @@ function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) # some notion of size, and that size matches the indices that are being set. In this # case we wrap the value in an ArrayLikeBlock, and set all the individual indices # to point to that. - alb = ArrayLikeBlock(value, inds) + alb = ArrayLikeBlock(value, orig_inds) new_data = setindex!!(new_data, fill(alb, inds_size...), inds...) else new_data = setindex!!(new_data, value, inds...) @@ -1180,11 +1183,11 @@ end function _map_recursive!!(func, alb::ArrayLikeBlock, vn) new_block = _map_recursive!!(func, alb.block, vn) - if size(new_block) != size(alb.block) + if vnt_size(new_block) != vnt_size(alb.block) throw( DimensionMismatch( "map_pairs!! can't change the size of an ArrayLikeBlock. Tried to change " * - "from $(size(alb.block)) to $(size(new_block)).", + "from $(vnt_size(alb.block)) to $(vnt_size(new_block)).", ), ) end diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 655b8e9e5..18737f3d7 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -6,7 +6,7 @@ using Test: @inferred, @test, @test_throws, @testset using DynamicPPL: DynamicPPL, @varname, VarNamedTuple, subset using DynamicPPL.VarNamedTuples: PartialArray, ArrayLikeBlock, map_pairs!!, map_values!!, apply!! -using AbstractPPL: VarName, concretize, prefix +using AbstractPPL: AbstractPPL, VarName, concretize, prefix using BangBang: setindex!!, empty!! """ @@ -305,6 +305,16 @@ Base.size(st::SizedThing) = st.size @test @inferred(getindex(vnt, vn)) == x test_invariants(vnt) + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, SizedThing((3,)), vn)) + @test haskey(vnt, vn) + @test vn in keys(vnt) + @test @inferred(getindex(vnt, vn)) == SizedThing((3,)) + # TODO(mhauru) The below test_invariants fails because AbstractPPL's ConretizedSlice + # objects don't respect the eval(Meta.parse(repr(...))) == ... property. + # test_invariants(vnt) + + vnt = VarNamedTuple() y = fill("a", (3, 2, 4)) x = y[:, 2, :] a = (; b=[nothing, nothing, (; c=(; d=reshape(y, (1, 3, 2, 4, 1))))]) From 1cbcda7b551528feaf24e5041614732edc2038d0 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 9 Jan 2026 18:02:15 +0000 Subject: [PATCH 107/148] Fix PartialArray map bug --- src/varnamedtuple.jl | 19 ++++++++- test/varnamedtuple.jl | 90 ++++++++++++++++++++++++++++++++----------- 2 files changed, 84 insertions(+), 25 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index a0640fb3b..4c739e681 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -1169,12 +1169,27 @@ function _map_recursive!!(func, pa::PartialArray, vn) # We need to allocate a new data array. similar(pa.data, new_et) end + # Keep a dictionary of already-seen ArrayLikeBlocks to avoid redundant computations. + # This matters not only for performance, but also for correctness, because + # _map_recursive!! may mutate the value, and we don't want to mutate it multiple times. + albs_seen = Dict{ArrayLikeBlock,ArrayLikeBlock}() @inbounds for i in CartesianIndices(pa.mask) if pa.mask[i] val = pa.data[i] - ind = val isa ArrayLikeBlock ? val.inds : Tuple(i) + is_alb = val isa ArrayLikeBlock + if is_alb + if val in keys(albs_seen) + new_data[i] = albs_seen[val] + continue + end + end + ind = is_alb ? val.inds : Tuple(i) new_vn = IndexLens(ind) ∘ vn - new_data[i] = _map_recursive!!(func, pa.data[i], new_vn) + new_val = _map_recursive!!(func, pa.data[i], new_vn) + new_data[i] = new_val + if is_alb + albs_seen[val] = new_val + end end end # The above type inference may be overly conservative, so we concretise the eltype. diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 18737f3d7..a18885be7 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -10,13 +10,15 @@ using AbstractPPL: AbstractPPL, VarName, concretize, prefix using BangBang: setindex!!, empty!! """ - test_invariants(vnt::VarNamedTuple) + test_invariants(vnt::VarNamedTuple; skip=()) Test properties that should hold for all VarNamedTuples. Uses @test for all the tests. Intended to be called inside a @testset. + +`skip` is a tuple of symbols indicating which tests are to be skipped. """ -function test_invariants(vnt::VarNamedTuple) +function test_invariants(vnt::VarNamedTuple; skip=()) # These will be needed repeatedly. vnt_keys = keys(vnt) vnt_values = values(vnt) @@ -41,12 +43,14 @@ function test_invariants(vnt::VarNamedTuple) # The below eval test is a bit fragile: If any elements in vnt don't respect the same # reconstructability-from-repr property, this will fail. Likewise if any element uses # in its repr print out types that are not in scope in this module, it will fail. - vnt3 = eval(Meta.parse(repr(vnt))) - equality = (vnt == vnt3) - # The value may be `missing` if vnt itself has values that are missing. - @test equality === true || equality === missing - @test isequal(vnt, vnt3) - @test hash(vnt) == hash(vnt3) + if !(:parseeval in skip) + vnt3 = eval(Meta.parse(repr(vnt))) + equality = (vnt == vnt3) + # The value may be `missing` if vnt itself has values that are missing. + @test equality === true || equality === missing + @test isequal(vnt, vnt3) + @test hash(vnt) == hash(vnt3) + end # Check that merge with an empty VarNamedTuple is a no-op. @test isequal(merge(vnt, VarNamedTuple()), vnt) @@ -310,9 +314,9 @@ Base.size(st::SizedThing) = st.size @test haskey(vnt, vn) @test vn in keys(vnt) @test @inferred(getindex(vnt, vn)) == SizedThing((3,)) - # TODO(mhauru) The below test_invariants fails because AbstractPPL's ConretizedSlice + # TODO(mhauru) The below skip is needed because AbstractPPL's ConretizedSlice # objects don't respect the eval(Meta.parse(repr(...))) == ... property. - # test_invariants(vnt) + test_invariants(vnt; skip=(:parseeval,)) vnt = VarNamedTuple() y = fill("a", (3, 2, 4)) @@ -927,15 +931,21 @@ Base.size(st::SizedThing) = st.size vnt = @inferred( setindex!!(vnt, SizedThing((2, 2)), @varname(y.z[3, 2:3, 3, 2:3, 4])) ) + concretized_vn = concretize(@varname(v[:]), [0, 0]) + vnt = @inferred(setindex!!(vnt, SizedThing((2,)), concretized_vn)) vnt = @inferred(setindex!!(vnt, "", @varname(w[4][3][2, 1]))) - test_invariants(vnt) + # TODO(mhauru) The below skip is needed because AbstractPPL's ConretizedSlice + # objects don't respect the eval(Meta.parse(repr(...))) == ... property. + test_invariants(vnt; skip=(:parseeval,)) struct AnotherSizedThing{T<:Tuple} size::T end Base.size(st::AnotherSizedThing) = st.size + call_counter = 0 function f_val(val) + call_counter += 1 if val isa Int return val + 10 elseif val isa AbstractVector{Int} @@ -956,8 +966,9 @@ Base.size(st::SizedThing) = st.size f_pair(pair) = f_val(pair.second) val_reduction = mapreduce(pair -> pair.second, vcat, vnt; init=Any[]) - @test val_reduction == - vcat(Any[], 1, [2, 2], [3.0], "a", 5.0, SizedThing((2, 2)), "") + @test val_reduction == vcat( + Any[], 1, [2, 2], [3.0], "a", 5.0, SizedThing((2, 2)), SizedThing((2,)), "" + ) key_reduction = mapreduce(pair -> pair.first, vcat, vnt; init=Any[]) @test key_reduction == vcat( @varname(a), @@ -967,18 +978,35 @@ Base.size(st::SizedThing) = st.size @varname(e.f[3].g.h[2].i), @varname(e.f[3].g.h[2].j), @varname(y.z[3, 2:3, 3, 2:3, 4]), + concretized_vn, @varname(w[4][3][2, 1]), ) + + call_counter = 0 reduction = mapreduce(f_pair, vcat, vnt; init=Any[]) - @test reduction == - vcat(Any[], 11, [12, 12], [2.0], "ab", 6.0, AnotherSizedThing((2, 2)), "b") + @test reduction == vcat( + Any[], + 11, + [12, 12], + [2.0], + "ab", + 6.0, + AnotherSizedThing((2, 2)), + AnotherSizedThing((2,)), + "b", + ) + # Check that f_pair gets called exactly once per element. + @test call_counter == length(keys(vnt)) # TODO(mhauru) This should hopefully be type stable, but fails to be so because of # some complex VarNames being too much for constant propagation. See comment in # src/varnamedtuple.jl for more. + call_counter = 0 vnt_mapped = map_pairs!!(f_pair, copy(vnt)) + # Check that f_pair gets called exactly once per element. + @test call_counter == length(keys(vnt)) @test vnt_mapped == map_values!!(f_val, copy(vnt)) - test_invariants(vnt_mapped) + test_invariants(vnt_mapped; skip=(:parseeval,)) @test @inferred(getindex(vnt_mapped, @varname(a))) == 11 @test @inferred(getindex(vnt_mapped, @varname(b[1:2]))) == [12, 12] @test @inferred(getindex(vnt_mapped, @varname(c.d))) == [2.0] @@ -986,29 +1014,38 @@ Base.size(st::SizedThing) = st.size @test @inferred(getindex(vnt_mapped, @varname(e.f[3].g.h[2].j))) == 6.0 @test @inferred(getindex(vnt_mapped, @varname(y.z[3, 2:3, 3, 2:3, 4]))) == AnotherSizedThing((2, 2)) + @test @inferred(getindex(vnt_mapped, concretized_vn)) == AnotherSizedThing((2,)) @test @inferred(getindex(vnt_mapped, @varname(w[4][3][2, 1]))) == "b" + call_counter = 0 vnt_applied = @inferred(apply!!(f_val, vnt, @varname(a))) - test_invariants(vnt_applied) + @test call_counter == 1 + test_invariants(vnt_applied; skip=(:parseeval,)) @test @inferred(getindex(vnt_applied, @varname(a))) == 11 @test @inferred(getindex(vnt_applied, @varname(b[1:2]))) == [2, 2] vnt_applied = @inferred(apply!!(f_val, vnt_applied, @varname(b[1:2]))) - test_invariants(vnt_applied) + # Unlike map_pairs!!, apply!! operates on the whole value at once, rather than + # element-wise, so this is only one more call. + @test call_counter == 2 + test_invariants(vnt_applied; skip=(:parseeval,)) @test @inferred(getindex(vnt_applied, @varname(a))) == 11 @test @inferred(getindex(vnt_applied, @varname(b[1:2]))) == [12, 12] vnt_applied = @inferred(apply!!(f_val, vnt_applied, @varname(c.d))) - test_invariants(vnt_applied) + @test call_counter == 3 + test_invariants(vnt_applied; skip=(:parseeval,)) @test @inferred(getindex(vnt_applied, @varname(c.d))) == [2.0] vnt_applied = @inferred(apply!!(f_val, vnt_applied, @varname(e.f[3].g.h[2].i))) - test_invariants(vnt_applied) + @test call_counter == 4 + test_invariants(vnt_applied; skip=(:parseeval,)) @test @inferred(getindex(vnt_applied, @varname(e.f[3].g.h[2].i))) == "ab" @test @inferred(getindex(vnt_applied, @varname(e.f[3].g.h[2].j))) == 5.0 vnt_applied = @inferred(apply!!(f_val, vnt_applied, @varname(e.f[3].g.h[2].j))) - test_invariants(vnt_applied) + @test call_counter == 5 + test_invariants(vnt_applied; skip=(:parseeval,)) @test @inferred(getindex(vnt_applied, @varname(e.f[3].g.h[2].i))) == "ab" @test @inferred(getindex(vnt_applied, @varname(e.f[3].g.h[2].j))) == 6.0 @@ -1016,12 +1053,19 @@ Base.size(st::SizedThing) = st.size # know at compile time that this sets the only one, thus allowing the element type # to be AnotherSizedThing. vnt_applied = apply!!(f_val, vnt_applied, @varname(y.z[3, 2:3, 3, 2:3, 4])) - test_invariants(vnt_applied) + @test call_counter == 6 + test_invariants(vnt_applied; skip=(:parseeval,)) @test @inferred(getindex(vnt_applied, @varname(y.z[3, 2:3, 3, 2:3, 4]))) == AnotherSizedThing((2, 2)) + vnt_applied = apply!!(f_val, vnt_applied, concretized_vn) + @test call_counter == 7 + test_invariants(vnt_applied; skip=(:parseeval,)) + @test @inferred(getindex(vnt_applied, concretized_vn)) == AnotherSizedThing((2,)) + vnt_applied = @inferred(apply!!(f_val, vnt_applied, @varname(w[4][3][2, 1]))) - test_invariants(vnt_applied) + @test call_counter == 8 + test_invariants(vnt_applied; skip=(:parseeval,)) @test @inferred(getindex(vnt_applied, @varname(w[4][3][2, 1]))) == "b" # map a function that maps every key => value pair to key => key. From 573cd5afea464c2148ee5767011d676e4ab15093 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 9 Jan 2026 18:04:36 +0000 Subject: [PATCH 108/148] In VNTVarInfo, handle variables with varying dimensions correctly --- src/contexts/init.jl | 2 +- src/threadsafe.jl | 6 ++++++ src/vntvarinfo.jl | 26 ++++++++++++++++---------- test/varinfo.jl | 20 -------------------- 4 files changed, 23 insertions(+), 31 deletions(-) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index f137e07d6..b118280d0 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -363,7 +363,7 @@ function tilde_assume!!( vi = setindex!!(vi, val_to_insert, vn) else vi = if vi isa VNTVarInfo - push!!(vi, vn, val_to_insert, inverse(transform)) + push!!(vi, vn, val_to_insert, inverse(transform), size(x)) else push!!(vi, vn, val_to_insert, dist) end diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 44b4da316..f168eb7c1 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -71,6 +71,12 @@ function BangBang.push!!(vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distributi return Accessors.@set vi.varinfo = push!!(vi.varinfo, vn, r, dist) end +function BangBang.push!!( + vi::ThreadSafeVarInfo, vn::VarName, r, transform=typed_identity, orig_size=size(r) +) + return Accessors.@set vi.varinfo = push!!(vi.varinfo, vn, r, transform, orig_size) +end + syms(vi::ThreadSafeVarInfo) = syms(vi.varinfo) setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn) diff --git a/src/vntvarinfo.jl b/src/vntvarinfo.jl index 1ae9bc9d8..756bf8e34 100644 --- a/src/vntvarinfo.jl +++ b/src/vntvarinfo.jl @@ -6,12 +6,15 @@ end # TODO(mhauru) Make this renaming permanent. const VarInfo = VNTVarInfo -struct TransformedValue{ValType,TransformType} +struct TransformedValue{ValType,TransformType,SizeType} val::ValType linked::Bool transform::TransformType + size::SizeType end +VarNamedTuples.vnt_size(tv::TransformedValue) = tv.size + VNTVarInfo() = VNTVarInfo(VarNamedTuple(), default_accumulators()) function VNTVarInfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) @@ -53,7 +56,7 @@ BangBang.empty!!(vi::VNTVarInfo) = VNTVarInfo(empty!!(vi.values), map(reset, vi. function setindex_internal!!(vi::VNTVarInfo, val, vn::VarName) old_tv = getindex(vi.values, vn) - new_tv = TransformedValue(val, old_tv.linked, old_tv.transform) + new_tv = TransformedValue(val, old_tv.linked, old_tv.transform, old_tv.size) new_values = setindex!!(vi.values, new_tv, vn) return VNTVarInfo(new_values, vi.accs) end @@ -61,12 +64,14 @@ end BangBang.setindex!!(vi::VNTVarInfo, val, vn::VarName) = push!!(vi, vn, val) # TODO(mhauru) The arguments are in the wrong order, but this is the current convetion. -function BangBang.push!!(vi::VNTVarInfo, vn::VarName, val, transform=typed_identity) +function BangBang.push!!( + vi::VNTVarInfo, vn::VarName, val, transform=typed_identity, orig_size=size(val) +) # TODO(mhauru) We should move away from having all values vectorised by default. # That messes with our use of unflatten though, so will require some thought. transform = _compose_no_identity(transform, from_vec_transform(val)) val = to_vec_transform(val)(val) - new_tv = TransformedValue(val, false, transform) + new_tv = TransformedValue(val, false, transform, orig_size) new_values = setindex!!(vi.values, new_tv, vn) return VNTVarInfo(new_values, vi.accs) end @@ -76,14 +81,14 @@ Base.values(vi::VNTVarInfo) = mapreduce(p -> p.second.val, push!, vi.values; ini function set_transformed!!(vi::VNTVarInfo, linked::Bool, vn::VarName) old_tv = getindex(vi.values, vn) - new_tv = TransformedValue(old_tv.val, linked, old_tv.transform) + new_tv = TransformedValue(old_tv.val, linked, old_tv.transform, old_tv.size) new_values = setindex!!(vi.values, new_tv, vn) return VNTVarInfo(new_values, vi.accs) end function set_transformed!!(vi::VNTVarInfo, linked::Bool) new_values = map_values!!(vi.values) do tv - TransformedValue(tv.val, linked, tv.transform) + TransformedValue(tv.val, linked, tv.transform, tv.size) end return VNTVarInfo(new_values, vi.accs) end @@ -121,9 +126,11 @@ function from_linked_internal_transform(vi::VNTVarInfo, vn::VarName) end function change_transform(tv::TransformedValue, new_transform, linked) + # Note that the transform may change the size of `val`, but it doesn't change the + # tv.size, since that one tracks the original size of the value before any transforms. val_untransformed, logjac1 = with_logabsdet_jacobian(tv.transform, tv.val) val_new, logjac2 = with_logabsdet_jacobian(inverse(new_transform), val_untransformed) - return TransformedValue(val_new, linked, new_transform), logjac1 + logjac2 + return TransformedValue(val_new, linked, new_transform, tv.size), logjac1 + logjac2 end function link!!(::DynamicTransformation, vi::VNTVarInfo, vns, model::Model) @@ -154,8 +161,7 @@ function link!!(::DynamicTransformation, vi::VNTVarInfo, model::Model) # map!!, but it doesn't have access to the VarName. dists = extract_priors(model, vi) cumulative_logjac = zero(LogProbType) - new_values = vi.values - new_values = map_pairs!!(new_values) do pair + new_values = map_pairs!!(vi.values) do pair vn, tv = pair dist = getindex(dists, vn) transform = from_linked_vec_transform(dist) @@ -324,7 +330,7 @@ function unflatten!!(vi::VNTVarInfo, vec::AbstractVector) end len = length(old_val) new_val = get_next_chunk!(vci, len) - return TransformedValue(new_val, tv.linked, tv.transform) + return TransformedValue(new_val, tv.linked, tv.transform, tv.size) end return VNTVarInfo(new_values, vi.accs) end diff --git a/test/varinfo.jl b/test/varinfo.jl index 0a8e58eef..0bea67402 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -369,26 +369,6 @@ end model, value_true, varnames; include_threadsafe=true ) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple{<:NamedTuple} - # NOTE: this is broken since we'll end up trying to set - # - # varinfo[@varname(x[4:5])] = [x[4],] - # - # upon linking (since `x[4:5]` will be projected onto a 1-dimensional - # space). In the case of `SimpleVarInfo{<:NamedTuple}`, this results in - # calling `setindex!!(varinfo.values, [x[4],], @varname(x[4:5]))`, which - # in turn attempts to call `setindex!(varinfo.values.x, [x[4],], 4:5)`, - # i.e. a vector of length 1 (`[x[4],]`) being assigned to 2 indices (`4:5`). - @test_broken false - continue - end - - if DynamicPPL.has_varnamedvector(varinfo) && mutating - # NOTE: Can't handle mutating `link!` and `invlink!` `VarNamedVector`. - @test_broken false - continue - end - # Evaluate the model once to update the logp of the varinfo. varinfo = last(DynamicPPL.evaluate!!(model, varinfo)) From c353cbc7ba9fa117bfd330f3d806831ce92be176 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 9 Jan 2026 18:25:31 +0000 Subject: [PATCH 109/148] Fix two small bugs --- src/utils.jl | 4 ++++ src/varnamedtuple.jl | 11 +++++++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 11261334c..0e03c5cdc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,3 +1,7 @@ +# subset is defined here to avoid circular dependencies between files. Methods for it are +# defined in other files. +function subset end + # singleton for indicating if no default arguments are present struct NoDefault end const NO_DEFAULT = NoDefault() diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 4c739e681..2f7e38ca6 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -388,6 +388,13 @@ function BangBang.empty!!(pa::PartialArray) return pa end +# This is a tad hacky: We use _mapreduce_recursive which requires a prefix VarName. We give +# it the non-sense @varname(_), and then strip it away with the mapping function, returning +# only the optic. +function Base.keys(pa::PartialArray) + return _mapreduce_recursive(pair -> first(pair).optic, push!, pa, @varname(_), Any[]) +end + # Length could be defined as a special case of mapreduce, but it's harder to keep it type # stable that way: If the element type is abstract, we end up calling _mapreduce_recursive # on an abstract type, which makes the type of the cumulant Any. @@ -1500,8 +1507,8 @@ function AbstractPPL.hasvalue(vnt::VarNamedTuple, vn::VarName, dist::MV_DIST_TYP # Note that _getindex, rather than getindex, skips the need to denseify PartialArrays. val = _getindex(vnt, vn) if !(val isa VarNamedTuple || val isa PartialArray) - # There is _a_ value. Where it's the right kind, we do not know, but returning true - # is no worse than `hasvalue` returning true for e.g. UnivariateDistributions + # There is _a_ value. Whether it's the right kind, we do not know, but returning + # true is no worse than `hasvalue` returning true for e.g. UnivariateDistributions # whenever there is at least some value. return true end From a36bb150d2c44ab365514650be00d4ba005b9280 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 9 Jan 2026 18:59:29 +0000 Subject: [PATCH 110/148] Allow nested PartialArrays with ArrayLikeBlocks --- src/varnamedtuple.jl | 15 +++++++++++++-- test/varnamedtuple.jl | 2 ++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 2f7e38ca6..37158442b 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -110,6 +110,8 @@ Get the size of an object `x` for use in `VarNamedTuple` and `PartialArray`. By default, this falls back onto `Base.size`, but can be overloaded for custom types. This notion of type is used to determine whether a value can be set into a `PartialArray` as a block, see the docstring of `PartialArray` and `ArrayLikeBlock` for details. + +A special return value of `Val(:pass)` indicates that the size check should be skipped. """ vnt_size(x) = size(x) @@ -294,6 +296,13 @@ end # The size of the .data field is an implementation detail. _internal_size(pa::PartialArray, args...) = size(pa.data, args...) +# Even though a PartialArray has no well-defined size, we still allow it to be used as an +# ArrayLikeBlock. This enables setting values for keys like @varname(x[1:3][1]), which will +# be stored as a PartialArray wrapped in an ArrayLikeBlock, stored in another PartialArray. +# Note that this bypasses _any_ size checks, so that e.g. @varname(x[1:3][1,15]) is also a +# valid key. +vnt_size(pa::PartialArray) = Val(:pass) + function Base.copy(pa::PartialArray) # Make a shallow copy of pa, except for any VarNamedTuple elements, which we recursively # copy. @@ -677,7 +686,7 @@ function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) new_data = pa.data if _needs_arraylikeblock(value, inds...) inds_size = reduce((x, y) -> tuple(x..., y...), map(size, inds)) - if vnt_size(value) != inds_size + if vnt_size(value) !== Val(:pass) && vnt_size(value) != inds_size throw( DimensionMismatch( "Assigned value has size $(vnt_size(value)), which does not match " * @@ -1205,7 +1214,9 @@ end function _map_recursive!!(func, alb::ArrayLikeBlock, vn) new_block = _map_recursive!!(func, alb.block, vn) - if vnt_size(new_block) != vnt_size(alb.block) + sz_new = vnt_size(new_block) + sz_old = vnt_size(alb.block) + if sz_new !== Val(:pass) && sz_old !== Val(:pass) && sz_new != sz_old throw( DimensionMismatch( "map_pairs!! can't change the size of an ArrayLikeBlock. Tried to change " * diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index a18885be7..1937ea189 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -333,6 +333,8 @@ Base.size(st::SizedThing) = st.size vnt = VarNamedTuple() vnt = @inferred(setindex!!(vnt, 1, @varname(a[1][1]))) @test @inferred(getindex(vnt, @varname(a[1][1]))) == 1 + vnt = @inferred(setindex!!(vnt, 1, @varname(ab[1:2][1]))) + @test @inferred(getindex(vnt, @varname(a[1][1]))) == 1 vnt = @inferred(setindex!!(vnt, [1], @varname(b[1].c[1]))) @test @inferred(getindex(vnt, @varname(b[1].c[1]))) == [1] vnt = @inferred(setindex!!(vnt, [1], @varname(e[3, 2].f[2, 2][10, 10]))) From bf05554ff96d2e7e72b8a64c497c296af9f24851 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 9 Jan 2026 18:59:59 +0000 Subject: [PATCH 111/148] Stop testing for NamedDist with unconcrete VarName --- test/compiler.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/test/compiler.jl b/test/compiler.jl index 5101bd602..e4a9a2474 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -341,21 +341,21 @@ module Issue537 end end @testset "user-defined variable name" begin @model f1() = x ~ NamedDist(Normal(), :y) - @model f2() = x ~ NamedDist(Normal(), @varname(y[2][:, 1])) + @model f2() = x ~ NamedDist(Normal(), @varname(y[2][5, 1])) @model f3() = x ~ NamedDist(Normal(), @varname(y[1])) vi1 = VarInfo(f1()) vi2 = VarInfo(f2()) vi3 = VarInfo(f3()) - @test haskey(vi1.metadata, :y) - @test first(Base.keys(vi1.metadata.y)) == @varname(y) - @test haskey(vi2.metadata, :y) - @test first(Base.keys(vi2.metadata.y)) == @varname(y[2][:, 1]) - @test haskey(vi3.metadata, :y) - @test first(Base.keys(vi3.metadata.y)) == @varname(y[1]) + @test haskey(vi1, @varname(y)) + @test first(Base.keys(vi1)) == @varname(y) + @test haskey(vi2, @varname(y[2][5, 1])) + @test first(Base.keys(vi2)) == @varname(y[2][5, 1]) + @test haskey(vi3, @varname(y[1])) + @test first(Base.keys(vi3)) == @varname(y[1]) # Conditioning f1_c = f1() | (y=1,) - f2_c = f2() | NamedTuple((Symbol(@varname(y[2][:, 1])) => 1,)) + f2_c = f2() | NamedTuple((Symbol(@varname(y[2][5, 1])) => 1,)) f3_c = f3() | NamedTuple((Symbol(@varname(y[1])) => 1,)) @test f1_c() == 1 # TODO(torfjelde): We need conditioning for `Dict`. From 7857eaee73c3854bde5ae30f1626f70c91691432 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 9 Jan 2026 19:39:42 +0000 Subject: [PATCH 112/148] Misc bugfixes --- docs/src/api.md | 2 +- ext/DynamicPPLMarginalLogDensitiesExt.jl | 2 +- src/DynamicPPL.jl | 4 +- src/abstract_varinfo.jl | 10 ++--- src/contexts/init.jl | 3 +- src/logdensityfunction.jl | 23 ++++++++---- src/test_utils/ad.jl | 2 +- src/values_as_in_model.jl | 2 +- test/logdensityfunction.jl | 45 ++++++++++------------- test/model.jl | 47 +++++++----------------- 10 files changed, 61 insertions(+), 79 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index f687fd90a..084639c07 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -468,7 +468,7 @@ DynamicPPL.maybe_invlink_before_eval!! ```@docs Base.merge(::AbstractVarInfo) DynamicPPL.subset -DynamicPPL.unflatten +DynamicPPL.unflatten!! ``` ### Evaluation Contexts diff --git a/ext/DynamicPPLMarginalLogDensitiesExt.jl b/ext/DynamicPPLMarginalLogDensitiesExt.jl index e28560872..8e53d8709 100644 --- a/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ b/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -212,7 +212,7 @@ function DynamicPPL.VarInfo( if unmarginalized_params !== nothing full_params[MarginalLogDensities.ijoint(mld)] = unmarginalized_params end - return DynamicPPL.unflatten(original_vi, full_params) + return DynamicPPL.unflatten!!(original_vi, full_params) end end diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index d6248a4d0..84b8b2e68 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -196,14 +196,14 @@ include("model.jl") include("varname.jl") include("distribution_wrappers.jl") include("submodel.jl") -include("varnamedvector.jl") +# include("varnamedvector.jl") include("accumulators.jl") include("default_accumulators.jl") include("abstract_varinfo.jl") include("threadsafe.jl") # include("varinfo.jl") include("vntvarinfo.jl") -include("simple_varinfo.jl") +# include("simple_varinfo.jl") include("onlyaccs.jl") include("compiler.jl") include("pointwise_logdensities.jl") diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 898b6caf9..ef1d92042 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -838,14 +838,14 @@ function link!!( t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model ) # TODO(mhauru) This assumes that the user has defined the bijector using the same - # variable ordering as what `vi[:]` and `unflatten(vi, x)` use. This is a bad user + # variable ordering as what `vi[:]` and `unflatten!!(vi, x)` use. This is a bad user # interface, and it's also dangerous for any AbstractVarInfo types that may not respect # a particular ordering, such as SimpleVarInfo{Dict}. b = inverse(t.bijector) x = vi[:] y, logjac = with_logabsdet_jacobian(b, x) # Set parameters and add the logjac term. - vi = unflatten(vi, y) + vi = unflatten!!(vi, y) if hasacc(vi, Val(:LogJacobian)) vi = acclogjac!!(vi, logjac) end @@ -910,7 +910,7 @@ function invlink!!( # Mildly confusing: we need to _add_ the logjac of the inverse transform, # because we are trying to remove the logjac of the forward transform # that was previously accumulated when linking. - vi = unflatten(vi, x) + vi = unflatten!!(vi, x) if hasacc(vi, Val(:LogJacobian)) vi = acclogjac!!(vi, inv_logjac) end @@ -1013,11 +1013,11 @@ end # Utilities """ - unflatten(vi::AbstractVarInfo, x::AbstractVector) + unflatten!!(vi::AbstractVarInfo, x::AbstractVector) Return a new instance of `vi` with the values of `x` assigned to the variables. """ -function unflatten end +function unflatten!! end """ to_maybe_linked_internal(vi::AbstractVarInfo, vn::VarName, dist, val) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index b118280d0..8899ba4d0 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -363,7 +363,8 @@ function tilde_assume!!( vi = setindex!!(vi, val_to_insert, vn) else vi = if vi isa VNTVarInfo - push!!(vi, vn, val_to_insert, inverse(transform), size(x)) + x_size = hasmethod(size, Tuple{typeof(x)}) ? size(x) : () + vi = push!!(vi, vn, val_to_insert, inverse(transform), x_size) else push!!(vi, vn, val_to_insert, dist) end diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index adcb319c8..e7c83ecb4 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -311,13 +311,22 @@ This function returns a VarNamedTuple mapping all VarNames to their correspondin `RangeAndLinked`. """ function get_ranges_and_linked(vi::VNTVarInfo) - offset = 1 - vnt = map!!(vi.values) do tv - val = tv.val - range = offset:(offset + length(val) - 1) - offset += length(val) - RangeAndLinked(range, tv.linked, size(val)) - end + # TODO(mhauru) Check that the closure doesn't cause type instability here. + vnt = VarNamedTuple() + vnt, _ = mapreduce( + identity, + function ((vnt, offset), pair) + vn, tv = pair + val = tv.val + range = offset:(offset + length(val) - 1) + offset += length(val) + ral = RangeAndLinked(range, tv.linked, size(val)) + vnt = setindex!!(vnt, ral, vn) + return vnt, offset + end, + vi.values; + init=(VarNamedTuple(), 1), + ) return vnt end diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index a030b479e..6bcd9547e 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -242,7 +242,7 @@ Everything else is optional, and can be categorised into several groups: Finally, note that these only reflect the parameters used for _evaluating_ the gradient. If you also want to control the parameters used for _preparing_ the gradient, then you need to manually set these parameters in - the VarInfo object, for example using `vi = DynamicPPL.unflatten(vi, + the VarInfo object, for example using `vi = DynamicPPL.unflatten!!(vi, prep_params)`. You could then evaluate the gradient at a different set of parameters using the `params` keyword argument. diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index 992cbdc8d..f7440d6ff 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -118,7 +118,7 @@ julia> # Perform computations in unconstrained space, e.g. changing the values o θ = [!varinfo[@varname(x)], rand(rng)]; julia> # Update the `VarInfo` with the new values. - varinfo_linked = DynamicPPL.unflatten(varinfo_linked, θ); + varinfo_linked = DynamicPPL.unflatten!!(varinfo_linked, θ); julia> # Determine the expected support of `y`. lb, ub = θ[1] == 1 ? (0, 1) : (11, 12) diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index 77ae0ccab..777b91ee4 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -17,31 +17,24 @@ using Mooncake: Mooncake @testset "LogDensityFunction: Correctness" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.ALL_MODELS - @testset "$varinfo_func" for varinfo_func in [ - # DynamicPPL.untyped_varinfo, - DynamicPPL.typed_varinfo, - # DynamicPPL.untyped_vector_varinfo, - # DynamicPPL.typed_vector_varinfo, - ] - unlinked_vi = varinfo_func(m) - @testset "$islinked" for islinked in (false, true) - vi = if islinked - DynamicPPL.link!!(unlinked_vi, m) - else - unlinked_vi - end - ranges = DynamicPPL.get_ranges_and_linked(vi) - params = [x for x in vi[:]] - # Iterate over all variables - for vn in keys(vi) - # Check that `getindex_internal` returns the same thing as using the ranges - # directly - range_with_linked = ranges[vn] - @test params[range_with_linked.range] == - DynamicPPL.tovec(DynamicPPL.getindex_internal(vi, vn)) - # Check that the link status is correct - @test range_with_linked.is_linked == islinked - end + @testset "$islinked" for islinked in (false, true) + unlinked_vi = DynamicPPL.VarInfo(m) + vi = if islinked + DynamicPPL.link!!(unlinked_vi, m) + else + unlinked_vi + end + ranges = DynamicPPL.get_ranges_and_linked(vi) + params = [x for x in vi[:]] + # Iterate over all variables + for vn in keys(vi) + # Check that `getindex_internal` returns the same thing as using the ranges + # directly + range_with_linked = ranges[vn] + @test params[range_with_linked.range] == + DynamicPPL.tovec(DynamicPPL.getindex_internal(vi, vn)) + # Check that the link status is correct + @test range_with_linked.is_linked == islinked end end end @@ -104,8 +97,8 @@ end @testset "LogDensityFunction: Type stability" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.ALL_MODELS - unlinked_vi = DynamicPPL.VarInfo(m) @testset "$islinked" for islinked in (false, true) + unlinked_vi = DynamicPPL.VarInfo(m) vi = if islinked DynamicPPL.link!!(unlinked_vi, m) else diff --git a/test/model.jl b/test/model.jl index 29b9650a5..05688c224 100644 --- a/test/model.jl +++ b/test/model.jl @@ -26,8 +26,7 @@ function innermost_distribution_type(d::Distributions.Product) end is_type_stable_varinfo(::DynamicPPL.AbstractVarInfo) = false -is_type_stable_varinfo(varinfo::DynamicPPL.NTVarInfo) = true -is_type_stable_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true +is_type_stable_varinfo(varinfo::DynamicPPL.VNTVarInfo) = true const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @@ -230,24 +229,13 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() for i in 1:10 # Sample with large variations. r_raw = randn(length(vi[:])) * 10 - vi = DynamicPPL.unflatten(vi, r_raw) + vi = DynamicPPL.unflatten!!(vi, r_raw) @test vi[@varname(m)] == r_raw[1] @test vi[@varname(x)] != r_raw[2] model(vi) end end - @testset "Dynamic constraints, VectorVarInfo" begin - model = DynamicPPL.TestUtils.demo_dynamic_constraint() - for i in 1:10 - for vi_constructor in - [DynamicPPL.typed_vector_varinfo, DynamicPPL.untyped_vector_varinfo] - vi = vi_constructor(model) - @test vi[@varname(x)] >= vi[@varname(m)] - end - end - end - @testset "rand" begin model = GDEMO_DEFAULT @@ -510,26 +498,17 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() end model = product_dirichlet() - varinfos = [ - DynamicPPL.untyped_varinfo(model), - DynamicPPL.typed_varinfo(model), - DynamicPPL.typed_simple_varinfo(model), - DynamicPPL.untyped_simple_varinfo(model), - ] - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - logjoint = getlogjoint(varinfo) # unlinked space - varinfo_linked = DynamicPPL.link(varinfo, model) - varinfo_linked_result = last( - DynamicPPL.evaluate!!(model, deepcopy(varinfo_linked)) - ) - # getlogjoint should return the same result as before it was linked - @test getlogjoint(varinfo_linked) ≈ getlogjoint(varinfo_linked_result) - @test getlogjoint(varinfo_linked) ≈ logjoint - # getlogjoint_internal shouldn't - @test getlogjoint_internal(varinfo_linked) ≈ - getlogjoint_internal(varinfo_linked_result) - @test !isapprox(getlogjoint_internal(varinfo_linked), logjoint) - end + varinfo = DynamicPPL.VarInfo(model) + logjoint = getlogjoint(varinfo) # unlinked space + varinfo_linked = DynamicPPL.link(varinfo, model) + varinfo_linked_result = last(DynamicPPL.evaluate!!(model, deepcopy(varinfo_linked))) + # getlogjoint should return the same result as before it was linked + @test getlogjoint(varinfo_linked) ≈ getlogjoint(varinfo_linked_result) + @test getlogjoint(varinfo_linked) ≈ logjoint + # getlogjoint_internal shouldn't + @test getlogjoint_internal(varinfo_linked) ≈ + getlogjoint_internal(varinfo_linked_result) + @test !isapprox(getlogjoint_internal(varinfo_linked), logjoint) end @testset "predict" begin From 16fe15056559a3cc1a2e783727f1a76b18c9d350 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 9 Jan 2026 19:41:21 +0000 Subject: [PATCH 113/148] Stop running SVI and VNT tests --- test/runtests.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index e0b42904c..e04b664fe 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -54,9 +54,9 @@ include("test_util.jl") include("accumulators.jl") include("compiler.jl") include("varnamedtuple.jl") - include("varnamedvector.jl") + # include("varnamedvector.jl") include("varinfo.jl") - include("simple_varinfo.jl") + # include("simple_varinfo.jl") include("model.jl") include("distribution_wrappers.jl") include("linking.jl") From 51a518f19a84dc60eb6976f0a13372ecbf835ab6 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 12 Jan 2026 09:43:41 +0000 Subject: [PATCH 114/148] Fix LDF bug --- src/logdensityfunction.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index e7c83ecb4..44fdad5a8 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -320,7 +320,7 @@ function get_ranges_and_linked(vi::VNTVarInfo) val = tv.val range = offset:(offset + length(val) - 1) offset += length(val) - ral = RangeAndLinked(range, tv.linked, size(val)) + ral = RangeAndLinked(range, tv.linked, tv.size) vnt = setindex!!(vnt, ral, vn) return vnt, offset end, From 1950a9345e3c1740ecc17f27c94c5786f9c959e6 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 12 Jan 2026 12:41:17 +0000 Subject: [PATCH 115/148] Fix some bugs, simplify (inv)linking --- src/contexts/init.jl | 15 +++---- src/test_utils/varinfo.jl | 2 +- src/vntvarinfo.jl | 83 +++++++++++++-------------------------- 3 files changed, 33 insertions(+), 67 deletions(-) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 8899ba4d0..45d6356f1 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -308,7 +308,6 @@ end function tilde_assume!!( ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo ) - in_varinfo = haskey(vi, vn) val, transform = init(ctx.rng, vn, dist, ctx.strategy) x, inv_logjac = with_logabsdet_jacobian(transform, val) # Determine whether to insert a transformed value into the VarInfo. @@ -317,7 +316,7 @@ function tilde_assume!!( # check the rest of the VarInfo to see if other variables are linked. # is_transformed(vi) returns true if vi is nonempty and all variables in vi # are linked. - insert_transformed_value = in_varinfo ? is_transformed(vi, vn) : is_transformed(vi) + insert_transformed_value = haskey(vi, vn) ? is_transformed(vi, vn) : is_transformed(vi) val_to_insert, logjac = if insert_transformed_value # Calculate the forward logjac and sum them up. lt = link_transform(dist) @@ -359,15 +358,11 @@ function tilde_assume!!( end # Add the new value to the VarInfo. `push!!` errors if the value already # exists, hence the need for setindex!!. - if in_varinfo - vi = setindex!!(vi, val_to_insert, vn) + vi = if vi isa VNTVarInfo + x_size = hasmethod(size, Tuple{typeof(x)}) ? size(x) : () + vi = push!!(vi, vn, val_to_insert, inverse(transform), x_size) else - vi = if vi isa VNTVarInfo - x_size = hasmethod(size, Tuple{typeof(x)}) ? size(x) : () - vi = push!!(vi, vn, val_to_insert, inverse(transform), x_size) - else - push!!(vi, vn, val_to_insert, dist) - end + push!!(vi, vn, val_to_insert, dist) end # Neither of these set the `trans` flag so we have to do it manually if # necessary. diff --git a/src/test_utils/varinfo.jl b/src/test_utils/varinfo.jl index 25f4fd04f..bbfb0b662 100644 --- a/src/test_utils/varinfo.jl +++ b/src/test_utils/varinfo.jl @@ -35,7 +35,7 @@ function setup_varinfos( ) vi = DynamicPPL.VarInfo(model) vi = update_values!!(vi, example_values, varnames) - last(DynamicPPL.evaluate!!(model, vi)) + vi = last(DynamicPPL.evaluate!!(model, vi)) varinfos = if include_threadsafe (vi, DynamicPPL.ThreadSafeVarInfo(deepcopy(vi))) diff --git a/src/vntvarinfo.jl b/src/vntvarinfo.jl index 756bf8e34..a0392334c 100644 --- a/src/vntvarinfo.jl +++ b/src/vntvarinfo.jl @@ -125,28 +125,24 @@ function from_linked_internal_transform(vi::VNTVarInfo, vn::VarName) return getindex(vi.values, vn).transform end -function change_transform(tv::TransformedValue, new_transform, linked) - # Note that the transform may change the size of `val`, but it doesn't change the - # tv.size, since that one tracks the original size of the value before any transforms. - val_untransformed, logjac1 = with_logabsdet_jacobian(tv.transform, tv.val) - val_new, logjac2 = with_logabsdet_jacobian(inverse(new_transform), val_untransformed) - return TransformedValue(val_new, linked, new_transform, tv.size), logjac1 + logjac2 -end - function link!!(::DynamicTransformation, vi::VNTVarInfo, vns, model::Model) dists = extract_priors(model, vi) cumulative_logjac = zero(LogProbType) - new_values = vi.values - new_values = map_pairs!!(new_values) do pair + new_values = map_pairs!!(vi.values) do pair vn, tv = pair - if !any(x -> subsumes(x, vn), vns) + if vns !== nothing && !any(x -> subsumes(x, vn), vns) # Not one of the target variables. return tv end dist = getindex(dists, vn) - transform = from_linked_vec_transform(dist) - new_tv, logjac = change_transform(tv, transform, true) - cumulative_logjac += logjac + vec_transform = from_vec_transform(dist) + link_transform = from_linked_vec_transform(dist) + val_untransformed, logjac1 = with_logabsdet_jacobian(vec_transform, tv.val) + val_new, logjac2 = with_logabsdet_jacobian( + inverse(link_transform), val_untransformed + ) + new_tv = TransformedValue(val_new, true, link_transform, tv.size) + cumulative_logjac += logjac1 + logjac2 return new_tv end vi = VNTVarInfo(new_values, vi.accs) @@ -156,39 +152,29 @@ function link!!(::DynamicTransformation, vi::VNTVarInfo, vns, model::Model) return vi end -function link!!(::DynamicTransformation, vi::VNTVarInfo, model::Model) - # TODO(mhauru) This is probably pretty inefficient. Do this better. Would like to use - # map!!, but it doesn't have access to the VarName. - dists = extract_priors(model, vi) - cumulative_logjac = zero(LogProbType) - new_values = map_pairs!!(vi.values) do pair - vn, tv = pair - dist = getindex(dists, vn) - transform = from_linked_vec_transform(dist) - new_tv, logjac = change_transform(tv, transform, true) - cumulative_logjac += logjac - return new_tv - end - vi = VNTVarInfo(new_values, vi.accs) - if hasacc(vi, Val(:LogJacobian)) - vi = acclogjac!!(vi, cumulative_logjac) - end - return vi +function link!!(t::DynamicTransformation, vi::VNTVarInfo, model::Model) + return link!!(t, vi, nothing, model) end function invlink!!(::DynamicTransformation, vi::VNTVarInfo, vns, model::Model) + dists = extract_priors(model, vi) cumulative_logjac = zero(LogProbType) - new_values = vi.values - new_values = map_pairs!!(new_values) do pair + new_values = map_pairs!!(vi.values) do pair vn, tv = pair - if !any(x -> subsumes(x, vn), vns) + if vns !== nothing && !any(x -> subsumes(x, vn), vns) # Not one of the target variables. return tv end - current_val = tv.transform(tv.val) - transform = from_vec_transform(current_val) - new_tv, logjac = change_transform(tv, transform, false) - cumulative_logjac += logjac + current_val = tv.val + dist = getindex(dists, vn) + vec_transform = from_vec_transform(dist) + link_transform = from_linked_vec_transform(dist) + val_untransformed, logjac1 = with_logabsdet_jacobian(link_transform, current_val) + val_new, logjac2 = with_logabsdet_jacobian( + inverse(vec_transform), val_untransformed + ) + new_tv = TransformedValue(val_new, false, vec_transform, tv.size) + cumulative_logjac += logjac1 + logjac2 return new_tv end vi = VNTVarInfo(new_values, vi.accs) @@ -198,23 +184,8 @@ function invlink!!(::DynamicTransformation, vi::VNTVarInfo, vns, model::Model) return vi end -function invlink!!(::DynamicTransformation, vi::VNTVarInfo, model::Model) - # TODO(mhauru) This is probably pretty inefficient. Do this better. Would like to use - # map!!, but it doesn't have access to the VarName. - cumulative_logjac = zero(LogProbType) - new_values = vi.values - new_values = map_values!!(new_values) do tv - current_val = tv.transform(tv.val) - transform = from_vec_transform(current_val) - new_tv, logjac = change_transform(tv, transform, false) - cumulative_logjac += logjac - return new_tv - end - vi = VNTVarInfo(new_values, vi.accs) - if hasacc(vi, Val(:LogJacobian)) - vi = acclogjac!!(vi, cumulative_logjac) - end - return vi +function invlink!!(t::DynamicTransformation, vi::VNTVarInfo, model::Model) + return invlink!!(t, vi, nothing, model) end function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VNTVarInfo}, model::Model) From 051521a83b96ed537b4272acea0b6dff9c8843af Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 12 Jan 2026 12:42:23 +0000 Subject: [PATCH 116/148] Fix some tests --- test/contexts.jl | 214 ++++++++++++++++++++------------------------ test/debug_utils.jl | 54 ----------- test/lkj.jl | 2 +- 3 files changed, 98 insertions(+), 172 deletions(-) diff --git a/test/contexts.jl b/test/contexts.jl index 9621013ac..24f6445f5 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -414,18 +414,6 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() end @testset "InitContext" begin - empty_varinfos = [ - ("untyped+metadata", VarInfo()), - ("typed+metadata", DynamicPPL.typed_varinfo(VarInfo())), - ("untyped+VNV", VarInfo(DynamicPPL.VarNamedVector())), - ( - "typed+VNV", - DynamicPPL.typed_vector_varinfo(DynamicPPL.typed_varinfo(VarInfo())), - ), - ("SVI+NamedTuple", SimpleVarInfo()), - ("Svi+Dict", SimpleVarInfo(OrderedDict{VarName,Any}())), - ] - @model function test_init_model() x ~ Normal() y ~ MvNormal(fill(x, 2), I) @@ -438,19 +426,17 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() # Check that init!! can generate values that weren't there # previously. model = test_init_model() - @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos - this_vi = deepcopy(empty_vi) - _, vi = DynamicPPL.init!!(model, this_vi, strategy) - @test Set(keys(vi)) == Set([@varname(x), @varname(y)]) - x, y = vi[@varname(x)], vi[@varname(y)] - @test x isa Real - @test y isa AbstractVector{<:Real} - @test length(y) == 2 - (; logprior, loglikelihood) = getlogp(vi) - @test logpdf(Normal(), x) + logpdf(MvNormal(fill(x, 2), I), y) == - logprior - @test logpdf(Normal(), 1.0) == loglikelihood - end + empty_vi = VarInfo() + this_vi = deepcopy(empty_vi) + _, vi = DynamicPPL.init!!(model, this_vi, strategy) + @test Set(keys(vi)) == Set([@varname(x), @varname(y)]) + x, y = vi[@varname(x)], vi[@varname(y)] + @test x isa Real + @test y isa AbstractVector{<:Real} + @test length(y) == 2 + (; logprior, loglikelihood) = getlogp(vi) + @test logpdf(Normal(), x) + logpdf(MvNormal(fill(x, 2), I), y) == logprior + @test logpdf(Normal(), 1.0) == loglikelihood end end @@ -458,40 +444,38 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @testset "replacing old values: $(typeof(strategy))" begin # Check that init!! can overwrite values that were already there. model = test_init_model() - @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos - # start by generating some rubbish values - vi = deepcopy(empty_vi) - old_x, old_y = 100000.00, [300000.00, 500000.00] - push!!(vi, @varname(x), old_x, Normal()) - push!!(vi, @varname(y), old_y, MvNormal(fill(old_x, 2), I)) - # then overwrite it - _, new_vi = DynamicPPL.init!!(model, vi, strategy) - new_x, new_y = new_vi[@varname(x)], new_vi[@varname(y)] - # check that the values are (presumably) different - @test old_x != new_x - @test old_y != new_y - end + empty_vi = VarInfo() + # start by generating some rubbish values + vi = deepcopy(empty_vi) + old_x, old_y = 100000.00, [300000.00, 500000.00] + push!!(vi, @varname(x), old_x, Normal()) + push!!(vi, @varname(y), old_y, MvNormal(fill(old_x, 2), I)) + # then overwrite it + _, new_vi = DynamicPPL.init!!(model, vi, strategy) + new_x, new_y = new_vi[@varname(x)], new_vi[@varname(y)] + # check that the values are (presumably) different + @test old_x != new_x + @test old_y != new_y end end function test_rng_respected(strategy::AbstractInitStrategy) @testset "check that RNG is respected: $(typeof(strategy))" begin model = test_init_model() - @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos - _, vi1 = DynamicPPL.init!!( - Xoshiro(468), model, deepcopy(empty_vi), strategy - ) - _, vi2 = DynamicPPL.init!!( - Xoshiro(468), model, deepcopy(empty_vi), strategy - ) - _, vi3 = DynamicPPL.init!!( - Xoshiro(469), model, deepcopy(empty_vi), strategy - ) - @test vi1[@varname(x)] == vi2[@varname(x)] - @test vi1[@varname(y)] == vi2[@varname(y)] - @test vi1[@varname(x)] != vi3[@varname(x)] - @test vi1[@varname(y)] != vi3[@varname(y)] - end + empty_vi = VarInfo() + _, vi1 = DynamicPPL.init!!( + Xoshiro(468), model, deepcopy(empty_vi), strategy + ) + _, vi2 = DynamicPPL.init!!( + Xoshiro(468), model, deepcopy(empty_vi), strategy + ) + _, vi3 = DynamicPPL.init!!( + Xoshiro(469), model, deepcopy(empty_vi), strategy + ) + @test vi1[@varname(x)] == vi2[@varname(x)] + @test vi1[@varname(y)] == vi2[@varname(y)] + @test vi1[@varname(x)] != vi3[@varname(x)] + @test vi1[@varname(y)] != vi3[@varname(y)] end end @@ -578,21 +562,20 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() params_nt = (; x=my_x, y=my_y) params_dict = Dict(@varname(x) => my_x, @varname(y) => my_y) model = test_init_model() - @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos - _, vi = DynamicPPL.init!!( - model, deepcopy(empty_vi), InitFromParams(params_nt) - ) - @test vi[@varname(x)] == my_x - @test vi[@varname(y)] == my_y - logp_nt = getlogp(vi) - _, vi = DynamicPPL.init!!( - model, deepcopy(empty_vi), InitFromParams(params_dict) - ) - @test vi[@varname(x)] == my_x - @test vi[@varname(y)] == my_y - logp_dict = getlogp(vi) - @test logp_nt == logp_dict - end + empty_vi = VarInfo() + _, vi = DynamicPPL.init!!( + model, deepcopy(empty_vi), InitFromParams(params_nt) + ) + @test vi[@varname(x)] == my_x + @test vi[@varname(y)] == my_y + logp_nt = getlogp(vi) + _, vi = DynamicPPL.init!!( + model, deepcopy(empty_vi), InitFromParams(params_dict) + ) + @test vi[@varname(x)] == my_x + @test vi[@varname(y)] == my_y + logp_dict = getlogp(vi) + @test logp_nt == logp_dict end @testset "given only partial parameters" begin @@ -600,56 +583,53 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() params_nt = (; x=my_x) params_dict = Dict(@varname(x) => my_x) model = test_init_model() - @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos - @testset "with InitFromPrior fallback" begin - _, vi = DynamicPPL.init!!( - Xoshiro(468), - model, - deepcopy(empty_vi), - InitFromParams(params_nt, InitFromPrior()), - ) - @test vi[@varname(x)] == my_x - nt_y = vi[@varname(y)] - @test nt_y isa AbstractVector{<:Real} - @test length(nt_y) == 2 - _, vi = DynamicPPL.init!!( - Xoshiro(469), - model, - deepcopy(empty_vi), - InitFromParams(params_dict, InitFromPrior()), - ) - @test vi[@varname(x)] == my_x - dict_y = vi[@varname(y)] - @test dict_y isa AbstractVector{<:Real} - @test length(dict_y) == 2 - # the values should be different since we used different seeds - @test dict_y != nt_y - end + empty_vi = VarInfo() + @testset "with InitFromPrior fallback" begin + _, vi = DynamicPPL.init!!( + Xoshiro(468), + model, + deepcopy(empty_vi), + InitFromParams(params_nt, InitFromPrior()), + ) + @test vi[@varname(x)] == my_x + nt_y = vi[@varname(y)] + @test nt_y isa AbstractVector{<:Real} + @test length(nt_y) == 2 + _, vi = DynamicPPL.init!!( + Xoshiro(469), + model, + deepcopy(empty_vi), + InitFromParams(params_dict, InitFromPrior()), + ) + @test vi[@varname(x)] == my_x + dict_y = vi[@varname(y)] + @test dict_y isa AbstractVector{<:Real} + @test length(dict_y) == 2 + # the values should be different since we used different seeds + @test dict_y != nt_y + end - @testset "with no fallback" begin - # These just don't have an entry for `y`. - @test_throws ErrorException DynamicPPL.init!!( - model, deepcopy(empty_vi), InitFromParams(params_nt, nothing) - ) - @test_throws ErrorException DynamicPPL.init!!( - model, deepcopy(empty_vi), InitFromParams(params_dict, nothing) - ) - # We also explicitly test the case where `y = missing`. - params_nt_missing = (; x=my_x, y=missing) - params_dict_missing = Dict( - @varname(x) => my_x, @varname(y) => missing - ) - @test_throws ErrorException DynamicPPL.init!!( - model, - deepcopy(empty_vi), - InitFromParams(params_nt_missing, nothing), - ) - @test_throws ErrorException DynamicPPL.init!!( - model, - deepcopy(empty_vi), - InitFromParams(params_dict_missing, nothing), - ) - end + @testset "with no fallback" begin + # These just don't have an entry for `y`. + @test_throws ErrorException DynamicPPL.init!!( + model, deepcopy(empty_vi), InitFromParams(params_nt, nothing) + ) + @test_throws ErrorException DynamicPPL.init!!( + model, deepcopy(empty_vi), InitFromParams(params_dict, nothing) + ) + # We also explicitly test the case where `y = missing`. + params_nt_missing = (; x=my_x, y=missing) + params_dict_missing = Dict(@varname(x) => my_x, @varname(y) => missing) + @test_throws ErrorException DynamicPPL.init!!( + model, + deepcopy(empty_vi), + InitFromParams(params_nt_missing, nothing), + ) + @test_throws ErrorException DynamicPPL.init!!( + model, + deepcopy(empty_vi), + InitFromParams(params_dict_missing, nothing), + ) end end end diff --git a/test/debug_utils.jl b/test/debug_utils.jl index f950f6b45..343282480 100644 --- a/test/debug_utils.jl +++ b/test/debug_utils.jl @@ -67,60 +67,6 @@ model = ModelOuterWorking2() @test check_model(model, VarInfo(model); error_on_failure=true) end - - @testset "subsumes (x then x[1])" begin - @model function buggy_subsumes_demo_model() - x = Vector{Float64}(undef, 2) - x ~ MvNormal(zeros(2), I) - x[1] ~ Normal() - return nothing - end - buggy_model = buggy_subsumes_demo_model() - varinfo = VarInfo(buggy_model) - - @test_logs (:warn,) (:warn,) check_model(buggy_model, varinfo) - issuccess = check_model(buggy_model, varinfo) - @test !issuccess - @test_throws ErrorException check_model( - buggy_model, varinfo; error_on_failure=true - ) - end - - @testset "subsumes (x[1] then x)" begin - @model function buggy_subsumes_demo_model() - x = Vector{Float64}(undef, 2) - x[1] ~ Normal() - x ~ MvNormal(zeros(2), I) - return nothing - end - buggy_model = buggy_subsumes_demo_model() - varinfo = VarInfo(buggy_model) - - @test_logs (:warn,) (:warn,) check_model(buggy_model, varinfo) - issuccess = check_model(buggy_model, varinfo) - @test !issuccess - @test_throws ErrorException check_model( - buggy_model, varinfo; error_on_failure=true - ) - end - - @testset "subsumes (x.a then x)" begin - @model function buggy_subsumes_demo_model() - x = (a=nothing,) - x.a ~ Normal() - x ~ Normal() - return nothing - end - buggy_model = buggy_subsumes_demo_model() - varinfo = VarInfo(buggy_model) - - @test_logs (:warn,) (:warn,) check_model(buggy_model, varinfo) - issuccess = check_model(buggy_model, varinfo) - @test !issuccess - @test_throws ErrorException check_model( - buggy_model, varinfo; error_on_failure=true - ) - end end @testset "NaN in data" begin diff --git a/test/lkj.jl b/test/lkj.jl index 5c5603aba..bab3ce185 100644 --- a/test/lkj.jl +++ b/test/lkj.jl @@ -37,7 +37,7 @@ end last(DynamicPPL.init!!(model, VarInfo(), init_strategy)) for _ in 1:n_samples ] corr_matrices = map(samples) do s - M = reshape(s.metadata.vals, (2, 2)) + M = reshape(DynamicPPL.getindex_internal(s, @varname(x)), (2, 2)) pd_from_triangular(M, uplo) end @test vec(mean(corr_matrices)) ≈ target_mean atol = _lkj_atol From 9812ad0b950f83c1911ce8f5857c047ae1335f2e Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 12 Jan 2026 12:44:53 +0000 Subject: [PATCH 117/148] Comment back in include of old VI files --- src/DynamicPPL.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 84b8b2e68..d6248a4d0 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -196,14 +196,14 @@ include("model.jl") include("varname.jl") include("distribution_wrappers.jl") include("submodel.jl") -# include("varnamedvector.jl") +include("varnamedvector.jl") include("accumulators.jl") include("default_accumulators.jl") include("abstract_varinfo.jl") include("threadsafe.jl") # include("varinfo.jl") include("vntvarinfo.jl") -# include("simple_varinfo.jl") +include("simple_varinfo.jl") include("onlyaccs.jl") include("compiler.jl") include("pointwise_logdensities.jl") From 6d44954d7420ccf80386e01985fe7b288e678921 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 12 Jan 2026 14:19:56 +0000 Subject: [PATCH 118/148] Remote JET extension and experimental.jl --- Project.toml | 3 - docs/Project.toml | 2 - docs/src/api.md | 9 --- ext/DynamicPPLJETExt.jl | 56 ----------------- src/DynamicPPL.jl | 22 ------- src/experimental.jl | 98 ------------------------------ test/Project.toml | 2 - test/ext/DynamicPPLJETExt.jl | 113 ----------------------------------- test/runtests.jl | 3 - 9 files changed, 308 deletions(-) delete mode 100644 ext/DynamicPPLJETExt.jl delete mode 100644 src/experimental.jl delete mode 100644 test/ext/DynamicPPLJETExt.jl diff --git a/Project.toml b/Project.toml index 1b899c906..a1a95c822 100644 --- a/Project.toml +++ b/Project.toml @@ -30,7 +30,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392" @@ -40,7 +39,6 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" DynamicPPLChainRulesCoreExt = ["ChainRulesCore"] DynamicPPLEnzymeCoreExt = ["EnzymeCore"] DynamicPPLForwardDiffExt = ["ForwardDiff"] -DynamicPPLJETExt = ["JET"] DynamicPPLMCMCChainsExt = ["MCMCChains"] DynamicPPLMarginalLogDensitiesExt = ["MarginalLogDensities"] DynamicPPLMooncakeExt = ["Mooncake"] @@ -62,7 +60,6 @@ DocStringExtensions = "0.9" EnzymeCore = "0.6 - 0.8" ForwardDiff = "0.10.12, 1" InteractiveUtils = "1" -JET = "0.9, 0.10, 0.11" KernelAbstractions = "0.9.33" LinearAlgebra = "1.6" LogDensityProblems = "2" diff --git a/docs/Project.toml b/docs/Project.toml index d5fa9a637..5bdb0a2db 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -8,7 +8,6 @@ DocumenterMermaid = "a078cd44-4d9c-4618-b545-3ab9d77f9177" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392" @@ -24,7 +23,6 @@ DocumenterMermaid = "0.1, 0.2" DynamicPPL = "0.40" FillArrays = "0.13, 1" ForwardDiff = "0.10, 1" -JET = "0.9, 0.10, 0.11" LogDensityProblems = "2" MCMCChains = "5, 6, 7" MarginalLogDensities = "0.4" diff --git a/docs/src/api.md b/docs/src/api.md index 084639c07..bfc5dcc8d 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -553,15 +553,6 @@ init get_param_eltype ``` -### Choosing a suitable VarInfo - -There is also the _experimental_ [`DynamicPPL.Experimental.determine_suitable_varinfo`](@ref), which uses static checking via [JET.jl](https://github.com/aviatesk/JET.jl) to determine whether one should use [`DynamicPPL.typed_varinfo`](@ref) or [`DynamicPPL.untyped_varinfo`](@ref), depending on which supports the model: - -```@docs -DynamicPPL.Experimental.determine_suitable_varinfo -DynamicPPL.Experimental.is_suitable_varinfo -``` - ### Converting VarInfos to/from chains It is a fairly common operation to want to convert a collection of `VarInfo` objects into a chains object for downstream analysis. diff --git a/ext/DynamicPPLJETExt.jl b/ext/DynamicPPLJETExt.jl deleted file mode 100644 index cb35c5ffb..000000000 --- a/ext/DynamicPPLJETExt.jl +++ /dev/null @@ -1,56 +0,0 @@ -module DynamicPPLJETExt - -using DynamicPPL: DynamicPPL -using JET: JET - -function DynamicPPL.Experimental.is_suitable_varinfo( - model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo; only_dppl::Bool=true -) - f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(model, varinfo) - # If specified, we only check errors originating somewhere in the DynamicPPL.jl. - # This way we don't just fall back to untyped if the user's code is the issue. - result = if only_dppl - JET.report_call(f, argtypes; target_modules=(JET.AnyFrameModule(DynamicPPL),)) - else - JET.report_call(f, argtypes) - end - return length(JET.get_reports(result)) == 0, result -end - -function DynamicPPL.Experimental._determine_varinfo_jet( - model::DynamicPPL.Model; only_dppl::Bool=true -) - # Generate a typed varinfo to test model type stability with - varinfo = DynamicPPL.typed_varinfo(model) - - # Check type stability of evaluation (i.e. DefaultContext) - model = DynamicPPL.setleafcontext(model, DynamicPPL.DefaultContext()) - eval_issuccess, eval_result = DynamicPPL.Experimental.is_suitable_varinfo( - model, varinfo; only_dppl - ) - if !eval_issuccess - @debug "Evaluation with typed varinfo failed with the following issues:" - @debug eval_result - end - - # Check type stability of initialisation (i.e. InitContext) - model = DynamicPPL.setleafcontext(model, DynamicPPL.InitContext()) - init_issuccess, init_result = DynamicPPL.Experimental.is_suitable_varinfo( - model, varinfo; only_dppl - ) - if !init_issuccess - @debug "Initialisation with typed varinfo failed with the following issues:" - @debug init_result - end - - # If neither of them failed, we can return the typed varinfo as it's type stable. - return if (eval_issuccess && init_issuccess) - varinfo - else - # Warn the user that we can't use the type stable one. - @warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo." - DynamicPPL.untyped_varinfo(model) - end -end - -end diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index d6248a4d0..b84c076be 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -211,7 +211,6 @@ include("logdensityfunction.jl") include("model_utils.jl") include("extract_priors.jl") include("values_as_in_model.jl") -include("experimental.jl") include("chains.jl") include("bijector.jl") @@ -223,27 +222,6 @@ include("deprecated.jl") if isdefined(Base.Experimental, :register_error_hint) function __init__() - # Better error message if users forget to load JET.jl - Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _ - requires_jet = - exc.f === DynamicPPL.Experimental._determine_varinfo_jet && - length(argtypes) >= 2 && - argtypes[1] <: Model && - argtypes[2] <: AbstractContext - requires_jet |= - exc.f === DynamicPPL.Experimental.is_suitable_varinfo && - length(argtypes) >= 3 && - argtypes[1] <: Model && - argtypes[2] <: AbstractContext && - argtypes[3] <: AbstractVarInfo - if requires_jet - print( - io, - "\n$(exc.f) requires JET.jl to be loaded. Please run `using JET` before calling $(exc.f).", - ) - end - end - # Same for MarginalLogDensities.jl Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _ requires_mld = diff --git a/src/experimental.jl b/src/experimental.jl deleted file mode 100644 index 8c82dca68..000000000 --- a/src/experimental.jl +++ /dev/null @@ -1,98 +0,0 @@ -module Experimental - -using DynamicPPL: DynamicPPL - -# This file only defines the names of the functions, and their docstrings. The actual implementations are in `ext/DynamicPPLJETExt.jl`, since we don't want to depend on JET.jl other than as a weak dependency. -""" - is_suitable_varinfo(model::Model, varinfo::AbstractVarInfo; kwargs...) - -Check if the `model` supports evaluation using the provided `varinfo`. - -!!! warning - Loading JET.jl is required before calling this function. - -# Arguments -- `model`: The model to verify the support for. -- `varinfo`: The varinfo to verify the support for. - -# Keyword Arguments -- `only_dppl`: If `true`, only consider error reports occuring in the tilde pipeline. Default: `true`. - -# Returns -- `issuccess`: `true` if the model supports the varinfo, otherwise `false`. -- `report`: The result of `report_call` from JET.jl. -""" -function is_suitable_varinfo end - -# Internal hook for JET.jl to overload. -function _determine_varinfo_jet end - -""" - determine_suitable_varinfo(model; only_dppl::Bool=true) - -Return a suitable varinfo for the given `model`. - -See also: [`DynamicPPL.Experimental.is_suitable_varinfo`](@ref). - -!!! warning - For full functionality, this requires JET.jl to be loaded. - If JET.jl is not loaded, this function will assume the model is compatible with typed varinfo. - -# Arguments -- `model`: The model for which to determine the varinfo. - -# Keyword Arguments -- `only_dppl`: If `true`, only consider error reports within DynamicPPL.jl. - -# Examples - -```jldoctest -julia> using DynamicPPL.Experimental: determine_suitable_varinfo - -julia> using JET: JET # needs to be loaded for full functionality - -julia> @model function model_with_random_support() - x ~ Bernoulli() - if x - y ~ Normal() - else - z ~ Normal() - end - end -model_with_random_support (generic function with 2 methods) - -julia> model = model_with_random_support(); - -julia> # Typed varinfo cannot handle this random support model properly - # as using a single execution of the model will not see all random variables. - # Hence, this this model requires untyped varinfo. - vi = determine_suitable_varinfo(model); -┌ Warning: Model seems incompatible with typed varinfo. Falling back to untyped varinfo. -└ @ DynamicPPLJETExt ~/.julia/dev/DynamicPPL.jl/ext/DynamicPPLJETExt.jl:48 - -julia> vi isa typeof(DynamicPPL.untyped_varinfo(model)) -true - -julia> # In contrast, a simple model with no random support can be handled by typed varinfo. - @model model_with_static_support() = x ~ Normal() -model_with_static_support (generic function with 2 methods) - -julia> vi = determine_suitable_varinfo(model_with_static_support()); - -julia> vi isa typeof(DynamicPPL.typed_varinfo(model_with_static_support())) -true -``` -""" -function determine_suitable_varinfo(model::DynamicPPL.Model; only_dppl::Bool=true) - # If JET.jl has been loaded, and thus `determine_varinfo` has been defined, we use that. - return if Base.get_extension(DynamicPPL, :DynamicPPLJETExt) !== nothing - _determine_varinfo_jet(model; only_dppl) - else - # Warn the user. - @warn "JET.jl is not loaded. Assumes the model is compatible with typed varinfo." - # Otherwise, we use the, possibly incorrect, default typed varinfo (to stay backwards compat). - DynamicPPL.typed_varinfo(model, context) - end -end - -end diff --git a/test/Project.toml b/test/Project.toml index 927954ba4..9c146eb97 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -14,7 +14,6 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" @@ -46,7 +45,6 @@ Distributions = "0.25" DistributionsAD = "0.6.3" Documenter = "1" ForwardDiff = "0.10.12, 1" -JET = "0.9, 0.10, 0.11" LogDensityProblems = "2" MCMCChains = "7.2.1" MacroTools = "0.5.6" diff --git a/test/ext/DynamicPPLJETExt.jl b/test/ext/DynamicPPLJETExt.jl deleted file mode 100644 index e46c25113..000000000 --- a/test/ext/DynamicPPLJETExt.jl +++ /dev/null @@ -1,113 +0,0 @@ -@testset "DynamicPPLJETExt.jl" begin - @testset "determine_suitable_varinfo" begin - @model function demo1() - x ~ Bernoulli() - if x - y ~ Normal() - else - z ~ Normal() - end - end - model = demo1() - @test DynamicPPL.Experimental.determine_suitable_varinfo(model) isa - DynamicPPL.UntypedVarInfo - - @model demo2() = x ~ Normal() - @test DynamicPPL.Experimental.determine_suitable_varinfo(demo2()) isa - DynamicPPL.NTVarInfo - - @model function demo3() - # Just making sure that nothing strange happens when type inference fails. - x = Vector(undef, 1) - x[1] ~ Bernoulli() - if x[1] - y ~ Normal() - else - z ~ Normal() - end - end - @test DynamicPPL.Experimental.determine_suitable_varinfo(demo3()) isa - DynamicPPL.UntypedVarInfo - - # Evaluation works (and it would even do so in practice), but sampling - # will fail due to storing `Cauchy{Float64}` in `Vector{Normal{Float64}}`. - @model function demo4() - x ~ Bernoulli() - if x - y ~ Normal() - else - y ~ Cauchy() # different distibution, but same transformation - end - end - @test DynamicPPL.Experimental.determine_suitable_varinfo(demo4()) isa - DynamicPPL.UntypedVarInfo - - # In this model, the type error occurs in the user code rather than in DynamicPPL. - @model function demo5() - x ~ Normal() - xs = Any[] - push!(xs, x) - # `sum(::Vector{Any})` can potentially error unless the dynamic manages to resolve the - # correct `zero` method. As a result, this code will run, but JET will raise this is an issue. - return sum(xs) - end - # Should pass if we're only checking the tilde statements. - @test DynamicPPL.Experimental.determine_suitable_varinfo(demo5()) isa - DynamicPPL.NTVarInfo - # Should fail if we're including errors in the model body. - @test DynamicPPL.Experimental.determine_suitable_varinfo( - demo5(); only_dppl=false - ) isa DynamicPPL.UntypedVarInfo - end - - @testset "demo models" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.ALL_MODELS - if model.f === DynamicPPL.TestUtils.demo_lkjchol - # TODO(mhauru) - # The LKJCholesky model fails with JET. The problem is not with Turing but - # with Distributions, and ultimately this in LinearAlgebra: - # julia> v = @view rand(2,2)[:,1]; - # - # julia> JET.@report_call norm(v) - # ═════ 2 possible errors found ═════ - # blahblah - # The below trivial call to @test is just marking that there's something - # broken here. - @test false broken = true - continue - end - # Use debug logging below. - varinfo = DynamicPPL.Experimental.determine_suitable_varinfo(model) - # Check that the inferred varinfo is indeed suitable for evaluation - f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - model, varinfo - ) - JET.test_call(f_eval, argtypes_eval) - - # For our demo models, they should all result in typed. - is_typed = varinfo isa DynamicPPL.NTVarInfo - @test is_typed - # If the test failed, check what the type stability problem was for - # the typed varinfo. This is mostly useful for debugging from test - # logs. - if !is_typed - @info "Model `$(model.f)` is not type stable with typed varinfo." - typed_vi = DynamicPPL.typed_varinfo(model) - - @info "Evaluating with DefaultContext:" - model = DynamicPPL.setleafcontext(model, DynamicPPL.DefaultContext()) - f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - model, varinfo - ) - JET.test_call(f, argtypes) - - @info "Initialising with InitContext:" - model = DynamicPPL.setleafcontext(model, DynamicPPL.InitContext()) - f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - model, varinfo - ) - JET.test_call(f, argtypes) - end - end - end -end diff --git a/test/runtests.jl b/test/runtests.jl index e04b664fe..23dda437b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -28,8 +28,6 @@ using Test using Distributions using LinearAlgebra # Diagonal -using JET: JET - using Combinatorics: combinations using OrderedCollections: OrderedSet @@ -76,7 +74,6 @@ include("test_util.jl") include("logdensityfunction.jl") @testset "extensions" begin include("ext/DynamicPPLMCMCChainsExt.jl") - include("ext/DynamicPPLJETExt.jl") include("ext/DynamicPPLMarginalLogDensitiesExt.jl") end @testset "ad" begin From d5bfa2c1748e0b3d629da1d79de8000d035357a4 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 12 Jan 2026 18:12:04 +0000 Subject: [PATCH 119/148] Reimplement bijector.jl --- src/bijector.jl | 111 +++++++++++++++++++++++++----------------------- 1 file changed, 58 insertions(+), 53 deletions(-) diff --git a/src/bijector.jl b/src/bijector.jl index 31fe7cd88..576205641 100644 --- a/src/bijector.jl +++ b/src/bijector.jl @@ -1,60 +1,65 @@ +struct BijectorAccumulator <: AbstractAccumulator + bijectors::Vector{Any} + sizes::Vector{Int} +end -""" - bijector(model::Model[, sym2ranges = Val(false)]) +BijectorAccumulator() = BijectorAccumulator(Bijectors.Bijector[], UnitRange{Int}[]) + +function Base.:(==)(acc1::BijectorAccumulator, acc2::BijectorAccumulator) + return (acc1.bijectors == acc2.bijectors && acc1.sizes == acc2.sizes) +end + +function Base.copy(acc::BijectorAccumulator) + return BijectorAccumulator(copy(acc.bijectors), copy(acc.sizes)) +end + +accumulator_name(::Type{<:BijectorAccumulator}) = :Bijector + +function _zero(acc::BijectorAccumulator) + return BijectorAccumulator(empty(acc.bijectors), empty(acc.sizes)) +end +reset(acc::BijectorAccumulator) = _zero(acc) +split(acc::BijectorAccumulator) = _zero(acc) +function combine(acc1::BijectorAccumulator, acc2::BijectorAccumulator) + return BijectorAccumulator( + vcat(acc1.bijectors, acc2.bijectors), vcat(acc1.sizes, acc2.sizes) + ) +end + +function accumulate_assume!!(acc::BijectorAccumulator, val, logjac, vn, right) + bijector = _compose_no_identity( + to_linked_vec_transform(right), from_vec_transform(right) + ) + push!(acc.bijectors, bijector) + push!(acc.sizes, prod(output_size(to_vec_transform(right), right); init=1)) + return acc +end + +accumulate_observe!!(acc::BijectorAccumulator, right, left, vn) = acc -Returns a `Stacked <: Bijector` which maps from the support of the posterior to ℝᵈ with `d` -denoting the dimensionality of the latent variables. """ -function Bijectors.bijector( - model::DynamicPPL.Model, - (::Val{sym2ranges})=Val(false); - varinfo=DynamicPPL.VarInfo(model), -) where {sym2ranges} - dists = vcat([varinfo.metadata[sym].dists for sym in keys(varinfo.metadata)]...) - - num_ranges = sum([ - length(varinfo.metadata[sym].ranges) for sym in keys(varinfo.metadata) - ]) - ranges = Vector{UnitRange{Int}}(undef, num_ranges) - idx = 0 - range_idx = 1 - - # ranges might be discontinuous => values are vectors of ranges rather than just ranges - sym_lookup = Dict{Symbol,Vector{UnitRange{Int}}}() - for sym in keys(varinfo.metadata) - sym_lookup[sym] = Vector{UnitRange{Int}}() - for r in varinfo.metadata[sym].ranges - ranges[range_idx] = idx .+ r - push!(sym_lookup[sym], ranges[range_idx]) - range_idx += 1 - end - - idx += varinfo.metadata[sym].ranges[end][end] - end + bijector(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) - bs = map(tuple(dists...)) do d - b = Bijectors.bijector(d) - if d isa Distributions.UnivariateDistribution - b - else - # Wrap a bijector `f` such that it operates on vectors of length `prod(in_size)` - # and produces a vector of length `prod(Bijectors.output(f, in_size))`. - in_size = size(d) - vec_in_length = prod(in_size) - reshape_inner = Bijectors.Reshape((vec_in_length,), in_size) - out_size = Bijectors.output_size(b, in_size) - vec_out_length = prod(out_size) - reshape_outer = Bijectors.Reshape(out_size, (vec_out_length,)) - reshape_outer ∘ b ∘ reshape_inner - end - end +Returns a `Stacked <: Bijector` which maps from constrained to unconstrained space. + +The input to the bijector is a vector of values for the whole model, like the input to +`unflatten!!`. These are in constrained space, i.e., respecting variable constraints. +The output is a vector of unconstrained values. - if sym2ranges - return ( - Bijectors.Stacked(bs, ranges), - (; collect(zip(keys(sym_lookup), values(sym_lookup)))...), - ) - else - return Bijectors.Stacked(bs, ranges) +`init_strategy` is passed to `DynamicPPL.init!!` to determine what values the model is +evaluated with. This may affect the results if the prior distributions or constraints of +variables are dependent on other variables. +""" +function Bijectors.bijector( + model::DynamicPPL.Model, init_strategy::AbstractInitStrategy=InitFromPrior() +) + vi = OnlyAccsVarInfo((BijectorAccumulator(),)) + vi = last(DynamicPPL.init!!(model, vi, init_strategy)) + acc = getacc(vi, Val(:Bijector)) + ranges = foldl(acc.sizes; init=UnitRange{Int}[]) do cumulant, sz + last_index = length(cumulant) > 0 ? last(cumulant).stop : 0 + push!(cumulant, (last_index + 1):(last_index + sz)) + return cumulant end + return Bijectors.Stacked(acc.bijectors, ranges) end From eb903e1ac79bf611422cd5cc132081af4ee70acb Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 12 Jan 2026 18:13:06 +0000 Subject: [PATCH 120/148] Move linking code to VarInfo, fix ProductNamedDistribution bijector, etc --- src/contexts/init.jl | 64 ++------------------------- src/model.jl | 4 +- src/test_utils/model_interface.jl | 4 +- src/utils.jl | 19 +++++++- src/vntvarinfo.jl | 73 ++++++++++++++++++++++++++----- test/model.jl | 2 +- 6 files changed, 89 insertions(+), 77 deletions(-) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 45d6356f1..65ea08ec5 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -309,68 +309,10 @@ function tilde_assume!!( ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo ) val, transform = init(ctx.rng, vn, dist, ctx.strategy) - x, inv_logjac = with_logabsdet_jacobian(transform, val) - # Determine whether to insert a transformed value into the VarInfo. - # If the VarInfo alrady had a value for this variable, we will - # keep the same linked status as in the original VarInfo. If not, we - # check the rest of the VarInfo to see if other variables are linked. - # is_transformed(vi) returns true if vi is nonempty and all variables in vi - # are linked. - insert_transformed_value = haskey(vi, vn) ? is_transformed(vi, vn) : is_transformed(vi) - val_to_insert, logjac = if insert_transformed_value - # Calculate the forward logjac and sum them up. - lt = link_transform(dist) - y, fwd_logjac = with_logabsdet_jacobian(lt, x) - transform = _compose_no_identity(transform, lt) - # Note that if we use VectorWithRanges with a full VarInfo, this double-Jacobian - # calculation wastes a lot of time going from linked vectorised -> unlinked -> - # linked, and `inv_logjac` will also just be the negative of `fwd_logjac`. - # - # However, `VectorWithRanges` is only really used with `OnlyAccsVarInfo`, in which - # case this branch is never hit (since `in_varinfo` will always be false). It does - # mean that the combination of InitFromParams{<:VectorWithRanges} with a full, - # linked, VarInfo will be very slow. That should never really be used, though. So - # (at least for now) we can leave this branch in for full generality with other - # combinations of init strategies / VarInfo. - # - # TODO(penelopeysm): Figure out one day how to refactor this. The crux of the issue - # is that the transform used by `VectorWithRanges` is `from_linked_VEC_transform`, - # which is NOT the same as `inverse(link_transform)` (because there is an additional - # vectorisation step). We need `init` and `tilde_assume!!` to share this information - # but it's not clear right now how to do this. In my opinion, there are a couple of - # potential ways forward: - # - # 1. Just remove metadata entirely so that there is never any need to construct - # a linked vectorised value again. This would require us to use VAIMAcc as the only - # way of getting values. I consider this the best option, but it might take a long - # time. - # - # 2. Clean up the behaviour of bijectors so that we can have a complete separation - # between the linking and vectorisation parts of it. That way, `x` can either be - # unlinked, unlinked vectorised, linked, or linked vectorised, and regardless of - # which it is, we should only need to apply at most one linking and one - # vectorisation transform. Doing so would allow us to remove the first call to - # `with_logabsdet_jacobian`, and instead compose and/or uncompose the - # transformations before calling `with_logabsdet_jacobian` once. - y, -inv_logjac + fwd_logjac - else - x, -inv_logjac - end - # Add the new value to the VarInfo. `push!!` errors if the value already - # exists, hence the need for setindex!!. - vi = if vi isa VNTVarInfo - x_size = hasmethod(size, Tuple{typeof(x)}) ? size(x) : () - vi = push!!(vi, vn, val_to_insert, inverse(transform), x_size) - else - push!!(vi, vn, val_to_insert, dist) - end - # Neither of these set the `trans` flag so we have to do it manually if - # necessary. - if insert_transformed_value - vi = set_transformed!!(vi, true, vn) - end + x, init_logjac = with_logabsdet_jacobian(transform, val) + vi, logjac = setindex_with_dist!!(vi, x, dist, vn) # `accumulate_assume!!` wants untransformed values as the second argument. - vi = accumulate_assume!!(vi, x, logjac, vn, dist) + vi = accumulate_assume!!(vi, x, init_logjac + logjac, vn, dist) # We always return the untransformed value here, as that will determine # what the lhs of the tilde-statement is set to. return x, vi diff --git a/src/model.jl b/src/model.jl index 8bfeaf6a1..91558ecdc 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1085,7 +1085,9 @@ Base.nameof(model::Model{<:Function}) = nameof(model.f) Generate a sample of type `T` from the prior distribution of the `model`. """ function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T} - x = last(init!!(rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()))) + vi = VarInfo() + vi = setaccs!!(vi, DynamicPPL.AccumulatorTuple()) + x = last(init!!(rng, model, vi)) return values_as(x, T) end diff --git a/src/test_utils/model_interface.jl b/src/test_utils/model_interface.jl index e7fb16fbe..9914c05ca 100644 --- a/src/test_utils/model_interface.jl +++ b/src/test_utils/model_interface.jl @@ -89,10 +89,10 @@ function logprior_true_with_logabsdet_jacobian end Return a collection of `VarName` as they are expected to appear in the model. Even though it is recommended to implement this by hand for a particular `Model`, -a default implementation using [`SimpleVarInfo{<:Dict}`](@ref) is provided. +a default implementation using [`VarInfo`](@ref) is provided. """ function varnames(model::Model) - result = collect(keys(last(DynamicPPL.init!!(model, SimpleVarInfo(OrderedDict()))))) + result = collect(keys(last(DynamicPPL.init!!(model, VarInfo())))) # Concretise the element type. return [x for x in result] end diff --git a/src/utils.jl b/src/utils.jl index 0e03c5cdc..ba79f94b4 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -406,7 +406,7 @@ from_vec_transform(dist::Distribution) = from_vec_transform_for_size(size(dist)) from_vec_transform(::UnivariateDistribution) = UnwrapSingletonTransform() from_vec_transform(dist::LKJCholesky) = ToChol(dist.uplo) ∘ ReshapeTransform(size(dist)) -struct ProductNamedTupleUnvecTransform{names,T<:NamedTuple{names}} +struct ProductNamedTupleUnvecTransform{names,T<:NamedTuple{names}} <: Bijectors.Bijector dists::T # The `i`-th input range corresponds to the segment of the input vector # that belongs to the `i`-th distribution. @@ -439,13 +439,30 @@ end return expr end +@generated function (inv_trf::Bijectors.Inverse{<:ProductNamedTupleUnvecTransform{names}})( + x::NamedTuple{names} +) where {names} + exprs = Expr[] + for name in names + push!(exprs, :(to_vec_transform(inv_trf.orig.dists.$name)(x.$name))) + end + return :(vcat($(exprs...))) +end + function from_vec_transform(dist::Distributions.ProductNamedTupleDistribution) return ProductNamedTupleUnvecTransform(dist) end + function Bijectors.with_logabsdet_jacobian(f::ProductNamedTupleUnvecTransform, x) return f(x), zero(LogProbType) end +function Bijectors.with_logabsdet_jacobian( + inv_f::Bijectors.Inverse{<:ProductNamedTupleUnvecTransform}, x +) + return inv_f(x), zero(LogProbType) +end + # This function returns the length of the vector that the function from_vec_transform # expects. This helps us determine which segment of a concatenated vector belongs to which # variable. diff --git a/src/vntvarinfo.jl b/src/vntvarinfo.jl index a0392334c..6ce1a861e 100644 --- a/src/vntvarinfo.jl +++ b/src/vntvarinfo.jl @@ -61,19 +61,39 @@ function setindex_internal!!(vi::VNTVarInfo, val, vn::VarName) return VNTVarInfo(new_values, vi.accs) end -BangBang.setindex!!(vi::VNTVarInfo, val, vn::VarName) = push!!(vi, vn, val) - -# TODO(mhauru) The arguments are in the wrong order, but this is the current convetion. -function BangBang.push!!( - vi::VNTVarInfo, vn::VarName, val, transform=typed_identity, orig_size=size(val) -) +# TODO(mhauru) It shouldn't really be VarInfo's business to know about `dist`. However, +# we need `dist` to determine the linking transformation (or even just the vectorisation +# transformation, in the case of ProductNamedTupleDistribions), and if we leave the work +# of doing the transformation to the caller, it'll be done even when e.g. using +# OnlyAccsVarInfo. Hence having this function. It should eventually hopefully be removed +# once VAIMAcc is the only way to get values out of an evaluation. +function setindex_with_dist!!(vi::VNTVarInfo, val, dist::Distribution, vn::VarName) + # Determine whether to insert a transformed value into `vi`. + # If the VarInfo alrady had a value for this variable, we will + # keep the same linked status as in the original VarInfo. If not, we + # check the rest of the VarInfo to see if other variables are linked. + # is_transformed(vi) returns true if vi is nonempty and all variables in vi + # are linked. + insert_transformed_value = haskey(vi, vn) ? is_transformed(vi, vn) : is_transformed(vi) # TODO(mhauru) We should move away from having all values vectorised by default. # That messes with our use of unflatten though, so will require some thought. - transform = _compose_no_identity(transform, from_vec_transform(val)) - val = to_vec_transform(val)(val) - new_tv = TransformedValue(val, false, transform, orig_size) - new_values = setindex!!(vi.values, new_tv, vn) - return VNTVarInfo(new_values, vi.accs) + transform = if insert_transformed_value + from_linked_vec_transform(dist) + else + from_vec_transform(dist) + end + transformed_val, logjac = with_logabsdet_jacobian(inverse(transform), val) + val_size = hasmethod(size, Tuple{typeof(val)}) ? size(val) : () + tv = TransformedValue(transformed_val, insert_transformed_value, transform, val_size) + vi = VNTVarInfo(setindex!!(vi.values, tv, vn), vi.accs) + return vi, logjac +end + +function BangBang.setindex!!(vi::VNTVarInfo, val, vn::VarName) + transform = from_vec_transform(val) + transformed_val = inverse(transform)(val) + tv = TransformedValue(transformed_val, false, transform, size(val)) + return VNTVarInfo(setindex!!(vi.values, tv, vn), vi.accs) end Base.keys(vi::VNTVarInfo) = keys(vi.values) @@ -86,6 +106,20 @@ function set_transformed!!(vi::VNTVarInfo, linked::Bool, vn::VarName) return VNTVarInfo(new_values, vi.accs) end +# VNTVarInfo does not care whether the transformation was Static or Dynamic, it just tracks +# whether one was applied at all. +function set_transformed!!(vi::VNTVarInfo, ::AbstractTransformation, vn::VarName) + return set_transformed!!(vi, true, vn) +end + +set_transformed!!(vi::VNTVarInfo, ::AbstractTransformation) = set_transformed!!(vi, true) + +function set_transformed!!(vi::VNTVarInfo, ::NoTransformation, vn::VarName) + return set_transformed!!(vi, false, vn) +end + +set_transformed!!(vi::VNTVarInfo, ::NoTransformation) = set_transformed!!(vi, false) + function set_transformed!!(vi::VNTVarInfo, linked::Bool) new_values = map_values!!(vi.values) do tv TransformedValue(tv.val, linked, tv.transform, tv.size) @@ -238,6 +272,23 @@ function values_as(vi::VNTVarInfo, ::Type{T}) where {T<:AbstractDict} end, vi.values; init=T()) end +# TODO(mhauru) I really dislike this sort of conversion to Symbols, but it's the current +# interface provided by rand(::Model). We should change that to return a VarNamedTuple +# instead, and then this method (and any other values_as methods for NamedTuple) could be +# removed. +function values_as(vi::VNTVarInfo, ::Type{NamedTuple}) + return mapfoldl( + identity, + function (cumulant, pair) + vn, tv = pair + val = tv.transform(tv.val) + return setindex!!(cumulant, val, Symbol(vn)) + end, + vi.values; + init=NamedTuple(), + ) +end + # TODO(mhauru) These two are now redundant, just conforming to the old interface # temporarily. function untyped_varinfo( diff --git a/test/model.jl b/test/model.jl index 05688c224..7c5dc2fcc 100644 --- a/test/model.jl +++ b/test/model.jl @@ -311,7 +311,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @test logjoint(model, x) != DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian(model, x...) # Ensure `varnames` is implemented. - vi = last(DynamicPPL.init!!(model, SimpleVarInfo(OrderedDict{VarName,Any}()))) + vi = last(DynamicPPL.init!!(model, VarInfo())) @test all(collect(keys(vi)) .== DynamicPPL.TestUtils.varnames(model)) # Ensure `posterior_mean` is implemented. @test DynamicPPL.TestUtils.posterior_mean(model) isa typeof(x) From 469a71514626f916c4594c2a9592d80b8752b9c3 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 12 Jan 2026 18:54:33 +0000 Subject: [PATCH 121/148] Mark a test as broken --- src/abstract_varinfo.jl | 7 +++++-- test/chains.jl | 10 +++++++++- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index ef1d92042..0c15cb9c7 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -839,12 +839,15 @@ function link!!( ) # TODO(mhauru) This assumes that the user has defined the bijector using the same # variable ordering as what `vi[:]` and `unflatten!!(vi, x)` use. This is a bad user - # interface, and it's also dangerous for any AbstractVarInfo types that may not respect - # a particular ordering, such as SimpleVarInfo{Dict}. + # interface. b = inverse(t.bijector) x = vi[:] y, logjac = with_logabsdet_jacobian(b, x) # Set parameters and add the logjac term. + # TODO(mhauru) This doesn't set the transforms of `vi`. With the old Metadata that meant + # that getindex(vi, vn) would apply the default link transform of the distribution. With + # the new VarNamedTuple-based VarInfo it means that getindex(vi, vn) won't apply any + # transform. Neither is correct, rather the transform should be the inverse of b. vi = unflatten!!(vi, y) if hasacc(vi, Val(:LogJacobian)) vi = acclogjac!!(vi, logjac) diff --git a/test/chains.jl b/test/chains.jl index 608a9a9cf..d69d2d4ca 100644 --- a/test/chains.jl +++ b/test/chains.jl @@ -68,8 +68,16 @@ end @testset "ParamsWithStats from LogDensityFunction" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.ALL_MODELS - unlinked_vi = VarInfo(m) + if m.f === DynamicPPL.TestUtils.demo_static_transformation + # TODO(mhauru) These tests are broken for demo_static_transformation because + # vi[vn] doesn't know which transform it should apply to the internally stored + # value. This requires a rethink, either of StaticTransformation or of what the + # comparison in this test should be. + @test false broken = true + continue + end @testset "$islinked" for islinked in (false, true) + unlinked_vi = VarInfo(m) vi = if islinked DynamicPPL.link!!(unlinked_vi, m) else From 89a8396a35ab7db7f5ad9d6ba8d28d88b0c5b147 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 12 Jan 2026 19:32:33 +0000 Subject: [PATCH 122/148] Various bugfixes --- src/threadsafe.jl | 15 +++++---------- test/compiler.jl | 10 +++++----- test/contexts.jl | 6 ++++-- test/varinfo.jl | 12 ++++++------ 4 files changed, 20 insertions(+), 23 deletions(-) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index f168eb7c1..88200680a 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -67,16 +67,6 @@ end has_varnamedvector(vi::ThreadSafeVarInfo) = has_varnamedvector(vi.varinfo) -function BangBang.push!!(vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distribution) - return Accessors.@set vi.varinfo = push!!(vi.varinfo, vn, r, dist) -end - -function BangBang.push!!( - vi::ThreadSafeVarInfo, vn::VarName, r, transform=typed_identity, orig_size=size(r) -) - return Accessors.@set vi.varinfo = push!!(vi.varinfo, vn, r, transform, orig_size) -end - syms(vi::ThreadSafeVarInfo) = syms(vi.varinfo) setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn) @@ -168,6 +158,11 @@ function getindex(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}, dist::D return getindex(vi.varinfo, vns, dist) end +function setindex_with_dist!!(vi::ThreadSafeVarInfo, val, dist::Distribution, vn::VarName) + vi_inner, logjac = setindex_with_dist!!(vi.varinfo, val, dist, vn) + return Accessors.@set(vi.varinfo = vi_inner), logjac +end + function BangBang.setindex!!(vi::ThreadSafeVarInfo, vals, vn::VarName) return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, vals, vn) end diff --git a/test/compiler.jl b/test/compiler.jl index e4a9a2474..8d0105947 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -604,9 +604,9 @@ module Issue537 end # Even if the return-value is `AbstractVarInfo`, we should return # a `Tuple` with `AbstractVarInfo` in the second component too. @model demo() = return __varinfo__ - retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) - @test svi == SimpleVarInfo() - @test retval == svi + retval, vi = DynamicPPL.init!!(demo(), VarInfo()) + @test vi == VarInfo() + @test retval == vi # We should not be altering return-values other than at top-level. @model function demo() @@ -615,11 +615,11 @@ module Issue537 end f(x) = return x^2 return f(1.0) end - retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) + retval, vi = DynamicPPL.init!!(demo(), VarInfo()) @test retval isa Float64 @model demo() = x ~ Normal() - retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) + retval, vi = DynamicPPL.init!!(demo(), VarInfo()) # Return-value when using `to_submodel` @model inner() = x ~ Normal() diff --git a/test/contexts.jl b/test/contexts.jl index 24f6445f5..435561267 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -448,8 +448,10 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() # start by generating some rubbish values vi = deepcopy(empty_vi) old_x, old_y = 100000.00, [300000.00, 500000.00] - push!!(vi, @varname(x), old_x, Normal()) - push!!(vi, @varname(y), old_y, MvNormal(fill(old_x, 2), I)) + vi, _ = DynamicPPL.setindex_with_dist!!(vi, old_x, Normal(), @varname(x)) + vi, _ = DynamicPPL.setindex_with_dist!!( + vi, old_y, MvNormal(fill(old_x, 2), I), @varname(y) + ) # then overwrite it _, new_vi = DynamicPPL.init!!(model, vi, strategy) new_x, new_y = new_vi[@varname(x)], new_vi[@varname(y)] diff --git a/test/varinfo.jl b/test/varinfo.jl index 0bea67402..1d01a0cf8 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -17,7 +17,7 @@ end @testset "varinfo.jl" begin @testset "Base" begin # Test Base functions: - # in, keys, haskey, isempty, push!!, empty!!, + # in, keys, haskey, isempty, setindex!!, empty!!, # getindex, setindex!, getproperty, setproperty! vi = VarInfo() @@ -30,7 +30,7 @@ end @test isempty(vi) @test !haskey(vi, vn) @test !(vn in keys(vi)) - vi = push!!(vi, vn, r) + vi = setindex!!(vi, r, vn) @test !isempty(vi) @test haskey(vi, vn) @test vn in keys(vi) @@ -44,7 +44,7 @@ end vi = empty!!(vi) @test isempty(vi) - vi = push!!(vi, vn, r) + vi = setindex!!(vi, r, vn) @test !isempty(vi) end @@ -223,7 +223,7 @@ end vn_x = @varname x r = rand() - vi = push!!(vi, vn_x, r) + vi = setindex!!(vi, r, vn_x) # is_transformed is unset by default @test !is_transformed(vi, vn_x) @@ -637,9 +637,9 @@ end @testset "merge different dimensions" begin vn = @varname(x) vi_single = VarInfo() - vi_single = push!!(vi_single, vn, 1.0) + vi_single = setindex!!(vi_single, 1.0, vn) vi_double = VarInfo() - vi_double = push!!(vi_double, vn, [0.5, 0.6]) + vi_double = setindex!!(vi_double, [0.5, 0.6], vn) @test merge(vi_single, vi_double)[vn] == [0.5, 0.6] @test merge(vi_double, vi_single)[vn] == 1.0 end From 8cf8ab0dfa88afdf1e0efa9ca60cbc138686121a Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 12 Jan 2026 20:05:10 +0000 Subject: [PATCH 123/148] Remove SimpleVarInfo, VarNamedVector, and the old VarInfo type --- benchmarks/benchmarks.jl | 61 +- benchmarks/src/DynamicPPLBenchmarks.jl | 43 +- benchmarks/src/Models.jl | 2 +- docs/src/api.md | 24 +- docs/src/internals/varinfo.md | 295 +--- ext/DynamicPPLChainRulesCoreExt.jl | 2 - src/DynamicPPL.jl | 6 +- src/abstract_varinfo.jl | 101 +- src/contexts/transformation.jl | 2 +- src/logdensityfunction.jl | 50 - src/model.jl | 111 ++ src/simple_varinfo.jl | 647 --------- src/threadsafe.jl | 2 - src/utils.jl | 5 - src/varinfo.jl | 1810 ------------------------ src/varnamedvector.jl | 1674 ---------------------- src/vntvarinfo.jl | 9 + test/model.jl | 10 +- test/runtests.jl | 2 - test/simple_varinfo.jl | 345 ----- test/test_util.jl | 23 - test/varinfo.jl | 36 +- test/varnamedvector.jl | 711 ---------- 23 files changed, 198 insertions(+), 5773 deletions(-) delete mode 100644 src/simple_varinfo.jl delete mode 100644 src/varinfo.jl delete mode 100644 src/varnamedvector.jl delete mode 100644 test/simple_varinfo.jl delete mode 100644 test/varnamedvector.jl diff --git a/benchmarks/benchmarks.jl b/benchmarks/benchmarks.jl index e8ffa7e0b..5be32fdef 100644 --- a/benchmarks/benchmarks.jl +++ b/benchmarks/benchmarks.jl @@ -9,9 +9,7 @@ using StableRNGs: StableRNG rng = StableRNG(23) -colnames = [ - "Model", "Dim", "AD Backend", "VarInfo", "Linked", "t(eval)/t(ref)", "t(grad)/t(eval)" -] +colnames = ["Model", "Dim", "AD Backend", "Linked", "t(eval)/t(ref)", "t(grad)/t(eval)"] function print_results(results_table; to_json=false) if to_json # Print to the given file as JSON @@ -58,31 +56,26 @@ function run(; to_json=false) end # Specify the combinations to test: - # (Model Name, model instance, VarInfo choice, AD backend, linked) + # (Model Name, model instance, AD backend, linked) chosen_combinations = [ ( "Simple assume observe", Models.simple_assume_observe(randn(rng)), - :typed, :forwarddiff, false, ), - ("Smorgasbord", smorgasbord_instance, :typed, :forwarddiff, false), - ("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true), - ("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true), - ("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true), - ("Smorgasbord", smorgasbord_instance, :typed_vector, :forwarddiff, true), - ("Smorgasbord", smorgasbord_instance, :untyped_vector, :forwarddiff, true), - ("Smorgasbord", smorgasbord_instance, :typed, :reversediff, true), - ("Smorgasbord", smorgasbord_instance, :typed, :mooncake, true), - ("Smorgasbord", smorgasbord_instance, :typed, :enzyme, true), - ("Loop univariate 1k", loop_univariate1k, :typed, :mooncake, true), - ("Multivariate 1k", multivariate1k, :typed, :mooncake, true), - ("Loop univariate 10k", loop_univariate10k, :typed, :mooncake, true), - ("Multivariate 10k", multivariate10k, :typed, :mooncake, true), - ("Dynamic", Models.dynamic(), :typed, :mooncake, true), - ("Submodel", Models.parent(randn(rng)), :typed, :mooncake, true), - ("LDA", lda_instance, :typed, :reversediff, true), + ("Smorgasbord", smorgasbord_instance, :forwarddiff, false), + ("Smorgasbord", smorgasbord_instance, :forwarddiff, true), + ("Smorgasbord", smorgasbord_instance, :reversediff, true), + ("Smorgasbord", smorgasbord_instance, :mooncake, true), + ("Smorgasbord", smorgasbord_instance, :enzyme, true), + ("Loop univariate 1k", loop_univariate1k, :mooncake, true), + ("Multivariate 1k", multivariate1k, :mooncake, true), + ("Loop univariate 10k", loop_univariate10k, :mooncake, true), + ("Multivariate 10k", multivariate10k, :mooncake, true), + ("Dynamic", Models.dynamic(), :mooncake, true), + ("Submodel", Models.parent(randn(rng)), :mooncake, true), + ("LDA", lda_instance, :reversediff, true), ] # Time running a model-like function that does not use DynamicPPL, as a reference point. @@ -94,13 +87,13 @@ function run(; to_json=false) @info "Reference evaluation time: $(reference_time) seconds" results_table = Tuple{ - String,Int,String,String,Bool,Union{Float64,Missing},Union{Float64,Missing} + String,Int,String,Bool,Union{Float64,Missing},Union{Float64,Missing} }[] - for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinations - @info "Running benchmark for $model_name, $varinfo_choice, $adbackend, $islinked" + for (model_name, model, adbackend, islinked) in chosen_combinations + @info "Running benchmark for $model_name, $adbackend, $islinked" relative_eval_time, relative_ad_eval_time = try - results = benchmark(model, varinfo_choice, adbackend, islinked) + results = benchmark(model, adbackend, islinked) @info " t(eval) = $(results.primal_time)" @info " t(grad) = $(results.grad_time)" (results.primal_time / reference_time), @@ -115,7 +108,6 @@ function run(; to_json=false) model_name, model_dimension(model, islinked), string(adbackend), - string(varinfo_choice), islinked, relative_eval_time, relative_ad_eval_time, @@ -131,9 +123,8 @@ struct TestCase model_name::String dim::Integer ad_backend::String - varinfo::String linked::Bool - TestCase(d::Dict{String,Any}) = new((d[c] for c in colnames[1:5])...) + TestCase(d::Dict{String,Any}) = new((d[c] for c in colnames[1:4])...) end function combine(head_filename::String, base_filename::String) head_results = try @@ -148,23 +139,22 @@ function combine(head_filename::String, base_filename::String) Dict{String,Any}[] end @info "Loaded $(length(base_results)) results from $base_filename" - # Identify unique combinations of (Model, Dim, AD Backend, VarInfo, Linked) + # Identify unique combinations of (Model, Dim, AD Backend, Linked) head_testcases = Dict( - TestCase(d) => (d[colnames[6]], d[colnames[7]]) for d in head_results + TestCase(d) => (d[colnames[5]], d[colnames[6]]) for d in head_results ) base_testcases = Dict( - TestCase(d) => (d[colnames[6]], d[colnames[7]]) for d in base_results + TestCase(d) => (d[colnames[5]], d[colnames[6]]) for d in base_results ) all_testcases = union(Set(keys(head_testcases)), Set(keys(base_testcases))) @info "$(length(all_testcases)) unique test cases found" sorted_testcases = sort( - collect(all_testcases); by=(c -> (c.model_name, c.linked, c.varinfo, c.ad_backend)) + collect(all_testcases); by=(c -> (c.model_name, c.linked, c.ad_backend)) ) results_table = Tuple{ String, Int, String, - String, Bool, String, String, @@ -179,12 +169,12 @@ function combine(head_filename::String, base_filename::String) sublabels = ["base", "this PR", "speedup"] results_colnames = [ [ - EmptyCells(5), + EmptyCells(4), MultiColumn(3, "t(eval) / t(ref)"), MultiColumn(3, "t(grad) / t(eval)"), MultiColumn(3, "t(grad) / t(ref)"), ], - [colnames[1:5]..., sublabels..., sublabels..., sublabels...], + [colnames[1:4]..., sublabels..., sublabels..., sublabels...], ] sprint_float(x::Float64) = @sprintf("%.2f", x) sprint_float(m::Missing) = "err" @@ -211,7 +201,6 @@ function combine(head_filename::String, base_filename::String) c.model_name, c.dim, c.ad_backend, - c.varinfo, c.linked, sprint_float(base_eval), sprint_float(head_eval), diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index 0dc7ece6e..6bb8672c9 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -1,6 +1,6 @@ module DynamicPPLBenchmarks -using DynamicPPL: VarInfo, SimpleVarInfo, VarName +using DynamicPPL: VarInfo, VarName using DynamicPPL: DynamicPPL using DynamicPPL.TestUtils.AD: run_ad, NoTest using ADTypes: ADTypes @@ -23,7 +23,7 @@ Return the dimension of `model`, accounting for linking, if any. """ function model_dimension(model, islinked) vi = VarInfo() - model(StableRNG(23), vi) + vi = last(DynamicPPL.init!!(StableRNG(23), model, vi)) if islinked vi = DynamicPPL.link(vi, model) end @@ -52,53 +52,24 @@ function to_backend(x::Union{AbstractString,Symbol}) end """ - benchmark(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::Bool) + benchmark(model, adbackend::Symbol, islinked::Bool) -Benchmark evaluation and gradient calculation for `model` using the selected varinfo type -and AD backend. - -Available varinfo choices: - • `:untyped` → uses `DynamicPPL.untyped_varinfo(model)` - • `:typed` → uses `DynamicPPL.typed_varinfo(model)` - • `:simple_namedtuple` → uses `SimpleVarInfo{Float64}(model())` - • `:simple_dict` → builds a `SimpleVarInfo{Float64}` from a Dict (pre-populated with the model’s outputs) +Benchmark evaluation and gradient calculation for `model` using the selected AD backend. The AD backend should be specified as a Symbol (e.g. `:forwarddiff`, `:reversediff`, `:zygote`). `islinked` determines whether to link the VarInfo for evaluation. """ -function benchmark(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::Bool) +function benchmark(model, adbackend::Symbol, islinked::Bool) rng = StableRNG(23) - + vi = VarInfo(rng, model) adbackend = to_backend(adbackend) - - vi = if varinfo_choice == :untyped - DynamicPPL.untyped_varinfo(rng, model) - elseif varinfo_choice == :typed - DynamicPPL.typed_varinfo(rng, model) - elseif varinfo_choice == :simple_namedtuple - SimpleVarInfo{Float64}(model(rng)) - elseif varinfo_choice == :simple_dict - retvals = model(rng) - vns = [VarName{k}() for k in keys(retvals)] - SimpleVarInfo{Float64}(Dict(zip(vns, values(retvals)))) - elseif varinfo_choice == :typed_vector - DynamicPPL.typed_vector_varinfo(rng, model) - elseif varinfo_choice == :untyped_vector - DynamicPPL.untyped_vector_varinfo(rng, model) - else - error("Unknown varinfo choice: $varinfo_choice") - end - - adbackend = to_backend(adbackend) - if islinked vi = DynamicPPL.link(vi, model) end - return run_ad( model, adbackend; varinfo=vi, benchmark=true, test=NoTest(), verbose=false ) end -end # module +end diff --git a/benchmarks/src/Models.jl b/benchmarks/src/Models.jl index 2c881aa95..76d4b2e93 100644 --- a/benchmarks/src/Models.jl +++ b/benchmarks/src/Models.jl @@ -2,7 +2,7 @@ Models for benchmarking Turing.jl. Each model returns a NamedTuple of all the random variables in the model that are not -observed (this is used for constructing SimpleVarInfos). +observed. """ module Models diff --git a/docs/src/api.md b/docs/src/api.md index bfc5dcc8d..a506c793e 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -333,27 +333,18 @@ Please see the documentation of [AbstractPPL.jl](https://github.com/TuringLang/A ### Data Structures of Variables -DynamicPPL provides different data structures used in for storing samples and accumulation of the log-probabilities, all of which are subtypes of [`AbstractVarInfo`](@ref). +DynamicPPL provides a data structure for storing samples and accumulation of the log-probabilities, called [`VarInfo`](@ref). +The interface that `VarInfo` respects is described by the abstract type [`AbstractVarInfo`](@ref). +Internally DynamicPPL also uses a couple of other subtypes of `AbstractVarInfo`. ```@docs AbstractVarInfo ``` -But exactly how a [`AbstractVarInfo`](@ref) stores this information can vary. - -#### `VarInfo` - ```@docs VarInfo ``` -```@docs -DynamicPPL.untyped_varinfo -DynamicPPL.typed_varinfo -DynamicPPL.untyped_vector_varinfo -DynamicPPL.typed_vector_varinfo -``` - One main characteristic of [`VarInfo`](@ref) is that samples are transformed to unconstrained Euclidean space and stored in a linearized form, as described in the [main Turing documentation](https://turinglang.org/docs/developers/transforms/dynamicppl/). The [Transformations section below](#Transformations) describes the methods used for this. In the specific case of `VarInfo`, it keeps track of whether samples have been transformed by setting flags on them, using the following functions. @@ -367,14 +358,11 @@ set_transformed!! Base.empty! ``` -#### `SimpleVarInfo` - -```@docs -SimpleVarInfo -``` - #### `VarNamedTuple` +`VarInfo` is only a thin wrapper around [`VarNamedTuple`](@ref), which stores arbitrary data keyed by `VarName`s. +For more details on `VarNamedTuple`, see the Internals section of our documentation. + ```@docs DynamicPPL.VarNamedTuples.VarNamedTuple DynamicPPL.VarNamedTuples.vnt_size diff --git a/docs/src/internals/varinfo.md b/docs/src/internals/varinfo.md index b04913aaf..c57ea1fcf 100644 --- a/docs/src/internals/varinfo.md +++ b/docs/src/internals/varinfo.md @@ -8,293 +8,50 @@ VarInfo It contains - - a `logp` field for accumulation of the log-density evaluation, and - - a `metadata` field for storing information about the realizations of the different variables. + - a `VarNamedTuple` field called `values`, + - an `AccumulatorTuple` called `accs`, to hold accumulators. -Representing `logp` is fairly straight-forward: we'll just use a `Real` or an array of `Real`, depending on the context. +`values` takes care of storing information related to values of individual random variables, while `accs` keeps track of information that we keep accumulating in the course of evaluating through a model. -**Representing `metadata` is a bit trickier**. This is supposed to contain all the necessary information for each `VarName` to enable the different executions of the model + extraction of different properties of interest after execution, e.g. the realization / value corresponding to a variable `@varname(x)`. +Variables are regonised by their `VarName`. +We want to work with `VarName` rather than something like `Symbol` or `String` as `VarName` contains additional structural information. +For instance, a `Symbol("x[1]")` can be a result of either `var"x[1]" ~ Normal()` or `x[1] ~ Normal()`; these scenarios are disambiguated by `VarName`. +`VarName`s also allow things such as setting values for `x[1]` and `x[2]` and getting a value for `x` as a whole. -!!! note - - We want to work with `VarName` rather than something like `Symbol` or `String` as `VarName` contains additional structural information, e.g. a `Symbol("x[1]")` can be a result of either `var"x[1]" ~ Normal()` or `x[1] ~ Normal()`; these scenarios are disambiguated by `VarName`. +To ensure that `VarInfo` is simple and intuitive to work with we want it to replicate the following functionality of `Dict`: -To ensure that `VarInfo` is simple and intuitive to work with, we want `VarInfo`, and hence the underlying `metadata`, to replicate the following functionality of `Dict`: + - `keys(::VarInfo)`: return all the `VarName`s present. + - `haskey(::VarInfo)`: check if a particular `VarName` is present. + - `getindex(::VarInfo, ::VarName)`: return the realization corresponding to a particular `VarName`. + - `setindex!!(::VarInfo, val, ::VarName)`: set the realization corresponding to a particular `VarName`. + - `delete!!(::VarInfo, ::VarName)`: delete the realization corresponding to a particular `VarName`. + - `empty!!(::VarInfo)`: delete all data. + - `merge(::VarInfo, ::VarInfo)`: merge two containers according to similar rules as `Dict`. - - `keys(::Dict)`: return all the `VarName`s present in `metadata`. - - `haskey(::Dict)`: check if a particular `VarName` is present in `metadata`. - - `getindex(::Dict, ::VarName)`: return the realization corresponding to a particular `VarName`. - - `setindex!(::Dict, val, ::VarName)`: set the realization corresponding to a particular `VarName`. - - `push!(::Dict, ::Pair)`: add a new key-value pair to the container. - - `delete!(::Dict, ::VarName)`: delete the realization corresponding to a particular `VarName`. - - `empty!(::Dict)`: delete all realizations in `metadata`. - - `merge(::Dict, ::Dict)`: merge two `metadata` structures according to similar rules as `Dict`. +Note that we only define the BangBang methods such as `setindex!!`, rather than the mutating ones likes `setindex!`. +This is due to the design of `VarNamedTuple`, which is explained on its own page in these docs. -*But* for general-purpose samplers, we often want to work with a simple flattened structure, typically a `Vector{<:Real}`. One can access a vectorised version of a variable's value with the following vector-like functions: +*But* for general-purpose samplers, we often want to work with a simple flattened structure, typically a `Vector{<:Real}`. +One can access a vectorised version of a variable's value with the following vector-like functions: - `getindex_internal(::VarInfo, ::VarName)`: get the flattened value of a single variable. - `getindex_internal(::VarInfo, ::Colon)`: get the flattened values of all variables. - `getindex_internal(::VarInfo, i::Int)`: get `i`th value of the flattened vector of all values - - `setindex_internal!(::VarInfo, ::AbstractVector, ::VarName)`: set the flattened value of a variable. - - `setindex_internal!(::VarInfo, val, i::Int)`: set the `i`th value of the flattened vector of all values + - `setindex_internal!!(::VarInfo, ::AbstractVector, ::VarName)`: set the flattened value of a variable. + - `setindex_internal!!(::VarInfo, val, i::Int)`: set the `i`th value of the flattened vector of all values - `length_internal(::VarInfo)`: return the length of the flat representation of `metadata`. The functions have `_internal` in their name because internally `VarInfo` always stores values as vectorised. -Moreover, a link transformation can be applied to a `VarInfo` with `link!!` (and reversed with `invlink!!`), which applies a reversible transformation to the internal storage format of a variable that makes the range of the random variable cover all of Euclidean space. `getindex_internal` and `setindex_internal!` give direct access to the vectorised value after such a transformation, which is what samplers often need to be able sample in unconstrained space. One can also manually set a transformation by giving `setindex_internal!` a fourth, optional argument, that is a function that maps internally stored value to the actual value of the variable. +Moreover, a link transformation can be applied to a `VarInfo` with `link!!` (and reversed with `invlink!!`), which applies a reversible transformation to the internal storage format of a variable that makes the range of the random variable cover all of Euclidean space. +`getindex_internal` and `setindex_internal!` give direct access to the vectorised value after such a transformation, which is what samplers often need to be able sample in unconstrained space. +One can also manually set a transformation by giving `setindex_internal!!` a fourth, optional argument, that is a function that maps internally stored value to the actual value of the variable. -Finally, we want want the underlying representation used in `metadata` to have a few performance-related properties: +Finally, we want want the underlying storage to have a few performance-related properties: 1. Type-stable when possible, but functional when not. 2. Efficient storage and iteration when possible, but functional when not. The "but functional when not" is important as we want to support arbitrary models, which means that we can't always have these performance properties. -In the following sections, we'll outline how we achieve this in [`VarInfo`](@ref). - -## Type-stability - -Ensuring type-stability is somewhat non-trivial to address since we want this to be the case even when models mix continuous (typically `Float64`) and discrete (typically `Int`) variables. - -Suppose we have an implementation of `metadata` which implements the functionality outlined in the previous section. The way we approach this in `VarInfo` is to use a `NamedTuple` with a separate `metadata` *for each distinct `Symbol` used*. For example, if we have a model of the form - -```@example varinfo-design -using DynamicPPL, Distributions, FillArrays - -@model function demo() - x ~ product_distribution(Fill(Bernoulli(0.5), 2)) - y ~ Normal(0, 1) - return nothing -end -``` - -then we construct a type-stable representation by using a `NamedTuple{(:x, :y), Tuple{Vx, Vy}}` where - - - `Vx` is a container with `eltype` `Bool`, and - - `Vy` is a container with `eltype` `Float64`. - -Since `VarName` contains the `Symbol` used in its type, something like `getindex(varinfo, @varname(x))` can be resolved to `getindex(varinfo.metadata.x, @varname(x))` at compile-time. - -For example, with the model above we have - -```@example varinfo-design -# Type-unstable `VarInfo` -varinfo_untyped = DynamicPPL.untyped_varinfo(demo()) -typeof(varinfo_untyped.metadata) -``` - -```@example varinfo-design -# Type-stable `VarInfo` -varinfo_typed = DynamicPPL.typed_varinfo(demo()) -typeof(varinfo_typed.metadata) -``` - -They both work as expected but one results in concrete typing and the other does not: - -```@example varinfo-design -varinfo_untyped[@varname(x)], varinfo_untyped[@varname(y)] -``` - -```@example varinfo-design -varinfo_typed[@varname(x)], varinfo_typed[@varname(y)] -``` - -Notice that the untyped `VarInfo` uses `Vector{Real}` to store the boolean entries while the typed uses `Vector{Bool}`. This is because the untyped version needs the underlying container to be able to handle both the `Bool` for `x` and the `Float64` for `y`, while the typed version can use a `Vector{Bool}` for `x` and a `Vector{Float64}` for `y` due to its usage of `NamedTuple`. - -!!! warning - - Of course, this `NamedTuple` approach is *not* necessarily going to help us in scenarios where the `Symbol` does not correspond to a unique type, e.g. - - ```julia - x[1] ~ Bernoulli(0.5) - x[2] ~ Normal(0, 1) - ``` - - In this case we'll end up with a `NamedTuple((:x,), Tuple{Vx})` where `Vx` is a container with `eltype` `Union{Bool, Float64}` or something worse. This is *not* type-stable but will still be functional. - - In practice, we rarely observe such mixing of types, therefore in DynamicPPL, and more widely in Turing.jl, we use a `NamedTuple` approach for type-stability with great success. - -!!! warning - - Another downside with such a `NamedTuple` approach is that if we have a model with lots of tilde-statements, e.g. `a ~ Normal()`, `b ~ Normal()`, ..., `z ~ Normal()` will result in a `NamedTuple` with 27 entries, potentially leading to long compilation times. - - For these scenarios it can be useful to fall back to "untyped" representations. - -Hence we obtain a "type-stable when possible"-representation by wrapping it in a `NamedTuple` and partially resolving the `getindex`, `setindex!`, etc. methods at compile-time. When type-stability is *not* desired, we can simply use a single `metadata` for all `VarName`s instead of a `NamedTuple` wrapping a collection of `metadata`s. - -## Efficient storage and iteration - -Efficient storage and iteration we achieve through implementation of the `metadata`. In particular, we do so with [`DynamicPPL.VarNamedVector`](@ref): - -```@docs -DynamicPPL.VarNamedVector -``` - -In a [`DynamicPPL.VarNamedVector{<:VarName,T}`](@ref), we achieve the desiderata by storing the values for different `VarName`s contiguously in a `Vector{T}` and keeping track of which ranges correspond to which `VarName`s. - -This does require a bit of book-keeping, in particular when it comes to insertions and deletions. Internally, this is handled by assigning each `VarName` a unique `Int` index in the `varname_to_index` field, which is then used to index into the following fields: - - - `varnames::Vector{<:VarName}`: the `VarName`s in the order they appear in the `Vector{T}`. - - `ranges::Vector{UnitRange{Int}}`: the ranges of indices in the `Vector{T}` that correspond to each `VarName`. - - `transforms::Vector`: the transforms associated with each `VarName`. - -Mutating functions, e.g. `setindex_internal!(vnv::VarNamedVector, val, vn::VarName)`, are then treated according to the following rules: - - 1. If `vn` is not already present: add it to the end of `vnv.varnames`, add the `val` to the underlying `vnv.vals`, etc. - - 2. If `vn` is already present in `vnv`: - - 1. If `val` has the *same length* as the existing value for `vn`: replace existing value. - 2. If `val` has a *smaller length* than the existing value for `vn`: replace existing value and mark the remaining indices as "inactive" by increasing the entry in `vnv.num_inactive` field. - 3. If `val` has a *larger length* than the existing value for `vn`: expand the underlying `vnv.vals` to accommodate the new value, update all `VarName`s occuring after `vn`, and update the `vnv.ranges` to point to the new range for `vn`. - -This means that `VarNamedVector` is allowed to grow as needed, while "shrinking" (i.e. insertion of smaller elements) is handled by simply marking the redundant indices as "inactive". This turns out to be efficient for use-cases that we are generally interested in. - -For example, we want to optimize code-paths which effectively boil down to inner-loop in the following example: - -```julia -# Construct a `VarInfo` with types inferred from `model`. -varinfo = VarInfo(model) - -# Repeatedly sample from `model`. -for _ in 1:num_samples - rand!(rng, model, varinfo) - - # Do something with `varinfo`. - # ... -end -``` - -There are typically a few scenarios where we encounter changing representation sizes of a random variable `x`: - - 1. We're working with a transformed version `x` which is represented in a lower-dimensional space, e.g. transforming a `x ~ LKJ(2, 1)` to unconstrained `y = f(x)` takes us from 2-by-2 `Matrix{Float64}` to a 1-length `Vector{Float64}`. - 2. `x` has a random size, e.g. in a mixture model with a prior on the number of components. Here the size of `x` can vary widly between every realization of the `Model`. - -In scenario (1), we're usually *shrinking* the representation of `x`, and so we end up not making any allocations for the underlying `Vector{T}` but instead just marking the redundant part as "inactive". - -In scenario (2), we end up increasing the allocated memory for the randomly sized `x`, eventually leading to a vector that is large enough to hold realizations without needing to reallocate. But this can still lead to unnecessary memory usage, which might be undesirable. Hence one has to make a decision regarding the trade-off between memory usage and performance for the use-case at hand. - -To help with this, we have the following functions: - -```@docs -DynamicPPL.has_inactive -DynamicPPL.num_inactive -DynamicPPL.num_allocated -DynamicPPL.is_contiguous -DynamicPPL.contiguify! -``` - -For example, one might encounter the following scenario: - -```@example varinfo-design -vnv = DynamicPPL.VarNamedVector(@varname(x) => [true]) -println("Before insertion: number of allocated entries $(DynamicPPL.num_allocated(vnv))") - -for i in 1:5 - x = fill(true, rand(1:100)) - DynamicPPL.update!(vnv, x, @varname(x)) - println( - "After insertion #$(i) of length $(length(x)): number of allocated entries $(DynamicPPL.num_allocated(vnv))", - ) -end -``` - -We can then insert a call to [`DynamicPPL.contiguify!`](@ref) after every insertion whenever the allocation grows too large to reduce overall memory usage: - -```@example varinfo-design -vnv = DynamicPPL.VarNamedVector(@varname(x) => [true]) -println("Before insertion: number of allocated entries $(DynamicPPL.num_allocated(vnv))") - -for i in 1:5 - x = fill(true, rand(1:100)) - DynamicPPL.update!(vnv, x, @varname(x)) - if DynamicPPL.num_allocated(vnv) > 10 - DynamicPPL.contiguify!(vnv) - end - println( - "After insertion #$(i) of length $(length(x)): number of allocated entries $(DynamicPPL.num_allocated(vnv))", - ) -end -``` - -This does incur a runtime cost as it requires re-allocation of the `ranges` in addition to a `resize!` of the underlying `Vector{T}`. However, this also ensures that the the underlying `Vector{T}` is contiguous, which is important for performance. Hence, if we're about to do a lot of work with the `VarNamedVector` without insertions, etc., it can be worth it to do a sweep to ensure that the underlying `Vector{T}` is contiguous. - -!!! note - - Higher-dimensional arrays, e.g. `Matrix`, are handled by simply vectorizing them before storing them in the `Vector{T}`, and composing the `VarName`'s transformation with a `DynamicPPL.ReshapeTransform`. - -Continuing from the example from the previous section, we can use a `VarInfo` with a `VarNamedVector` as the `metadata` field: - -```@example varinfo-design -# Type-unstable -varinfo_untyped_vnv = DynamicPPL.untyped_vector_varinfo(varinfo_untyped) -varinfo_untyped_vnv[@varname(x)], varinfo_untyped_vnv[@varname(y)] -``` - -```@example varinfo-design -# Type-stable -varinfo_typed_vnv = DynamicPPL.typed_vector_varinfo(varinfo_typed) -varinfo_typed_vnv[@varname(x)], varinfo_typed_vnv[@varname(y)] -``` - -If we now try to `delete!` `@varname(x)` - -```@example varinfo-design -haskey(varinfo_untyped_vnv, @varname(x)) -``` - -```@example varinfo-design -DynamicPPL.has_inactive(varinfo_untyped_vnv.metadata) -``` - -```@example varinfo-design -# `delete!` -DynamicPPL.delete!(varinfo_untyped_vnv.metadata, @varname(x)) -DynamicPPL.has_inactive(varinfo_untyped_vnv.metadata) -``` - -```@example varinfo-design -haskey(varinfo_untyped_vnv, @varname(x)) -``` - -Or insert a differently-sized value for `@varname(x)` - -```@example varinfo-design -DynamicPPL.insert!(varinfo_untyped_vnv.metadata, fill(true, 1), @varname(x)) -varinfo_untyped_vnv[@varname(x)] -``` - -```@example varinfo-design -DynamicPPL.num_allocated(varinfo_untyped_vnv.metadata, @varname(x)) -``` - -```@example varinfo-design -DynamicPPL.update!(varinfo_untyped_vnv.metadata, fill(true, 4), @varname(x)) -varinfo_untyped_vnv[@varname(x)] -``` - -```@example varinfo-design -DynamicPPL.num_allocated(varinfo_untyped_vnv.metadata, @varname(x)) -``` - -### Performance summary - -In the end, we have the following "rough" performance characteristics for `VarNamedVector`: - -| Method | Is blazingly fast? | -|:----------------------------------------:|:--------------------------------------------------------------------------------------------:| -| `getindex` | ${\color{green} \checkmark}$ | -| `setindex!` on a new `VarName` | ${\color{green} \checkmark}$ | -| `delete!` | ${\color{red} \times}$ | -| `update!` on existing `VarName` | ${\color{green} \checkmark}$ if smaller or same size / ${\color{red} \times}$ if larger size | -| `values_as(::VarNamedVector, Vector{T})` | ${\color{green} \checkmark}$ if contiguous / ${\color{orange} \div}$ otherwise | - -## Other methods - -```@docs -DynamicPPL.replace_raw_storage(::DynamicPPL.VarNamedVector, vals::AbstractVector) -``` - -```@docs; canonical=false -DynamicPPL.values_as(::DynamicPPL.VarNamedVector) -``` +To understand how these are achieved, we refer the reader to the documentation on `VarNamedTuple`, which underpins `VarInfo`. diff --git a/ext/DynamicPPLChainRulesCoreExt.jl b/ext/DynamicPPLChainRulesCoreExt.jl index 12b816c60..37c9444b3 100644 --- a/ext/DynamicPPLChainRulesCoreExt.jl +++ b/ext/DynamicPPLChainRulesCoreExt.jl @@ -16,6 +16,4 @@ ChainRulesCore.@non_differentiable BangBang.push!!( # No need + causes issues for some AD backends, e.g. Zygote. ChainRulesCore.@non_differentiable DynamicPPL.infer_nested_eltype(x) -ChainRulesCore.@non_differentiable DynamicPPL.recontiguify_ranges!(ranges) - end # module diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index b84c076be..b5a77be03 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -46,7 +46,6 @@ import Base: # VarInfo export AbstractVarInfo, VarInfo, - SimpleVarInfo, AbstractAccumulator, LogLikelihoodAccumulator, LogPriorAccumulator, @@ -178,7 +177,7 @@ Abstract supertype for data structures that capture random variables when execut probabilistic model and accumulate log densities such as the log likelihood or the log joint probability of the model. -See also: [`VarInfo`](@ref), [`SimpleVarInfo`](@ref). +See also: [`VarInfo`](@ref) """ abstract type AbstractVarInfo <: AbstractModelTrace end @@ -196,14 +195,11 @@ include("model.jl") include("varname.jl") include("distribution_wrappers.jl") include("submodel.jl") -include("varnamedvector.jl") include("accumulators.jl") include("default_accumulators.jl") include("abstract_varinfo.jl") include("threadsafe.jl") -# include("varinfo.jl") include("vntvarinfo.jl") -include("simple_varinfo.jl") include("onlyaccs.jl") include("compiler.jl") include("pointwise_logdensities.jl") diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 0c15cb9c7..1c5159626 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -502,64 +502,12 @@ If no `Type` is provided, return values as stored in `varinfo`. # Examples -`SimpleVarInfo` with `NamedTuple`: - -```jldoctest -julia> data = (x = 1.0, m = [2.0]); - -julia> values_as(SimpleVarInfo(data)) -(x = 1.0, m = [2.0]) - -julia> values_as(SimpleVarInfo(data), NamedTuple) -(x = 1.0, m = [2.0]) - -julia> values_as(SimpleVarInfo(data), OrderedDict) -OrderedDict{VarName{sym, typeof(identity)} where sym, Any} with 2 entries: - x => 1.0 - m => [2.0] - -julia> values_as(SimpleVarInfo(data), Vector) -2-element Vector{Float64}: - 1.0 - 2.0 -``` - -`SimpleVarInfo` with `OrderedDict`: - -```jldoctest -julia> data = OrderedDict{Any,Any}(@varname(x) => 1.0, @varname(m) => [2.0]); - -julia> values_as(SimpleVarInfo(data)) -OrderedDict{Any, Any} with 2 entries: - x => 1.0 - m => [2.0] - -julia> values_as(SimpleVarInfo(data), NamedTuple) -(x = 1.0, m = [2.0]) - -julia> values_as(SimpleVarInfo(data), OrderedDict) -OrderedDict{Any, Any} with 2 entries: - x => 1.0 - m => [2.0] - -julia> values_as(SimpleVarInfo(data), Vector) -2-element Vector{Float64}: - 1.0 - 2.0 -``` - -`VarInfo` with `NamedTuple` of `Metadata`: - ```jldoctest julia> # Just use an example model to construct the `VarInfo` because we're lazy. - vi = DynamicPPL.typed_varinfo(DynamicPPL.TestUtils.demo_assume_dot_observe()); + vi = DynamicPPL.VarInfo(DynamicPPL.TestUtils.demo_assume_dot_observe()); julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0; -julia> # For the sake of brevity, let's just check the type. - md = values_as(vi); md.s isa Union{DynamicPPL.Metadata, DynamicPPL.VarNamedVector} -true - julia> values_as(vi, NamedTuple) (s = 1.0, m = 2.0) @@ -573,32 +521,6 @@ julia> values_as(vi, Vector) 1.0 2.0 ``` - -`VarInfo` with `Metadata`: - -```jldoctest -julia> # Just use an example model to construct the `VarInfo` because we're lazy. - vi = DynamicPPL.untyped_varinfo(DynamicPPL.TestUtils.demo_assume_dot_observe()); - -julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0; - -julia> # For the sake of brevity, let's just check the type. - values_as(vi) isa Union{DynamicPPL.Metadata, Vector} -true - -julia> values_as(vi, NamedTuple) -(s = 1.0, m = 2.0) - -julia> values_as(vi, OrderedDict) -OrderedDict{VarName{sym, typeof(identity)} where sym, Float64} with 2 entries: - s => 1.0 - m => 2.0 - -julia> values_as(vi, Vector) -2-element Vector{Real}: - 1.0 - 2.0 -``` """ function values_as end @@ -625,13 +547,6 @@ function Base.eltype(vi::AbstractVarInfo) return eltype(T) end -""" - has_varnamedvector(varinfo::VarInfo) - -Returns `true` if `varinfo` uses `VarNamedVector` as metadata. -""" -has_varnamedvector(vi::AbstractVarInfo) = false - # TODO: Should relax constraints on `vns` to be `AbstractVector{<:Any}` and just try to convert # the `eltype` to `VarName`? This might be useful when someone does `[@varname(x[1]), @varname(m)]` which # might result in a `Vector{Any}`. @@ -828,8 +743,6 @@ function link!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) return link!!(default_transformation(model, vi), vi, vns, model) end function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) - # Note that in practice this method is only called for SimpleVarInfo, because VarInfo - # has a dedicated implementation model = setleafcontext(model, DynamicTransformationContext{false}()) vi = last(evaluate!!(model, vi)) return set_transformed!!(vi, t) @@ -897,8 +810,6 @@ function invlink!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) return invlink!!(default_transformation(model, vi), vi, vns, model) end function invlink!!(::DynamicTransformation, vi::AbstractVarInfo, model::Model) - # Note that in practice this method is only called for SimpleVarInfo, because VarInfo - # has a dedicated implementation model = setleafcontext(model, DynamicTransformationContext{true}()) vi = last(evaluate!!(model, vi)) return set_transformed!!(vi, NoTransformation()) @@ -983,12 +894,12 @@ julia> # Change the `default_transformation` for our model to be a julia> model = demo(); -julia> vi = SimpleVarInfo(x=1.0) -SimpleVarInfo((x = 1.0,), 0.0) +julia> vi = setindex!!(VarInfo(), 1.0, @varname(x)); + +julia> vi[@varname(x)] +1.0 -julia> # Uses the `inverse` of `MyBijector`, which we have defined as `identity` - vi_linked = link!!(vi, model) -Transformed SimpleVarInfo((x = 1.0,), 0.0) +julia> vi_linked = link!!(vi, model); julia> # Now performs a single `invlink!!` before model evaluation. logjoint(model, vi_linked) diff --git a/src/contexts/transformation.jl b/src/contexts/transformation.jl index c2eee2863..0914d7a79 100644 --- a/src/contexts/transformation.jl +++ b/src/contexts/transformation.jl @@ -7,7 +7,7 @@ constrained space if `isinverse` or unconstrained if `!isinverse`. Note that some `AbstractVarInfo` types, must notably `VarInfo`, override the `DynamicTransformationContext` methods with more efficient implementations. `DynamicTransformationContext` is a fallback for when we need to evaluate the model to know -how to do the transformation, used by e.g. `SimpleVarInfo`. +how to do the transformation. """ struct DynamicTransformationContext{isinverse} <: AbstractContext end diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 44fdad5a8..4f8ac4933 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -13,8 +13,6 @@ using DynamicPPL: OnlyAccsVarInfo, RangeAndLinked, VectorWithRanges, - # Metadata, - VarNamedVector, default_accumulators, float_type_with_fallback, getlogjoint, @@ -296,11 +294,6 @@ tweak_adtype(adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo) = adtyp # Helper functions to extract ranges and link status # ###################################################### -# This fails for SimpleVarInfo, but honestly there is no reason to support that here. The -# fact is that evaluation doesn't use a VarInfo, it only uses it once to generate the ranges -# and link status. So there is no motivation to use SimpleVarInfo inside a -# LogDensityFunction any more, we can just always use typed VarInfo. In fact one could argue -# that there is no purpose in supporting untyped VarInfo either. """ get_ranges_and_linked(varinfo::VarInfo) @@ -329,46 +322,3 @@ function get_ranges_and_linked(vi::VNTVarInfo) ) return vnt end - -# function get_ranges_and_linked(varinfo::VarInfo{<:NamedTuple{syms}}) where {syms} -# all_ranges = VarNamedTuple() -# offset = 1 -# for sym in syms -# md = varinfo.metadata[sym] -# this_md_others, offset = get_ranges_and_linked_metadata(md, offset) -# all_ranges = merge(all_ranges, this_md_others) -# end -# return all_ranges -# end -# function get_ranges_and_linked(varinfo::VarInfo{<:Union{Metadata,VarNamedVector}}) -# all_ranges, _ = get_ranges_and_linked_metadata(varinfo.metadata, 1) -# return all_ranges -# end -# function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int) -# all_ranges = VarNamedTuple() -# offset = start_offset -# for (vn, idx) in md.idcs -# is_linked = md.is_transformed[idx] -# range = md.ranges[idx] .+ (start_offset - 1) -# orig_size = varnamesize(vn) -# all_ranges = BangBang.setindex!!( -# all_ranges, RangeAndLinked(range, is_linked, orig_size), vn -# ) -# offset += length(range) -# end -# return all_ranges, offset -# end -# function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int) -# all_ranges = VarNamedTuple() -# offset = start_offset -# for (vn, idx) in vnv.varname_to_index -# is_linked = vnv.is_unconstrained[idx] -# range = vnv.ranges[idx] .+ (start_offset - 1) -# orig_size = varnamesize(vn) -# all_ranges = BangBang.setindex!!( -# all_ranges, RangeAndLinked(range, is_linked, orig_size), vn -# ) -# offset += length(range) -# end -# return all_ranges, offset -# end diff --git a/src/model.jl b/src/model.jl index 91558ecdc..cd36ee44b 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1151,6 +1151,117 @@ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) return getloglikelihood(last(evaluate!!(model, varinfo))) end +""" + logjoint(model::Model, values::Union{NamedTuple,AbstractDict}) + +Return the log joint probability of variables `values` for the probabilistic `model`. + +See [`logprior`](@ref) and [`loglikelihood`](@ref). + +# Examples +```jldoctest; setup=:(using Distributions) +julia> @model function demo(x) + m ~ Normal() + for i in eachindex(x) + x[i] ~ Normal(m, 1.0) + end + end +demo (generic function with 2 methods) + +julia> # Using a `NamedTuple`. + logjoint(demo([1.0]), (m = 100.0, )) +-9902.33787706641 + +julia> # Using a `OrderedDict`. + logjoint(demo([1.0]), OrderedDict(@varname(m) => 100.0)) +-9902.33787706641 + +julia> # Truth. + logpdf(Normal(100.0, 1.0), 1.0) + logpdf(Normal(), 100.0) +-9902.33787706641 +``` +""" +function logjoint(model::Model, values::Union{NamedTuple,AbstractDict}) + accs = AccumulatorTuple((LogPriorAccumulator(), LogLikelihoodAccumulator())) + vi = OnlyAccsVarInfo(accs) + _, vi = DynamicPPL.init!!(model, vi, InitFromParams(values, nothing)) + return getlogjoint(vi) +end + +""" + logprior(model::Model, values::Union{NamedTuple,AbstractDict}) + +Return the log prior probability of variables `values` for the probabilistic `model`. + +See also [`logjoint`](@ref) and [`loglikelihood`](@ref). + +# Examples +```jldoctest; setup=:(using Distributions) +julia> @model function demo(x) + m ~ Normal() + for i in eachindex(x) + x[i] ~ Normal(m, 1.0) + end + end +demo (generic function with 2 methods) + +julia> # Using a `NamedTuple`. + logprior(demo([1.0]), (m = 100.0, )) +-5000.918938533205 + +julia> # Using a `OrderedDict`. + logprior(demo([1.0]), OrderedDict(@varname(m) => 100.0)) +-5000.918938533205 + +julia> # Truth. + logpdf(Normal(), 100.0) +-5000.918938533205 +``` +""" +function logprior(model::Model, values::Union{NamedTuple,AbstractDict}) + accs = AccumulatorTuple((LogPriorAccumulator(),)) + vi = OnlyAccsVarInfo(accs) + _, vi = DynamicPPL.init!!(model, vi, InitFromParams(values, nothing)) + return getlogprior(vi) +end + +""" + loglikelihood(model::Model, values::Union{NamedTuple,AbstractDict}) + +Return the log likelihood of variables `values` for the probabilistic `model`. + +See also [`logjoint`](@ref) and [`logprior`](@ref). + +# Examples +```jldoctest; setup=:(using Distributions) +julia> @model function demo(x) + m ~ Normal() + for i in eachindex(x) + x[i] ~ Normal(m, 1.0) + end + end +demo (generic function with 2 methods) + +julia> # Using a `NamedTuple`. + loglikelihood(demo([1.0]), (m = 100.0, )) +-4901.418938533205 + +julia> # Using a `OrderedDict`. + loglikelihood(demo([1.0]), OrderedDict(@varname(m) => 100.0)) +-4901.418938533205 + +julia> # Truth. + logpdf(Normal(100.0, 1.0), 1.0) +-4901.418938533205 +``` +""" +function Distributions.loglikelihood(model::Model, values::Union{NamedTuple,AbstractDict}) + accs = AccumulatorTuple((LogLikelihoodAccumulator(),)) + vi = OnlyAccsVarInfo(accs) + _, vi = DynamicPPL.init!!(model, vi, InitFromParams(values, nothing)) + return getloglikelihood(vi) +end + # Implemented & documented in DynamicPPLMCMCChainsExt function predict end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl deleted file mode 100644 index 4add65d6d..000000000 --- a/src/simple_varinfo.jl +++ /dev/null @@ -1,647 +0,0 @@ -""" - $(TYPEDEF) - -A simple wrapper of the parameters with a `logp` field for -accumulation of the logdensity. - -Currently only implemented for `NT<:NamedTuple` and `NT<:AbstractDict`. - -# Fields -$(FIELDS) - -# Notes -The major differences between this and `NTVarInfo` are: -1. `SimpleVarInfo` does not require linearization. -2. `SimpleVarInfo` can use more efficient bijectors. -3. `SimpleVarInfo` is only type-stable if `NT<:NamedTuple` and either - a) no indexing is used in tilde-statements, or - b) the values have been specified with the correct shapes. - -# Examples -## General usage -```jldoctest simplevarinfo-general; setup=:(using Distributions) -julia> using StableRNGs - -julia> @model function demo() - m ~ Normal() - x = Vector{Float64}(undef, 2) - for i in eachindex(x) - x[i] ~ Normal() - end - return x - end -demo (generic function with 2 methods) - -julia> m = demo(); - -julia> rng = StableRNG(42); - -julia> # In the `NamedTuple` version we need to provide the place-holder values for - # the variables which are using "containers", e.g. `Array`. - # In this case, this means that we need to specify `x` but not `m`. - _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo((x = ones(2), ))); - -julia> # (✓) Vroom, vroom! FAST!!! - vi[@varname(x[1])] -0.4471218424633827 - -julia> # We can also access arbitrary varnames pointing to `x`, e.g. - vi[@varname(x)] -2-element Vector{Float64}: - 0.4471218424633827 - 1.3736306979834252 - -julia> vi[@varname(x[1:2])] -2-element Vector{Float64}: - 0.4471218424633827 - 1.3736306979834252 - -julia> # (×) If we don't provide the container... - _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo()); -ERROR: FieldError: type NamedTuple has no field `x`, available fields: `m` -[...] - -julia> # If one does not know the varnames, we can use a `OrderedDict` instead. - _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}())); - -julia> # (✓) Sort of fast, but only possible at runtime. - vi[@varname(x[1])] --1.019202452456547 - -julia> # In addtion, we can only access varnames as they appear in the model! - vi[@varname(x)] -ERROR: x was not found in the dictionary provided -[...] - -julia> vi[@varname(x[1:2])] -ERROR: x[1:2] was not found in the dictionary provided -[...] -``` - -_Technically_, it's possible to use any implementation of `AbstractDict` in place of -`OrderedDict`, but `OrderedDict` ensures that certain operations, e.g. linearization/flattening -of the values in the varinfo, are consistent between evaluations. Hence `OrderedDict` is -the preferred implementation of `AbstractDict` to use here. - -You can also sample in _transformed_ space: - -```jldoctest simplevarinfo-general -julia> @model demo_constrained() = x ~ Exponential() -demo_constrained (generic function with 2 methods) - -julia> m = demo_constrained(); - -julia> _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo()); - -julia> vi[@varname(x)] # (✓) 0 ≤ x < ∞ -1.8632965762164932 - -julia> _, vi = DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(), true)); - -julia> vi[@varname(x)] # (✓) -∞ < x < ∞ --0.21080155351918753 - -julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; - -julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! -true - -julia> # And with `OrderedDict` of course! - _, vi = DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true)); - -julia> vi[@varname(x)] # (✓) -∞ < x < ∞ -0.6225185067787314 - -julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; - -julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! -true -``` - -Evaluation in transformed space of course also works: - -```jldoctest simplevarinfo-general -julia> vi = DynamicPPL.set_transformed!!(SimpleVarInfo((x = -1.0,)), true) -Transformed SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogJacobian = LogJacobianAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0))) - -julia> # (✓) Positive probability mass on negative numbers! - getlogjoint_internal(last(DynamicPPL.evaluate!!(m, vi))) --1.3678794411714423 - -julia> # While if we forget to indicate that it's transformed: - vi = DynamicPPL.set_transformed!!(SimpleVarInfo((x = -1.0,)), false) -SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogJacobian = LogJacobianAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0))) - -julia> # (✓) No probability mass on negative numbers! - getlogjoint_internal(last(DynamicPPL.evaluate!!(m, vi))) --Inf -``` - -## Indexing -Using `NamedTuple` as underlying storage. - -```jldoctest -julia> svi_nt = SimpleVarInfo((m = (a = [1.0], ), )); - -julia> svi_nt[@varname(m)] -(a = [1.0],) - -julia> svi_nt[@varname(m.a)] -1-element Vector{Float64}: - 1.0 - -julia> svi_nt[@varname(m.a[1])] -1.0 - -julia> svi_nt[@varname(m.a[2])] -ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2] -[...] - -julia> svi_nt[@varname(m.b)] -ERROR: FieldError: type NamedTuple has no field `b`, available fields: `a` -[...] -``` - -Using `OrderedDict` as underlying storage. -```jldoctest -julia> svi_dict = SimpleVarInfo(OrderedDict(@varname(m) => (a = [1.0], ))); - -julia> svi_dict[@varname(m)] -(a = [1.0],) - -julia> svi_dict[@varname(m.a)] -1-element Vector{Float64}: - 1.0 - -julia> svi_dict[@varname(m.a[1])] -1.0 - -julia> svi_dict[@varname(m.a[2])] -ERROR: m.a[2] was not found in the dictionary provided -[...] - -julia> svi_dict[@varname(m.b)] -ERROR: m.b was not found in the dictionary provided -[...] -``` -""" -struct SimpleVarInfo{NT,Accs<:AccumulatorTuple where {N},C<:AbstractTransformation} <: - AbstractVarInfo - "underlying representation of the realization represented" - values::NT - "tuple of accumulators for things like log prior and log likelihood" - accs::Accs - "represents whether it assumes variables to be transformed" - transformation::C -end - -function Base.:(==)(vi1::SimpleVarInfo, vi2::SimpleVarInfo) - return vi1.values == vi2.values && - vi1.accs == vi2.accs && - vi1.transformation == vi2.transformation -end - -transformation(vi::SimpleVarInfo) = vi.transformation - -function SimpleVarInfo(values, accs) - return SimpleVarInfo(values, accs, NoTransformation()) -end -function SimpleVarInfo{T}(values) where {T<:Real} - return SimpleVarInfo(values, default_accumulators(T)) -end -function SimpleVarInfo(values) - return SimpleVarInfo{LogProbType}(values) -end -function SimpleVarInfo(values::Union{<:NamedTuple,<:AbstractDict{<:VarName}}) - return if isempty(values) - # Can't infer from values, so we just use default. - SimpleVarInfo{LogProbType}(values) - else - # Infer from `values`. - SimpleVarInfo{float_type_with_fallback(infer_nested_eltype(typeof(values)))}(values) - end -end - -# Using `kwargs` to specify the values. -function SimpleVarInfo{T}(; kwargs...) where {T<:Real} - return SimpleVarInfo{T}(NamedTuple(kwargs)) -end -function SimpleVarInfo(; kwargs...) - return SimpleVarInfo(NamedTuple(kwargs)) -end - -# Constructor from `Model`. -function SimpleVarInfo{T}( - rng::Random.AbstractRNG, - model::Model, - init_strategy::AbstractInitStrategy=InitFromPrior(), -) where {T<:Real} - return last(init!!(rng, model, SimpleVarInfo{T}(), init_strategy)) -end -function SimpleVarInfo{T}( - model::Model, init_strategy::AbstractInitStrategy=InitFromPrior() -) where {T<:Real} - return SimpleVarInfo{T}(Random.default_rng(), model, init_strategy) -end -# Constructors without type param -function SimpleVarInfo( - rng::Random.AbstractRNG, - model::Model, - init_strategy::AbstractInitStrategy=InitFromPrior(), -) - return SimpleVarInfo{LogProbType}(rng, model, init_strategy) -end -function SimpleVarInfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) - return SimpleVarInfo{LogProbType}(Random.default_rng(), model, init_strategy) -end - -# Constructor from `VarInfo`. -# function SimpleVarInfo(vi::NTVarInfo, ::Type{D}) where {D} -# values = values_as(vi, D) -# return SimpleVarInfo(values, copy(getaccs(vi))) -# end -# function SimpleVarInfo{T}(vi::NTVarInfo, ::Type{D}) where {T<:Real,D} -# values = values_as(vi, D) -# accs = map(acc -> convert_eltype(T, acc), getaccs(vi)) -# return SimpleVarInfo(values, accs) -# end - -function untyped_simple_varinfo(model::Model) - varinfo = SimpleVarInfo(OrderedDict{VarName,Any}()) - return last(init!!(model, varinfo)) -end - -function typed_simple_varinfo(model::Model) - varinfo = SimpleVarInfo{Float64}() - return last(init!!(model, varinfo)) -end - -function unflatten(svi::SimpleVarInfo, x::AbstractVector) - vals = unflatten(svi.values, x) - return SimpleVarInfo(vals, svi.accs, svi.transformation) -end - -function BangBang.empty!!(vi::SimpleVarInfo) - return resetaccs!!(Accessors.@set vi.values = empty!!(vi.values)) -end -Base.isempty(vi::SimpleVarInfo) = isempty(vi.values) - -getaccs(vi::SimpleVarInfo) = vi.accs -setaccs!!(vi::SimpleVarInfo, accs::AccumulatorTuple) = Accessors.@set vi.accs = accs - -""" - keys(vi::SimpleVarInfo) - -Return an iterator of keys present in `vi`. -""" -Base.keys(vi::SimpleVarInfo) = keys(vi.values) -Base.keys(vi::SimpleVarInfo{<:NamedTuple}) = map(k -> VarName{k}(), keys(vi.values)) - -function Base.show(io::IO, mime::MIME"text/plain", svi::SimpleVarInfo) - if !(svi.transformation isa NoTransformation) - print(io, "Transformed ") - end - - return print(io, "SimpleVarInfo(", svi.values, ", ", repr(mime, getaccs(svi)), ")") -end - -function Base.getindex(vi::SimpleVarInfo, vn::VarName, dist::Distribution) - return from_maybe_linked_internal(vi, vn, dist, getindex(vi, vn)) -end -function Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribution) - vals_linked = mapreduce(vcat, vns) do vn - getindex(vi, vn, dist) - end - return recombine(dist, vals_linked, length(vns)) -end - -Base.getindex(vi::SimpleVarInfo, vn::VarName) = getindex_internal(vi, vn) - -# `SimpleVarInfo` doesn't necessarily vectorize, so we can have arrays other than -# just `Vector`. -function Base.getindex(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) - return map(Base.Fix1(getindex, vi), vns) -end -# HACK: Needed to disambiguate. -Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = map(Base.Fix1(getindex, vi), vns) - -Base.getindex(svi::SimpleVarInfo, ::Colon) = values_as(svi, Vector) - -getindex_internal(vi::SimpleVarInfo, vn::VarName) = get(vi.values, vn) -# `AbstractDict` -function getindex_internal( - vi::SimpleVarInfo{<:Union{AbstractDict,VarNamedVector}}, vn::VarName -) - return getvalue(vi.values, vn) -end - -Base.haskey(vi::SimpleVarInfo, vn::VarName) = hasvalue(vi.values, vn) - -function BangBang.setindex!!(vi::SimpleVarInfo, val, vn::VarName) - # For `NamedTuple` we treat the symbol in `vn` as the _property_ to set. - return Accessors.@set vi.values = set!!(vi.values, vn, val) -end - -# TODO: Specialize to handle certain cases, e.g. a collection of `VarName` with -# same symbol and same type of, say, `IndexLens`, for improved `.~` performance. -function BangBang.setindex!!(vi::SimpleVarInfo, vals, vns::AbstractVector{<:VarName}) - for (vn, val) in zip(vns, vals) - vi = BangBang.setindex!!(vi, val, vn) - end - return vi -end - -function BangBang.setindex!!(vi::SimpleVarInfo{<:AbstractDict}, val, vn::VarName) - # For dictlike objects, we treat the entire `vn` as a _key_ to set. - dict = values_as(vi) - # Attempt to split into `parent` and `child` optic. - parent, child, issuccess = splitoptic(getoptic(vn)) do optic - o = optic === nothing ? identity : optic - haskey(dict, VarName{getsym(vn)}(o)) - end - # When combined with `VarInfo`, `nothing` is equivalent to `identity`. - keyoptic = parent === nothing ? identity : parent - - dict_new = if !issuccess - # Split doesn't exist ⟹ we're working with a new key. - BangBang.setindex!!(dict, val, vn) - else - # Split exists ⟹ trying to set an existing key. - vn_key = VarName{getsym(vn)}(keyoptic) - BangBang.setindex!!(dict, set!!(dict[vn_key], child, val), vn_key) - end - return Accessors.@set vi.values = dict_new -end - -# `NamedTuple` -function BangBang.push!!( - vi::SimpleVarInfo{<:NamedTuple}, ::VarName{sym,typeof(identity)}, value, ::Distribution -) where {sym} - return Accessors.@set vi.values = merge(vi.values, NamedTuple{(sym,)}((value,))) -end -function BangBang.push!!( - vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym}, value, ::Distribution -) where {sym} - return Accessors.@set vi.values = set!!(vi.values, vn, value) -end - -# `AbstractDict` -function BangBang.push!!( - vi::SimpleVarInfo{<:AbstractDict}, vn::VarName, value, ::Distribution -) - vi.values[vn] = value - return vi -end - -function BangBang.push!!( - vi::SimpleVarInfo{<:VarNamedVector}, vn::VarName, value, ::Distribution -) - # The semantics of push!! for SimpleVarInfo and VarNamedVector are different. For - # SimpleVarInfo, push!! allows the key to exist already, for VarNamedVector it does not. - # Hence we need to call update!! here, which has the same semantics as push!! does for - # SimpleVarInfo. - return Accessors.@set vi.values = setindex!!(vi.values, value, vn) -end - -const SimpleOrThreadSafeSimple{T,V,C} = Union{ - SimpleVarInfo{T,V,C},ThreadSafeVarInfo{<:SimpleVarInfo{T,V,C}} -} - -# Necessary for `matchingvalue` to work properly. -Base.eltype(::SimpleOrThreadSafeSimple{<:Any,V}) where {V} = V - -# `subset` -function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName}) - return SimpleVarInfo( - _subset(varinfo.values, vns), map(copy, getaccs(varinfo)), varinfo.transformation - ) -end - -function _subset(x::AbstractDict, vns::AbstractVector{VN}) where {VN<:VarName} - vns_present = collect(keys(x)) - vns_found = filter( - vn_present -> any(subsumes(vn, vn_present) for vn in vns), vns_present - ) - C = ConstructionBase.constructorof(typeof(x)) - if isempty(vns_found) - return C() - else - return C(vn => x[vn] for vn in vns_found) - end -end - -function _subset(x::NamedTuple, vns) - # NOTE: Here we can only handle `vns` that contain `identity` as optic. - if any(Base.Fix1(!==, identity) ∘ getoptic, vns) - throw( - ArgumentError( - "Cannot subset `NamedTuple` with non-`identity` `VarName`. " * - "For example, `@varname(x)` is allowed, but `@varname(x[1])` is not.", - ), - ) - end - - syms = map(getsym, vns) - x_syms = filter(Base.Fix2(in, syms), keys(x)) - return NamedTuple{Tuple(x_syms)}(Tuple(map(Base.Fix1(getindex, x), x_syms))) -end - -_subset(x::VarNamedVector, vns) = subset(x, vns) - -# `merge` -function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo) - values = merge(varinfo_left.values, varinfo_right.values) - accs = map(copy, getaccs(varinfo_right)) - transformation = merge_transformations( - varinfo_left.transformation, varinfo_right.transformation - ) - return SimpleVarInfo(values, accs, transformation) -end - -function set_transformed!!(vi::SimpleVarInfo, trans) - return set_transformed!!(vi, trans ? DynamicTransformation() : NoTransformation()) -end -function set_transformed!!(vi::SimpleVarInfo, transformation::AbstractTransformation) - return Accessors.@set vi.transformation = transformation -end -function set_transformed!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans) - return Accessors.@set vi.varinfo = set_transformed!!(vi.varinfo, trans) -end -function set_transformed!!(vi::SimpleOrThreadSafeSimple, trans::Bool, ::VarName) - # We keep this method around just to obey the AbstractVarInfo interface. - # However, note that this would only be a valid operation if it would be a - # no-op, which we check here. - if trans != is_transformed(vi) - error( - "Individual variables in SimpleVarInfo cannot have different `set_transformed` statuses.", - ) - end - return vi -end - -is_transformed(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) -is_transformed(vi::SimpleVarInfo, ::VarName) = is_transformed(vi) -function is_transformed(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) - return is_transformed(vi.varinfo, vn) -end -is_transformed(vi::ThreadSafeVarInfo{<:SimpleVarInfo}) = is_transformed(vi.varinfo) - -values_as(vi::SimpleVarInfo) = vi.values -values_as(vi::SimpleVarInfo{<:T}, ::Type{T}) where {T} = vi.values -function values_as(vi::SimpleVarInfo, ::Type{Vector}) - isempty(vi) && return Any[] - return mapreduce(tovec, vcat, values(vi.values)) -end -function values_as(vi::SimpleVarInfo, ::Type{D}) where {D<:AbstractDict} - return ConstructionBase.constructorof(D)(zip(keys(vi), values(vi.values))) -end -function values_as(vi::SimpleVarInfo{<:AbstractDict}, ::Type{NamedTuple}) - return NamedTuple((Symbol(k), v) for (k, v) in vi.values) -end -function values_as(vi::SimpleVarInfo, ::Type{T}) where {T} - return values_as(vi.values, T) -end - -""" - logjoint(model::Model, θ::Union{NamedTuple,AbstractDict}) - -Return the log joint probability of variables `θ` for the probabilistic `model`. - -See [`logprior`](@ref) and [`loglikelihood`](@ref). - -# Examples -```jldoctest; setup=:(using Distributions) -julia> @model function demo(x) - m ~ Normal() - for i in eachindex(x) - x[i] ~ Normal(m, 1.0) - end - end -demo (generic function with 2 methods) - -julia> # Using a `NamedTuple`. - logjoint(demo([1.0]), (m = 100.0, )) --9902.33787706641 - -julia> # Using a `OrderedDict`. - logjoint(demo([1.0]), OrderedDict(@varname(m) => 100.0)) --9902.33787706641 - -julia> # Truth. - logpdf(Normal(100.0, 1.0), 1.0) + logpdf(Normal(), 100.0) --9902.33787706641 -``` -""" -logjoint(model::Model, θ::Union{NamedTuple,AbstractDict}) = - logjoint(model, SimpleVarInfo(θ)) - -""" - logprior(model::Model, θ::Union{NamedTuple,AbstractDict}) - -Return the log prior probability of variables `θ` for the probabilistic `model`. - -See also [`logjoint`](@ref) and [`loglikelihood`](@ref). - -# Examples -```jldoctest; setup=:(using Distributions) -julia> @model function demo(x) - m ~ Normal() - for i in eachindex(x) - x[i] ~ Normal(m, 1.0) - end - end -demo (generic function with 2 methods) - -julia> # Using a `NamedTuple`. - logprior(demo([1.0]), (m = 100.0, )) --5000.918938533205 - -julia> # Using a `OrderedDict`. - logprior(demo([1.0]), OrderedDict(@varname(m) => 100.0)) --5000.918938533205 - -julia> # Truth. - logpdf(Normal(), 100.0) --5000.918938533205 -``` -""" -logprior(model::Model, θ::Union{NamedTuple,AbstractDict}) = - logprior(model, SimpleVarInfo(θ)) - -""" - loglikelihood(model::Model, θ::Union{NamedTuple,AbstractDict}) - -Return the log likelihood of variables `θ` for the probabilistic `model`. - -See also [`logjoint`](@ref) and [`logprior`](@ref). - -# Examples -```jldoctest; setup=:(using Distributions) -julia> @model function demo(x) - m ~ Normal() - for i in eachindex(x) - x[i] ~ Normal(m, 1.0) - end - end -demo (generic function with 2 methods) - -julia> # Using a `NamedTuple`. - loglikelihood(demo([1.0]), (m = 100.0, )) --4901.418938533205 - -julia> # Using a `OrderedDict`. - loglikelihood(demo([1.0]), OrderedDict(@varname(m) => 100.0)) --4901.418938533205 - -julia> # Truth. - logpdf(Normal(100.0, 1.0), 1.0) --4901.418938533205 -``` -""" -Distributions.loglikelihood(model::Model, θ::Union{NamedTuple,AbstractDict}) = - loglikelihood(model, SimpleVarInfo(θ)) - -# Allow usage of `NamedBijector` too. -function link!!( - t::StaticTransformation{<:Bijectors.NamedTransform}, - vi::SimpleVarInfo{<:NamedTuple}, - ::Model, -) - b = inverse(t.bijector) - x = vi.values - y, logjac = with_logabsdet_jacobian(b, x) - vi_new = Accessors.@set(vi.values = y) - if hasacc(vi_new, Val(:LogJacobian)) - vi_new = acclogjac!!(vi_new, logjac) - end - return set_transformed!!(vi_new, t) -end - -function invlink!!( - t::StaticTransformation{<:Bijectors.NamedTransform}, - vi::SimpleVarInfo{<:NamedTuple}, - ::Model, -) - b = t.bijector - y = vi.values - x, inv_logjac = with_logabsdet_jacobian(b, y) - vi_new = Accessors.@set(vi.values = x) - # Mildly confusing: we need to _add_ the logjac of the inverse transform, - # because we are trying to remove the logjac of the forward transform - # that was previously accumulated when linking. - if hasacc(vi_new, Val(:LogJacobian)) - vi_new = acclogjac!!(vi_new, inv_logjac) - end - return set_transformed!!(vi_new, NoTransformation()) -end - -# With `SimpleVarInfo`, when we're not working with linked variables, there's no need to do anything. -from_internal_transform(vi::SimpleVarInfo, ::VarName) = identity -from_internal_transform(vi::SimpleVarInfo, ::VarName, dist) = identity -# TODO: Should the following methods specialize on the case where we have a `StaticTransformation{<:Bijectors.NamedTransform}`? -from_linked_internal_transform(vi::SimpleVarInfo, ::VarName) = identity -function from_linked_internal_transform(vi::SimpleVarInfo, ::VarName, dist) - return invlink_transform(dist) -end - -has_varnamedvector(vi::SimpleVarInfo) = vi.values isa VarNamedVector diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 88200680a..d83cb289d 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -65,8 +65,6 @@ function map_accumulators!!(func::Function, vi::ThreadSafeVarInfo) return vi end -has_varnamedvector(vi::ThreadSafeVarInfo) = has_varnamedvector(vi.varinfo) - syms(vi::ThreadSafeVarInfo) = syms(vi.varinfo) setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn) diff --git a/src/utils.jl b/src/utils.jl index ba79f94b4..4a0eea96c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -9,9 +9,6 @@ const NO_DEFAULT = NoDefault() # A short-hand for a type commonly used in type signatures for VarInfo methods. VarNameTuple = NTuple{N,VarName} where {N} -# TODO(mhauru) This is currently used in the transformation functions of NoDist, -# ReshapeTransform, and UnwrapSingletonTransform, and in VarInfo. We should also use it in -# SimpleVarInfo and maybe other places. """ The type for all log probability variables. @@ -506,8 +503,6 @@ end # UnivariateDistributions need to be handled as a special case, because size(dist) is (), # which makes the usual machinery think we are dealing with a 0-dim array, whereas in # actuality we are dealing with a scalar. -# TODO(mhauru) Hopefully all this can go once the old Gibbs sampler is removed and -# VarNamedVector takes over from Metadata. function from_linked_vec_transform(dist::UnivariateDistribution) f_invlink = invlink_transform(dist) f_vec = from_vec_transform(inverse(f_invlink), size(dist)) diff --git a/src/varinfo.jl b/src/varinfo.jl deleted file mode 100644 index d1ea7dae3..000000000 --- a/src/varinfo.jl +++ /dev/null @@ -1,1810 +0,0 @@ -#### -#### Types for typed and untyped VarInfo -#### - -#################### -# VarInfo metadata # -#################### - -""" -The `Metadata` struct stores some metadata about the parameters of the model. This helps -query certain information about a variable, such as its distribution, which samplers -sample this variable, its value and whether this value is transformed to real space or -not. - -Let `md` be an instance of `Metadata`: -- `md.vns` is the vector of all `VarName` instances. -- `md.idcs` is the dictionary that maps each `VarName` instance to its index in - `md.vns`, `md.ranges` `md.dists`, and `md.is_transformed`. -- `md.vns[md.idcs[vn]] == vn`. -- `md.dists[md.idcs[vn]]` is the distribution of `vn`. -- `md.ranges[md.idcs[vn]]` is the index range of `vn` in `md.vals`. -- `md.vals[md.ranges[md.idcs[vn]]]` is the vector of values of corresponding to `vn`. -- `md.is_transformed` is a BitVector of true/false flags for whether a variable has been - transformed. `md.is_transformed[md.idcs[vn]]` is the value corresponding to `vn`. - -To make `md::Metadata` type stable, all the `md.vns` must have the same symbol -and distribution type. However, one can have a Julia variable, say `x`, that is a -matrix or a hierarchical array sampled in partitions, e.g. -`x[1][:] ~ MvNormal(zeros(2), I); x[2][:] ~ MvNormal(ones(2), I)`, and is managed by -a single `md::Metadata` so long as all the distributions on the RHS of `~` are of the -same type. Type unstable `Metadata` will still work but will have inferior performance. -When sampling, the first iteration uses a type unstable `Metadata` for all the -variables then a specialized `Metadata` is used for each symbol along with a function -barrier to make the rest of the sampling type stable. -""" -struct Metadata{ - TIdcs<:Dict{<:VarName,Int}, - TDists<:AbstractVector{<:Distribution}, - TVN<:AbstractVector{<:VarName}, - TVal<:AbstractVector{<:Real}, -} - # Mapping from the `VarName` to its integer index in `vns`, `ranges` and `dists` - idcs::TIdcs # Dict{<:VarName,Int} - - # Vector of identifiers for the random variables, where `vns[idcs[vn]] == vn` - vns::TVN # AbstractVector{<:VarName} - - # Vector of index ranges in `vals` corresponding to `vns` - # Each `VarName` `vn` has a single index or a set of contiguous indices in `vals` - ranges::Vector{UnitRange{Int}} - - # Vector of values of all the univariate, multivariate and matrix variables - # The value(s) of `vn` is/are `vals[ranges[idcs[vn]]]` - vals::TVal # AbstractVector{<:Real} - - # Vector of distributions correpsonding to `vns` - dists::TDists # AbstractVector{<:Distribution} - - is_transformed::BitVector -end - -function Base.:(==)(md1::Metadata, md2::Metadata) - return ( - md1.idcs == md2.idcs && - md1.vns == md2.vns && - md1.ranges == md2.ranges && - md1.vals == md2.vals && - md1.dists == md2.dists && - md1.is_transformed == md2.is_transformed - ) -end - -########### -# VarInfo # -########### - -""" - struct VarInfo{Tmeta,Accs<:AccumulatorTuple} <: AbstractVarInfo - metadata::Tmeta - accs::Accs - end - -A light wrapper over some kind of metadata. - -The type of the metadata can be one of a number of options. It may either be a -`Metadata` or a `VarNamedVector`, _or_, it may be a `NamedTuple` which maps -symbols to `Metadata` or `VarNamedVector` instances. Here, a _symbol_ refers -to a Julia variable and may consist of one or more `VarName`s which appear on -the left-hand side of tilde statements. For example, `x[1]` and `x[2]` both -have the same symbol `x`. - -Several type aliases are provided for these forms of VarInfos: -- `VarInfo{<:Metadata}` is `UntypedVarInfo` -- `VarInfo{<:VarNamedVector}` is `UntypedVectorVarInfo` -- `VarInfo{<:NamedTuple}` is `NTVarInfo` - -The NamedTuple form, i.e. `NTVarInfo`, is useful for maintaining type stability -of model evaluation. However, the element type of NamedTuples are not contained -in its type itself: thus, there is no way to use the type system to determine -whether the elements of the NamedTuple are `Metadata` or `VarNamedVector`. - -Note that for NTVarInfo, it is the user's responsibility to ensure that each -symbol is visited at least once during model evaluation, regardless of any -stochastic branching. -""" -struct VarInfo{Tmeta,Accs<:AccumulatorTuple} <: AbstractVarInfo - metadata::Tmeta - accs::Accs -end -function VarInfo(meta=Metadata()) - return VarInfo(meta, default_accumulators()) -end - -""" - VarInfo( - [rng::Random.AbstractRNG], - model, - [init_strategy::AbstractInitStrategy] - ) - -Generate a `VarInfo` object for the given `model`, by initialising it with the -given `rng` and `init_strategy`. - -!!! warning - - This function currently returns a `VarInfo` with its metadata field set to - a `NamedTuple` of `Metadata`. This is an implementation detail. In general, - this function may return any kind of object that satisfies the - `AbstractVarInfo` interface. If you require precise control over the type - of `VarInfo` returned, use the internal functions `untyped_varinfo`, - `typed_varinfo`, `untyped_vector_varinfo`, or `typed_vector_varinfo` - instead. -""" -function VarInfo( - rng::Random.AbstractRNG, - model::Model, - init_strategy::AbstractInitStrategy=InitFromPrior(), -) - return typed_varinfo(rng, model, init_strategy) -end -function VarInfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) - return VarInfo(Random.default_rng(), model, init_strategy) -end - -const UntypedVectorVarInfo = VarInfo{<:VarNamedVector} -const UntypedVarInfo = VarInfo{<:Metadata} -# TODO: NTVarInfo carries no information about the type of the actual metadata -# i.e. the elements of the NamedTuple. It could be Metadata or it could be -# VarNamedVector. -# Resolving this ambiguity would likely require us to replace NamedTuple with -# something which carried both its keys as well as its values' types as type -# parameters. -const NTVarInfo = VarInfo{<:NamedTuple} -const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{ - VarInfo{Tmeta},ThreadSafeVarInfo{<:VarInfo{Tmeta}} -} - -function Base.:(==)(vi1::VarInfo, vi2::VarInfo) - return (vi1.metadata == vi2.metadata && vi1.accs == vi2.accs) -end - -# NOTE: This is kind of weird, but it effectively preserves the "old" -# behavior where we're allowed to call `link!` on the same `VarInfo` -# multiple times. -transformation(::VarInfo) = DynamicTransformation() - -# No-op if we're already working with a `VarNamedVector`. -metadata_to_varnamedvector(vnv::VarNamedVector) = vnv -function metadata_to_varnamedvector(md::Metadata) - idcs = copy(md.idcs) - vns = copy(md.vns) - ranges = copy(md.ranges) - vals = copy(md.vals) - is_trans = map(Base.Fix1(is_transformed, md), md.vns) - transforms = map(md.dists, is_trans) do dist, trans - if trans - return from_linked_vec_transform(dist) - else - return from_vec_transform(dist) - end - end - - return VarNamedVector( - OrderedDict{eltype(keys(idcs)),Int}(idcs), vns, ranges, vals, transforms, is_trans - ) -end - -function has_varnamedvector(vi::VarInfo) - return vi.metadata isa VarNamedVector || - (vi isa NTVarInfo && any(Base.Fix2(isa, VarNamedVector), values(vi.metadata))) -end - -######################## -# VarInfo constructors # -######################## - -""" - untyped_varinfo([rng, ]model[, init_strategy]) - -Construct a VarInfo object for the given `model`, which has just a single -`Metadata` as its metadata field. - -# Arguments -- `rng::Random.AbstractRNG`: The random number generator to use during model evaluation -- `model::Model`: The model for which to create the varinfo object -- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `InitFromPrior()`. -""" -function untyped_varinfo( - rng::Random.AbstractRNG, - model::Model, - init_strategy::AbstractInitStrategy=InitFromPrior(), -) - return last(init!!(rng, model, VarInfo(Metadata()), init_strategy)) -end -function untyped_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) - return untyped_varinfo(Random.default_rng(), model, init_strategy) -end - -""" - typed_varinfo(vi::UntypedVarInfo) - -This function finds all the unique `sym`s from the instances of `VarName{sym}` found in -`vi.metadata.vns`. It then extracts the metadata associated with each symbol from the -global `vi.metadata` field. Finally, a new `VarInfo` is created with a new `metadata` as -a `NamedTuple` mapping from symbols to type-stable `Metadata` instances, one for each -symbol. -""" -function typed_varinfo(vi::UntypedVarInfo) - meta = vi.metadata - new_metas = Metadata[] - # Symbols of all instances of `VarName{sym}` in `vi.vns` - syms_tuple = Tuple(syms(vi)) - for s in syms_tuple - # Find all indices in `vns` with symbol `s` - inds = findall(vn -> getsym(vn) === s, meta.vns) - n = length(inds) - # New `vns` - sym_vns = getindex.((meta.vns,), inds) - # New idcs - sym_idcs = Dict(a => i for (i, a) in enumerate(sym_vns)) - # New dists - sym_dists = getindex.((meta.dists,), inds) - # New is_transformed - sym_is_transformed = meta.is_transformed[inds] - - # Extract new ranges and vals - _ranges = getindex.((meta.ranges,), inds) - # `copy.()` is a workaround to reduce the eltype from Real to Int or Float64 - _vals = [copy.(meta.vals[_ranges[i]]) for i in 1:n] - sym_ranges = Vector{eltype(_ranges)}(undef, n) - start = 0 - for i in 1:n - sym_ranges[i] = (start + 1):(start + length(_vals[i])) - start += length(_vals[i]) - end - sym_vals = foldl(vcat, _vals) - - push!( - new_metas, - Metadata( - sym_idcs, sym_vns, sym_ranges, sym_vals, sym_dists, sym_is_transformed - ), - ) - end - nt = NamedTuple{syms_tuple}(Tuple(new_metas)) - return VarInfo(nt, copy(vi.accs)) -end -function typed_varinfo(vi::NTVarInfo) - # This function preserves the behaviour of typed_varinfo(vi) where vi is - # already a NTVarInfo - has_varnamedvector(vi) && error( - "Cannot convert VarInfo with NamedTuple of VarNamedVector to VarInfo with NamedTuple of Metadata", - ) - return vi -end -""" - typed_varinfo([rng, ]model[, init_strategy]) - -Return a VarInfo object for the given `model`, which has a NamedTuple of -`Metadata` structs as its metadata field. - -# Arguments -- `rng::Random.AbstractRNG`: The random number generator to use during model evaluation -- `model::Model`: The model for which to create the varinfo object -- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `InitFromPrior()`. -""" -function typed_varinfo( - rng::Random.AbstractRNG, - model::Model, - init_strategy::AbstractInitStrategy=InitFromPrior(), -) - return typed_varinfo(untyped_varinfo(rng, model, init_strategy)) -end -function typed_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) - return typed_varinfo(Random.default_rng(), model, init_strategy) -end - -""" - untyped_vector_varinfo([rng, ]model[, init_strategy]) - -Return a VarInfo object for the given `model`, which has just a single -`VarNamedVector` as its metadata field. - -# Arguments -- `rng::Random.AbstractRNG`: The random number generator to use during model evaluation -- `model::Model`: The model for which to create the varinfo object -- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `InitFromPrior()`. -""" -function untyped_vector_varinfo(vi::UntypedVarInfo) - md = metadata_to_varnamedvector(vi.metadata) - return VarInfo(md, copy(vi.accs)) -end -function untyped_vector_varinfo( - rng::Random.AbstractRNG, - model::Model, - init_strategy::AbstractInitStrategy=InitFromPrior(), -) - return last(init!!(rng, model, VarInfo(VarNamedVector()), init_strategy)) -end -function untyped_vector_varinfo( - model::Model, init_strategy::AbstractInitStrategy=InitFromPrior() -) - return untyped_vector_varinfo(Random.default_rng(), model, init_strategy) -end - -""" - typed_vector_varinfo([rng, ]model[, init_strategy]) - -Return a VarInfo object for the given `model`, which has a NamedTuple of -`VarNamedVector`s as its metadata field. - -# Arguments -- `rng::Random.AbstractRNG`: The random number generator to use during model evaluation -- `model::Model`: The model for which to create the varinfo object -- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `InitFromPrior()`. -""" -function typed_vector_varinfo(vi::NTVarInfo) - md = map(metadata_to_varnamedvector, vi.metadata) - return VarInfo(md, copy(vi.accs)) -end -function typed_vector_varinfo(vi::UntypedVectorVarInfo) - new_metas = group_by_symbol(vi.metadata) - nt = NamedTuple(new_metas) - return VarInfo(nt, copy(vi.accs)) -end -function typed_vector_varinfo( - rng::Random.AbstractRNG, - model::Model, - init_strategy::AbstractInitStrategy=InitFromPrior(), -) - return typed_vector_varinfo(untyped_vector_varinfo(rng, model, init_strategy)) -end -function typed_vector_varinfo( - model::Model, init_strategy::AbstractInitStrategy=InitFromPrior() -) - return typed_vector_varinfo(Random.default_rng(), model, init_strategy) -end - -""" - vector_length(varinfo::VarInfo) - -Return the length of the vector representation of `varinfo`. -""" -vector_length(varinfo::VarInfo) = length(varinfo.metadata) -vector_length(varinfo::NTVarInfo) = sum(length, varinfo.metadata) -vector_length(md::Metadata) = sum(length, md.ranges) - -function unflatten(vi::VarInfo, x::AbstractVector) - md = unflatten_metadata(vi.metadata, x) - return VarInfo(md, vi.accs) -end - -# We would call this `unflatten` if not for `unflatten` having a method for NamedTuples in -# utils.jl. -@generated function unflatten_metadata( - metadata::NamedTuple{names}, x::AbstractVector -) where {names} - exprs = [] - offset = :(0) - for f in names - mdf = :(metadata.$f) - len = :(sum(length, $mdf.ranges)) - push!(exprs, :($f = unflatten_metadata($mdf, x[($offset + 1):($offset + $len)]))) - offset = :($offset + $len) - end - length(exprs) == 0 && return :(NamedTuple()) - return :($(exprs...),) -end - -function unflatten_metadata(md::Metadata, x::AbstractVector) - return Metadata(md.idcs, md.vns, md.ranges, x, md.dists, md.is_transformed) -end - -unflatten_metadata(vnv::VarNamedVector, x::AbstractVector) = unflatten(vnv, x) - -#### -#### Internal functions -#### - -""" - Metadata() - -Construct an empty type unstable instance of `Metadata`. -""" -function Metadata() - vals = Vector{Real}() - is_transformed = BitVector() - - return Metadata( - Dict{VarName,Int}(), - Vector{VarName}(), - Vector{UnitRange{Int}}(), - vals, - Vector{Distribution}(), - is_transformed, - ) -end - -""" - empty!(meta::Metadata) - -Empty the fields of `meta`. - -This is useful when using a sampling algorithm that assumes an empty `meta`, e.g. `SMC`. -""" -function empty!(meta::Metadata) - empty!(meta.idcs) - empty!(meta.vns) - empty!(meta.ranges) - empty!(meta.vals) - empty!(meta.dists) - empty!(meta.is_transformed) - return meta -end - -# Removes the first element of a NamedTuple. The pairs in a NamedTuple are ordered, so this is well-defined. -if VERSION < v"1.1" - _tail(nt::NamedTuple{names}) where {names} = NamedTuple{Base.tail(names)}(nt) -else - _tail(nt::NamedTuple) = Base.tail(nt) -end - -function subset(varinfo::VarInfo, vns::AbstractVector{<:VarName}) - metadata = subset(varinfo.metadata, vns) - return VarInfo(metadata, map(copy, getaccs(varinfo))) -end - -function subset(metadata::NamedTuple, vns::AbstractVector{<:VarName}) - vns_syms = Set(unique(map(getsym, vns))) - syms = filter(Base.Fix2(in, vns_syms), keys(metadata)) - metadatas = map(syms) do sym - subset(getfield(metadata, sym), filter(==(sym) ∘ getsym, vns)) - end - return NamedTuple{syms}(metadatas) -end - -# The above method is type unstable since we don't know which symbols are in `vns`. -# In the below special case, when all `vns` have the same symbol, we can write a type stable -# version. - -@generated function subset( - metadata::NamedTuple{names}, vns::AbstractVector{<:VarName{sym}} -) where {names,sym} - return if (sym in names) - # TODO(mhauru) Note that this could still generate an empty metadata object if none - # of the lenses in `vns` are in `metadata`. Not sure if that's okay. Checking for - # emptiness would make this type unstable again. - :((; $sym=subset(metadata.$sym, vns))) - else - :(NamedTuple{}()) - end -end - -function subset(metadata::Metadata, vns_given::AbstractVector{VN}) where {VN<:VarName} - # TODO: Should we error if `vns` contains a variable that is not in `metadata`? - # Find all the vns in metadata that are subsumed by one of the given vns. - vns = filter(vn -> any(subsumes(vn_given, vn) for vn_given in vns_given), metadata.vns) - indices_for_vns = map(Base.Fix1(getindex, metadata.idcs), vns) - indices = if isempty(vns) - Dict{VarName,Int}() - else - Dict(vn => i for (i, vn) in enumerate(vns)) - end - # Construct new `vals` and `ranges`. - vals_original = metadata.vals - ranges_original = metadata.ranges - # Allocate the new `vals`. and `ranges`. - vals = similar(metadata.vals, sum(length, ranges_original[indices_for_vns]; init=0)) - ranges = similar(ranges_original, length(vns)) - # The new range `r` for `vns[i]` is offset by `offset` and - # has the same length as the original range `r_original`. - # The new `indices` (from above) ensures ordering according to `vns`. - # NOTE: This means that the order of the variables in `vns` defines the order - # in the resulting `varinfo`! This can have performance implications, e.g. - # if in the model we have something like - # - # for i = 1:N - # x[i] ~ Normal() - # end - # - # and we then we do - # - # subset(varinfo, [@varname(x[i]) for i in shuffle(keys(varinfo))]) - # - # the resulting `varinfo` will have `vals` ordered differently from the - # original `varinfo`, which can have performance implications. - offset = 0 - for (idx, idx_original) in enumerate(indices_for_vns) - r_original = ranges_original[idx_original] - r = (offset + 1):(offset + length(r_original)) - vals[r] = vals_original[r_original] - ranges[idx] = r - offset = r[end] - end - - dists = metadata.dists[indices_for_vns] - is_transformed = metadata.is_transformed[indices_for_vns] - return Metadata(indices, vns, ranges, vals, dists, is_transformed) -end - -function Base.merge(varinfo_left::VarInfo, varinfo_right::VarInfo) - return _merge(varinfo_left, varinfo_right) -end - -function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo) - metadata = merge_metadata(varinfo_left.metadata, varinfo_right.metadata) - accs = map(copy, getaccs(varinfo_right)) - return VarInfo(metadata, accs) -end - -function merge_metadata(vnv_left::VarNamedVector, vnv_right::VarNamedVector) - return merge(vnv_left, vnv_right) -end - -@generated function merge_metadata( - metadata_left::NamedTuple{names_left}, metadata_right::NamedTuple{names_right} -) where {names_left,names_right} - names = Expr(:tuple) - vals = Expr(:tuple) - # Loop over `names_left` first because we want to preserve the order of the variables. - for sym in names_left - push!(names.args, QuoteNode(sym)) - if sym in names_right - push!(vals.args, :(merge_metadata(metadata_left.$sym, metadata_right.$sym))) - else - push!(vals.args, :(metadata_left.$sym)) - end - end - # Loop over remaining variables in `names_right`. - names_right_only = filter(∉(names_left), names_right) - for sym in names_right_only - push!(names.args, QuoteNode(sym)) - push!(vals.args, :(metadata_right.$sym)) - end - - return :(NamedTuple{$names}($vals)) -end - -function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) - # Extract the varnames. - vns_left = metadata_left.vns - vns_right = metadata_right.vns - vns_both = union(vns_left, vns_right) - - # Determine `eltype` of `vals`. - T_left = eltype(metadata_left.vals) - T_right = eltype(metadata_right.vals) - T = promote_type(T_left, T_right) - # TODO: Is this necessary? - if !(T <: Real) - T = Real - end - - # Determine `eltype` of `dists`. - D_left = eltype(metadata_left.dists) - D_right = eltype(metadata_right.dists) - D = promote_type(D_left, D_right) - # TODO: Is this necessary? - if !(D <: Distribution) - D = Distribution - end - - # Initialize required fields for `metadata`. - vns = VarName[] - idcs = Dict{VarName,Int}() - ranges = Vector{UnitRange{Int}}() - vals = T[] - dists = D[] - transformed = BitVector() - - # Range offset. - offset = 0 - - for (idx, vn) in enumerate(vns_both) - idcs[vn] = idx - push!(vns, vn) - metadata_for_vn = vn in vns_right ? metadata_right : metadata_left - - val = getindex_internal(metadata_for_vn, vn) - append!(vals, val) - r = (offset + 1):(offset + length(val)) - push!(ranges, r) - offset = r[end] - dist = getdist(metadata_for_vn, vn) - push!(dists, dist) - push!(transformed, is_transformed(metadata_for_vn, vn)) - end - - return Metadata(idcs, vns, ranges, vals, dists, transformed) -end - -const VarView = Union{Int,UnitRange,Vector{Int}} - -""" - setval!(vi::UntypedVarInfo, val, vview::Union{Int, UnitRange, Vector{Int}}) - -Set the value of `vi.vals[vview]` to `val`. -""" -setval!(vi::UntypedVarInfo, val, vview::VarView) = vi.metadata.vals[vview] = val - -""" - getmetadata(vi::VarInfo, vn::VarName) - -Return the metadata in `vi` that belongs to `vn`. -""" -getmetadata(vi::VarInfo, vn::VarName) = vi.metadata -getmetadata(vi::NTVarInfo, vn::VarName) = getfield(vi.metadata, getsym(vn)) - -""" - getidx(vi::VarInfo, vn::VarName) - -Return the index of `vn` in the metadata of `vi` corresponding to `vn`. -""" -getidx(vi::VarInfo, vn::VarName) = getidx(getmetadata(vi, vn), vn) -getidx(md::Metadata, vn::VarName) = md.idcs[vn] - -""" - getrange(vi::VarInfo, vn::VarName) - -Return the index range of `vn` in the metadata of `vi`. -""" -getrange(vi::VarInfo, vn::VarName) = getrange(getmetadata(vi, vn), vn) -getrange(md::Metadata, vn::VarName) = md.ranges[getidx(md, vn)] - -""" - setrange!(vi::VarInfo, vn::VarName, range) - -Set the index range of `vn` in the metadata of `vi` to `range`. -""" -setrange!(vi::VarInfo, vn::VarName, range) = setrange!(getmetadata(vi, vn), vn, range) -setrange!(md::Metadata, vn::VarName, range) = md.ranges[getidx(md, vn)] = range - -""" - getdist(vi::VarInfo, vn::VarName) - -Return the distribution from which `vn` was sampled in `vi`. -""" -getdist(vi::VarInfo, vn::VarName) = getdist(getmetadata(vi, vn), vn) -getdist(md::Metadata, vn::VarName) = md.dists[getidx(md, vn)] -# TODO(mhauru) Remove this once the old Gibbs sampler stuff is gone. -function getdist(::VarNamedVector, ::VarName) - throw(ErrorException("getdist does not exist for VarNamedVector")) -end - -getindex_internal(vi::VarInfo, vn::VarName) = getindex_internal(getmetadata(vi, vn), vn) -# TODO(torfjelde): Use `view` instead of `getindex`. Requires addressing type-stability issues though, -# since then we might be returning a `SubArray` rather than an `Array`, which is typically -# what a bijector would result in, even if the input is a view (`SubArray`). -# TODO(torfjelde): An alternative is to implement `view` directly instead. -getindex_internal(md::Metadata, vn::VarName) = getindex(md.vals, getrange(md, vn)) -function getindex_internal(vi::VarInfo, vns::Vector{<:VarName}) - return mapreduce(Base.Fix1(getindex_internal, vi), vcat, vns) -end -getindex_internal(vi::VarInfo, ::Colon) = getindex_internal(vi.metadata, Colon()) -# NOTE: `mapreduce` over `NamedTuple` results in worse type-inference. -# See for example https://github.com/JuliaLang/julia/pull/46381. -function getindex_internal(vi::NTVarInfo, ::Colon) - return reduce(vcat, map(Base.Fix2(getindex_internal, Colon()), vi.metadata)) -end -function getindex_internal(vi::VarInfo{NamedTuple{(),Tuple{}}}, ::Colon) - return float(Real)[] -end -function getindex_internal(md::Metadata, ::Colon) - return mapreduce( - Base.Fix1(getindex_internal, md), vcat, md.vns; init=similar(md.vals, 0) - ) -end - -""" - setval!(vi::VarInfo, val, vn::VarName) - -Set the value(s) of `vn` in the metadata of `vi` to `val`. - -The values may or may not be transformed to Euclidean space. -""" -setval!(vi::VarInfo, val, vn::VarName) = setval!(getmetadata(vi, vn), val, vn) -function setval!(md::Metadata, val::AbstractVector, vn::VarName) - return md.vals[getrange(md, vn)] = val -end -function setval!(md::Metadata, val, vn::VarName) - return md.vals[getrange(md, vn)] = tovec(val) -end - -function set_transformed!!(vi::NTVarInfo, val::Bool, vn::VarName) - md = set_transformed!!(getmetadata(vi, vn), val, vn) - return Accessors.@set vi.metadata[getsym(vn)] = md -end - -function set_transformed!!(vi::VarInfo, val::Bool, vn::VarName) - md = set_transformed!!(getmetadata(vi, vn), val, vn) - return VarInfo(md, vi.accs) -end - -function set_transformed!!(metadata::Metadata, val::Bool, vn::VarName) - metadata.is_transformed[getidx(metadata, vn)] = val - return metadata -end - -function set_transformed!!(vi::VarInfo, val::Bool) - for vn in keys(vi) - vi = set_transformed!!(vi, val, vn) - end - - return vi -end - -set_transformed!!(vi::VarInfo, ::NoTransformation) = set_transformed!!(vi, false) -# HACK: This is necessary to make something like `link!!(transformation, vi, model)` -# work properly, which will transform the variables according to `transformation` -# and then call `set_transformed!!(vi, transformation)`. An alternative would be to add -# the `transformation` to the `VarInfo` object, but at the moment doesn't seem -# worth it as `VarInfo` has its own way of handling transformations. -set_transformed!!(vi::VarInfo, ::AbstractTransformation) = set_transformed!!(vi, true) - -""" - syms(vi::VarInfo) - -Returns a tuple of the unique symbols of random variables in `vi`. -""" -syms(vi::UntypedVarInfo) = Tuple(unique!(map(getsym, vi.metadata.vns))) # get all symbols -syms(vi::NTVarInfo) = keys(vi.metadata) - -_getidcs(vi::UntypedVarInfo) = 1:length(vi.metadata.idcs) -_getidcs(vi::NTVarInfo) = _getidcs(vi.metadata) - -@generated function _getidcs(metadata::NamedTuple{names}) where {names} - exprs = [] - for f in names - push!(exprs, :($f = findinds(metadata.$f))) - end - length(exprs) == 0 && return :(NamedTuple()) - return :($(exprs...),) -end - -@inline findinds(f_meta::Metadata) = eachindex(f_meta.vns) -findinds(vnv::VarNamedVector) = 1:length(vnv.varnames) - -""" - all_varnames_grouped_by_symbol(vi::NTVarInfo) - -Return a `NamedTuple` of the variables in `vi` grouped by symbol. -""" -all_varnames_grouped_by_symbol(vi::NTVarInfo) = all_varnames_grouped_by_symbol(vi.metadata) - -@generated function all_varnames_grouped_by_symbol(md::NamedTuple{names}) where {names} - expr = Expr(:tuple) - for f in names - push!(expr.args, :($f = keys(md.$f))) - end - return expr -end - -#### -#### APIs for typed and untyped VarInfo -#### - -function BangBang.empty!!(vi::VarInfo) - _empty!(vi.metadata) - vi = resetaccs!!(vi) - return vi -end - -_empty!(metadata) = empty!(metadata) -@generated function _empty!(metadata::NamedTuple{names}) where {names} - expr = Expr(:block) - for f in names - push!(expr.args, :(empty!(metadata.$f))) - end - return expr -end - -# `keys` -Base.keys(md::Metadata) = md.vns -Base.keys(vi::VarInfo) = Base.keys(vi.metadata) - -# HACK: Necessary to avoid returning `Any[]` which won't dispatch correctly -# on other methods in the codebase which requires `Vector{<:VarName}`. -Base.keys(vi::NTVarInfo{<:NamedTuple{()}}) = VarName[] -@generated function Base.keys(vi::NTVarInfo{<:NamedTuple{names}}) where {names} - expr = Expr(:call) - push!(expr.args, :vcat) - - for n in names - push!(expr.args, :(keys(vi.metadata.$n))) - end - - return expr -end - -is_transformed(vi::VarInfo, vn::VarName) = is_transformed(getmetadata(vi, vn), vn) -is_transformed(md::Metadata, vn::VarName) = md.is_transformed[getidx(md, vn)] - -getaccs(vi::VarInfo) = vi.accs -setaccs!!(vi::VarInfo, accs::AccumulatorTuple) = Accessors.@set vi.accs = accs - -# Need to introduce the _isempty to avoid type piracy of isempty(::NamedTuple). -isempty(vi::VarInfo) = _isempty(vi.metadata) -_isempty(metadata::Metadata) = isempty(metadata.idcs) -_isempty(vnv::VarNamedVector) = isempty(vnv) -@generated function _isempty(metadata::NamedTuple{names}) where {names} - return Expr(:&&, (:(_isempty(metadata.$f)) for f in names)...) -end - -function link!!(::DynamicTransformation, vi::NTVarInfo, model::Model) - vns = all_varnames_grouped_by_symbol(vi) - # If we're working with a `VarNamedVector`, we always use immutable. - has_varnamedvector(vi) && return _link(model, vi, vns) - vi = _link!!(vi, vns) - return vi -end - -function link!!(::DynamicTransformation, vi::VarInfo, model::Model) - vns = keys(vi) - # If we're working with a `VarNamedVector`, we always use immutable. - has_varnamedvector(vi) && return _link(model, vi, vns) - vi = _link!!(vi, vns) - return vi -end - -function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, model::Model) - # By default this will simply evaluate the model with `DynamicTransformationContext`, - # and so we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, model) -end - -function link!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model::Model) - # If we're working with a `VarNamedVector`, we always use immutable. - has_varnamedvector(vi) && return _link(model, vi, vns) - vi = _link!!(vi, vns) - return vi -end - -function link!!( - t::DynamicTransformation, - vi::ThreadSafeVarInfo{<:VarInfo}, - vns::VarNameTuple, - model::Model, -) - # By default this will simply evaluate the model with `DynamicTransformationContext`, - # and so we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, vns, model) -end - -function _link!!(vi::UntypedVarInfo, vns) - # TODO: Change to a lazy iterator over `vns` - if ~is_transformed(vi, vns[1]) - for vn in vns - f = internal_to_linked_internal_transform(vi, vn) - vi = _inner_transform!(vi, vn, f) - vi = set_transformed!!(vi, true, vn) - end - return vi - else - @warn("[DynamicPPL] attempt to link a linked vi") - end -end - -# If we try to _link!! a NTVarInfo with a Tuple of VarNames, first convert it to a -# NamedTuple that matches the structure of the NTVarInfo. -function _link!!(vi::NTVarInfo, vns::VarNameTuple) - return _link!!(vi, group_varnames_by_symbol(vns)) -end - -function _link!!(vi::NTVarInfo, vns::NamedTuple) - return _link!!(vi.metadata, vi, vns) -end - -""" - filter_subsumed(filter_vns, filtered_vns) - -Return the subset of `filtered_vns` that are subsumed by any variable in `filter_vns`. -""" -function filter_subsumed(filter_vns, filtered_vns) - return filter(x -> any(subsumes(y, x) for y in filter_vns), filtered_vns) -end - -@generated function _link!!( - ::NamedTuple{metadata_names}, vi, varnames::NamedTuple{vns_names} -) where {metadata_names,vns_names} - expr = Expr(:block) - for f in metadata_names - if !(f in vns_names) - continue - end - push!( - expr.args, - quote - f_vns = vi.metadata.$f.vns - f_vns = filter_subsumed(varnames.$f, f_vns) - if !isempty(f_vns) - if !is_transformed(vi, f_vns[1]) - # Iterate over all `f_vns` and transform - for vn in f_vns - f = internal_to_linked_internal_transform(vi, vn) - vi = _inner_transform!(vi, vn, f) - vi = set_transformed!!(vi, true, vn) - end - else - @warn("[DynamicPPL] attempt to link a linked vi") - end - end - end, - ) - end - push!(expr.args, :(return vi)) - return expr -end - -function invlink!!(::DynamicTransformation, vi::NTVarInfo, model::Model) - vns = all_varnames_grouped_by_symbol(vi) - # If we're working with a `VarNamedVector`, we always use immutable. - has_varnamedvector(vi) && return _invlink(model, vi, vns) - vi = _invlink!!(vi, vns) - return vi -end - -function invlink!!(::DynamicTransformation, vi::VarInfo, model::Model) - vns = keys(vi) - # If we're working with a `VarNamedVector`, we always use immutable. - has_varnamedvector(vi) && return _invlink(model, vi, vns) - vi = _invlink!!(vi, vns) - return vi -end - -function invlink!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, model::Model) - # By default this will simply evaluate the model with `DynamicTransformationContext`, - # and so we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(t, vi.varinfo, model) -end - -function invlink!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model::Model) - # If we're working with a `VarNamedVector`, we always use immutable. - has_varnamedvector(vi) && return _invlink(model, vi, vns) - vi = _invlink!!(vi, vns) - return vi -end - -function invlink!!( - ::DynamicTransformation, - vi::ThreadSafeVarInfo{<:VarInfo}, - vns::VarNameTuple, - model::Model, -) - # By default this will simply evaluate the model with `DynamicTransformationContext`, and so - # we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(vi.varinfo, vns, model) -end - -function maybe_invlink_before_eval!!(vi::VarInfo, model::Model) - # Because `VarInfo` does not contain any information about what the transformation - # other than whether or not it has actually been transformed, the best we can do - # is just assume that `default_transformation` is the correct one if - # `is_transformed(vi)`. - t = is_transformed(vi) ? default_transformation(model, vi) : NoTransformation() - return maybe_invlink_before_eval!!(t, vi, model) -end - -function _invlink!!(vi::UntypedVarInfo, vns) - if is_transformed(vi, vns[1]) - for vn in vns - f = linked_internal_to_internal_transform(vi, vn) - vi = _inner_transform!(vi, vn, f) - vi = set_transformed!!(vi, false, vn) - end - return vi - else - @warn("[DynamicPPL] attempt to invlink an invlinked vi") - end -end - -# If we try to _invlink!! a NTVarInfo with a Tuple of VarNames, first convert it to a -# NamedTuple that matches the structure of the NTVarInfo. -function _invlink!!(vi::NTVarInfo, vns::VarNameTuple) - return _invlink!!(vi.metadata, vi, group_varnames_by_symbol(vns)) -end - -function _invlink!!(vi::NTVarInfo, vns::NamedTuple) - return _invlink!!(vi.metadata, vi, vns) -end - -@generated function _invlink!!( - ::NamedTuple{metadata_names}, vi, vns::NamedTuple{vns_names} -) where {metadata_names,vns_names} - expr = Expr(:block) - for f in metadata_names - if !(f in vns_names) - continue - end - - push!( - expr.args, - quote - f_vns = vi.metadata.$f.vns - f_vns = filter_subsumed(vns.$f, f_vns) - if is_transformed(vi, f_vns[1]) - # Iterate over all `f_vns` and transform - for vn in f_vns - f = linked_internal_to_internal_transform(vi, vn) - vi = _inner_transform!(vi, vn, f) - vi = set_transformed!!(vi, false, vn) - end - else - @warn("[DynamicPPL] attempt to invlink an invlinked vi") - end - end, - ) - end - push!(expr.args, :(return vi)) - return expr -end - -function _inner_transform!(vi::VarInfo, vn::VarName, f) - return _inner_transform!(getmetadata(vi, vn), vi, vn, f) -end - -function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, f) - # TODO: Use inplace versions to avoid allocations - yvec, logjac = with_logabsdet_jacobian(f, getindex_internal(md, vn)) - # Determine the new range. - start = first(getrange(md, vn)) - # NOTE: `length(yvec)` should never be longer than `getrange(vi, vn)`. - setrange!(md, vn, start:(start + length(yvec) - 1)) - # Set the new value. - setval!(md, yvec, vn) - if hasacc(vi, Val(:LogJacobian)) - vi = acclogjac!!(vi, logjac) - end - return vi -end - -function link(::DynamicTransformation, vi::NTVarInfo, model::Model) - return _link(model, vi, all_varnames_grouped_by_symbol(vi)) -end - -function link(::DynamicTransformation, varinfo::VarInfo, model::Model) - return _link(model, varinfo, keys(varinfo)) -end - -function link(::DynamicTransformation, varinfo::ThreadSafeVarInfo{<:VarInfo}, model::Model) - # By default this will simply evaluate the model with `DynamicTransformationContext`, and so - # we need to specialize to avoid this. - return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, model) -end - -function link(::DynamicTransformation, varinfo::VarInfo, vns::VarNameTuple, model::Model) - return _link(model, varinfo, vns) -end - -function link( - ::DynamicTransformation, - varinfo::ThreadSafeVarInfo{<:VarInfo}, - vns::VarNameTuple, - model::Model, -) - # By default this will simply evaluate the model with `DynamicTransformationContext`, - # and so we need to specialize to avoid this. - return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, vns, model) -end - -function _link(model::Model, varinfo::VarInfo, vns) - varinfo = deepcopy(varinfo) - md, logjac = _link_metadata!!(model, varinfo, varinfo.metadata, vns) - new_varinfo = VarInfo(md, varinfo.accs) - if hasacc(new_varinfo, Val(:LogJacobian)) - new_varinfo = acclogjac!!(new_varinfo, logjac) - end - return new_varinfo -end - -# If we try to _link a NTVarInfo with a Tuple of VarNames, first convert it to a -# NamedTuple that matches the structure of the NTVarInfo. -function _link(model::Model, varinfo::NTVarInfo, vns::VarNameTuple) - return _link(model, varinfo, group_varnames_by_symbol(vns)) -end - -function _link(model::Model, varinfo::NTVarInfo, vns::NamedTuple) - varinfo = deepcopy(varinfo) - md, logjac = _link_metadata!(model, varinfo, varinfo.metadata, vns) - new_varinfo = VarInfo(md, varinfo.accs) - if hasacc(new_varinfo, Val(:LogJacobian)) - new_varinfo = acclogjac!!(new_varinfo, logjac) - end - return new_varinfo -end - -@generated function _link_metadata!( - model::Model, - varinfo::VarInfo, - metadata::NamedTuple{metadata_names}, - vns::NamedTuple{vns_names}, -) where {metadata_names,vns_names} - expr = quote - cumulative_logjac = zero(LogProbType) - end - mds = Expr(:tuple) - for f in metadata_names - if f in vns_names - push!( - mds.args, - quote - begin - md, logjac = _link_metadata!!(model, varinfo, metadata.$f, vns.$f) - cumulative_logjac += logjac - md - end - end, - ) - else - push!(mds.args, :(metadata.$f)) - end - end - - push!( - expr.args, - quote - NamedTuple{$metadata_names}($mds), cumulative_logjac - end, - ) - return expr -end - -function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_vns) - vns = metadata.vns - cumulative_logjac = zero(LogProbType) - - # Construct the new transformed values, and keep track of their lengths. - vals_new = map(vns) do vn - # Return early if we're already in unconstrained space. - # HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check. - if is_transformed(varinfo, vn) || (target_vns !== nothing && vn ∉ target_vns) - return metadata.vals[getrange(metadata, vn)] - end - - # Transform to constrained space. - x = getindex_internal(metadata, vn) - dist = getdist(metadata, vn) - f = internal_to_linked_internal_transform(varinfo, vn, dist) - y, logjac = with_logabsdet_jacobian(f, x) - # Vectorize value. - yvec = tovec(y) - # Accumulate the log-abs-det jacobian correction. - cumulative_logjac += logjac - # Mark as transformed. - set_transformed!!(varinfo, true, vn) - # Return the vectorized transformed value. - return yvec - end - - # Determine new ranges. - ranges_new = similar(metadata.ranges) - offset = 0 - for (i, v) in enumerate(vals_new) - r_start, r_end = offset + 1, length(v) + offset - offset = r_end - ranges_new[i] = r_start:r_end - end - - # Now we just create a new metadata with the new `vals` and `ranges`. - return Metadata( - metadata.idcs, - metadata.vns, - ranges_new, - reduce(vcat, vals_new), - metadata.dists, - metadata.is_transformed, - ), - cumulative_logjac -end - -function _link_metadata!!( - model::Model, varinfo::VarInfo, metadata::VarNamedVector, target_vns -) - vns = target_vns === nothing ? keys(metadata) : target_vns - dists = extract_priors(model, varinfo) - cumulative_logjac = zero(LogProbType) - for vn in vns - # First transform from however the variable is stored in vnv to the model - # representation. - transform_to_orig = gettransform(metadata, vn) - val_old = getindex_internal(metadata, vn) - val_orig, logjac1 = with_logabsdet_jacobian(transform_to_orig, val_old) - # Then transform from the model representation to the linked representation. - transform_from_linked = from_linked_vec_transform(dists[vn]) - transform_to_linked = inverse(transform_from_linked) - val_new, logjac2 = with_logabsdet_jacobian(transform_to_linked, val_orig) - # TODO(mhauru) We are calling a !! function but ignoring the return value. - # Fix this when attending to issue #653. - cumulative_logjac += logjac1 + logjac2 - metadata = setindex_internal!!(metadata, val_new, vn, transform_from_linked) - set_transformed!(metadata, true, vn) - end - # Linking can often change the sizes of variables, causing inactive elements. We don't - # want to keep them around, since typically linking is done once and then the VarInfo - # is evaluated multiple times. Hence we contiguify here. - metadata = contiguify!(metadata) - return metadata, cumulative_logjac -end - -function invlink(::DynamicTransformation, vi::NTVarInfo, model::Model) - return _invlink(model, vi, all_varnames_grouped_by_symbol(vi)) -end - -function invlink(::DynamicTransformation, vi::VarInfo, model::Model) - return _invlink(model, vi, keys(vi)) -end - -function invlink( - ::DynamicTransformation, varinfo::ThreadSafeVarInfo{<:VarInfo}, model::Model -) - # By default this will simply evaluate the model with `DynamicTransformationContext`, and so - # we need to specialize to avoid this. - return Accessors.@set varinfo.varinfo = invlink(varinfo.varinfo, model) -end - -function invlink(::DynamicTransformation, varinfo::VarInfo, vns::VarNameTuple, model::Model) - return _invlink(model, varinfo, vns) -end - -function invlink( - ::DynamicTransformation, - varinfo::ThreadSafeVarInfo{<:VarInfo}, - vns::VarNameTuple, - model::Model, -) - # By default this will simply evaluate the model with `DynamicTransformationContext`, and so - # we need to specialize to avoid this. - return Accessors.@set varinfo.varinfo = invlink(varinfo.varinfo, vns, model) -end - -function _invlink(model::Model, varinfo::VarInfo, vns) - varinfo = deepcopy(varinfo) - md, inv_logjac = _invlink_metadata!!(model, varinfo, varinfo.metadata, vns) - new_varinfo = VarInfo(md, varinfo.accs) - if hasacc(new_varinfo, Val(:LogJacobian)) - # Mildly confusing: we need to _add_ the logjac of the inverse transform, - # because we are trying to remove the logjac of the forward transform - # that was previously accumulated when linking. - new_varinfo = acclogjac!!(new_varinfo, inv_logjac) - end - return new_varinfo -end - -# If we try to _invlink a NTVarInfo with a Tuple of VarNames, first convert it to a -# NamedTuple that matches the structure of the NTVarInfo. -function _invlink(model::Model, varinfo::NTVarInfo, vns::VarNameTuple) - return _invlink(model, varinfo, group_varnames_by_symbol(vns)) -end - -function _invlink(model::Model, varinfo::NTVarInfo, vns::NamedTuple) - varinfo = deepcopy(varinfo) - md, inv_logjac = _invlink_metadata!(model, varinfo, varinfo.metadata, vns) - new_varinfo = VarInfo(md, varinfo.accs) - if hasacc(new_varinfo, Val(:LogJacobian)) - # Mildly confusing: we need to _add_ the logjac of the inverse transform, - # because we are trying to remove the logjac of the forward transform - # that was previously accumulated when linking. - new_varinfo = acclogjac!!(new_varinfo, inv_logjac) - end - return new_varinfo -end - -@generated function _invlink_metadata!( - model::Model, - varinfo::VarInfo, - metadata::NamedTuple{metadata_names}, - vns::NamedTuple{vns_names}, -) where {metadata_names,vns_names} - expr = quote - cumulative_inv_logjac = zero(LogProbType) - end - mds = Expr(:tuple) - for f in metadata_names - if (f in vns_names) - push!( - mds.args, - quote - begin - md, inv_logjac = _invlink_metadata!!( - model, varinfo, metadata.$f, vns.$f - ) - cumulative_inv_logjac += inv_logjac - md - end - end, - ) - else - push!(mds.args, :(metadata.$f)) - end - end - - push!( - expr.args, - quote - (NamedTuple{$metadata_names}($mds), cumulative_inv_logjac) - end, - ) - return expr -end - -function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_vns) - vns = metadata.vns - cumulative_inv_logjac = zero(LogProbType) - - # Construct the new transformed values, and keep track of their lengths. - vals_new = map(vns) do vn - # Return early if we're already in constrained space OR if we're not - # supposed to touch this `vn`. - # HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check. - if !is_transformed(varinfo, vn) || (target_vns !== nothing && vn ∉ target_vns) - return metadata.vals[getrange(metadata, vn)] - end - - # Transform to constrained space. - y = getindex_internal(varinfo, vn) - dist = getdist(varinfo, vn) - f = from_linked_internal_transform(varinfo, vn, dist) - x, inv_logjac = with_logabsdet_jacobian(f, y) - # Vectorize value. - xvec = tovec(x) - # Accumulate the log-abs-det jacobian correction. - cumulative_inv_logjac += inv_logjac - # Mark as no longer transformed. - set_transformed!!(varinfo, false, vn) - # Return the vectorized transformed value. - return xvec - end - - # Determine new ranges. - ranges_new = similar(metadata.ranges) - offset = 0 - for (i, v) in enumerate(vals_new) - r_start, r_end = offset + 1, length(v) + offset - offset = r_end - ranges_new[i] = r_start:r_end - end - - # Now we just create a new metadata with the new `vals` and `ranges`. - return Metadata( - metadata.idcs, - metadata.vns, - ranges_new, - reduce(vcat, vals_new), - metadata.dists, - metadata.is_transformed, - ), - cumulative_inv_logjac -end - -function _invlink_metadata!!( - ::Model, varinfo::VarInfo, metadata::VarNamedVector, target_vns -) - vns = target_vns === nothing ? keys(metadata) : target_vns - cumulative_inv_logjac = zero(LogProbType) - for vn in vns - transform = gettransform(metadata, vn) - old_val = getindex_internal(metadata, vn) - new_val, inv_logjac = with_logabsdet_jacobian(transform, old_val) - # TODO(mhauru) We are calling a !! function but ignoring the return value. - cumulative_inv_logjac += inv_logjac - new_transform = from_vec_transform(new_val) - metadata = setindex_internal!!(metadata, tovec(new_val), vn, new_transform) - set_transformed!(metadata, false, vn) - end - # Linking can often change the sizes of variables, causing inactive elements. We don't - # want to keep them around, since typically linking is done once and then the VarInfo - # is evaluated multiple times. Hence we contiguify here. - metadata = contiguify!(metadata) - return metadata, cumulative_inv_logjac -end - -# TODO(mhauru) The treatment of the case when some variables are transformed and others are -# not should be revised. It used to be the case that for UntypedVarInfo `is_transformed` -# returned whether the first variable was linked. For NTVarInfo we did an OR over the first -# variables under each symbol. We now more consistently use OR, but I'm not convinced this -# is really the right thing to do. -""" - is_transformed(vi::VarInfo) - -Check whether `vi` is in the transformed space. - -Turing's Hamiltonian samplers use the `link` and `invlink` functions from -[Bijectors.jl](https://github.com/TuringLang/Bijectors.jl) to map a constrained variable -(for example, one bounded to the space `[0, 1]`) from its constrained space to the set of -real numbers. `is_transformed` checks if the number is in the constrained space or the real -space. - -If some but only some of the variables in `vi` are transformed, this function will return -`true`. This behavior will likely change in the future. -""" -function is_transformed(vi::VarInfo) - return any(is_transformed(vi, vn) for vn in keys(vi)) -end - -# The default getindex & setindex!() for get & set values -# NOTE: vi[vn] will always transform the variable to its original space and Julia type -function getindex(vi::VarInfo, vn::VarName) - return from_maybe_linked_internal_transform(vi, vn)(getindex_internal(vi, vn)) -end - -function getindex(vi::VarInfo, vn::VarName, dist::Distribution) - @assert haskey(vi, vn) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" - val = getindex_internal(vi, vn) - return from_maybe_linked_internal(vi, vn, dist, val) -end - -function getindex(vi::VarInfo, vns::Vector{<:VarName}) - vals = map(vn -> getindex(vi, vn), vns) - - et = eltype(vals) - # This will catch type unstable cases, where vals has mixed types. - if !isconcretetype(et) - throw(ArgumentError("All variables must have the same type.")) - end - - if et <: Vector - all_of_equal_dimension = all(x -> length(x) == length(vals[1]), vals) - if !all_of_equal_dimension - throw(ArgumentError("All variables must have the same dimension.")) - end - end - - # TODO(mhauru) I'm not very pleased with the return type varying like this, even though - # this should be type stable. - vec_vals = reduce(vcat, vals) - if et <: Vector - # The individual variables are multivariate, and thus we return the values as a - # matrix. - return reshape(vec_vals, (:, length(vns))) - else - # The individual variables are univariate, and thus we return a vector of scalars. - return vec_vals - end -end - -function getindex(vi::VarInfo, vns::Vector{<:VarName}, dist::Distribution) - @assert haskey(vi, vns[1]) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" - vals_linked = mapreduce(vcat, vns) do vn - getindex(vi, vn, dist) - end - return recombine(dist, vals_linked, length(vns)) -end - -# Recursively builds a tuple of the `vals` of all the symbols -@generated function _getindex(metadata, ranges::NamedTuple{names}) where {names} - expr = Expr(:tuple) - for f in names - push!(expr.args, :(metadata.$f.vals[ranges.$f])) - end - return expr -end - -# TODO(mhauru) I think the below implementation of setindex! is a mistake. It should be -# called setindex_internal! since it directly writes to the `vals` field of the metadata. -""" - setindex!(vi::VarInfo, val, vn::VarName) - -Set the current value(s) of the random variable `vn` in `vi` to `val`. - -The value(s) may or may not be transformed to Euclidean space. -""" -setindex!(vi::VarInfo, val, vn::VarName) = (setval!(vi, val, vn); return vi) -function BangBang.setindex!!(vi::VarInfo, val, vn::VarName) - setindex!(vi, val, vn) - return vi -end - -@inline function findvns(vi, f_vns) - if length(f_vns) == 0 - throw("Unidentified error, please report this error in an issue.") - end - return map(vn -> vi[vn], f_vns) -end - -Base.haskey(metadata::Metadata, vn::VarName) = haskey(metadata.idcs, vn) - -""" - haskey(vi::VarInfo, vn::VarName) - -Check whether `vn` has a value in `vi`. -""" -Base.haskey(vi::VarInfo, vn::VarName) = haskey(getmetadata(vi, vn), vn) -function Base.haskey(vi::NTVarInfo, vn::VarName) - md_haskey = map(vi.metadata) do metadata - haskey(metadata, vn) - end - return any(md_haskey) -end - -function Base.show(io::IO, ::MIME"text/plain", vi::UntypedVarInfo) - lines = Tuple{String,Any}[ - ("VarNames", vi.metadata.vns), - ("Range", vi.metadata.ranges), - ("Vals", vi.metadata.vals), - ] - for accname in acckeys(vi) - push!(lines, (string(accname), getacc(vi, Val(accname)))) - end - push!(lines, ("is_transformed", vi.metadata.is_transformed)) - max_name_length = maximum(map(length ∘ first, lines)) - fmt = Printf.Format("%-$(max_name_length)s") - vi_str = ( - """ - /======================================================================= - | VarInfo - |----------------------------------------------------------------------- - """ * - prod( - map(lines) do (name, value) - """ - | $(Printf.format(fmt, name)) : $(value) - """ - end, - ) * - """ - \\======================================================================= - """ - ) - return print(io, vi_str) -end - -const _MAX_VARS_SHOWN = 4 - -function _show_varnames(io::IO, vi) - md = vi.metadata - vns = keys(md) - - vns_by_name = Dict{Symbol,Vector{VarName}}() - for vn in vns - group = get!(() -> Vector{VarName}(), vns_by_name, getsym(vn)) - push!(group, vn) - end - - L = length(vns_by_name) - if L == 0 - print(io, "0 variables, dimension 0") - else - (L == 1) ? print(io, "1 variable (") : print(io, L, " variables (") - join(io, Iterators.take(keys(vns_by_name), _MAX_VARS_SHOWN), ", ") - (L > _MAX_VARS_SHOWN) && print(io, ", ...") - print(io, "), dimension ", length(md.vals)) - end -end - -function Base.show(io::IO, vi::UntypedVarInfo) - print(io, "VarInfo (") - _show_varnames(io, vi) - print(io, "; accumulators: ") - # TODO(mhauru) This uses "text/plain" because we are doing quite a condensed repretation - # of vi anyway. However, technically `show(io, x)` should give full details of x and - # preferably output valid Julia code. - show(io, MIME"text/plain"(), getaccs(vi)) - return print(io, ")") -end - -""" - push!!(vi::VarInfo, vn::VarName, r, dist::Distribution) - -Push a new random variable `vn` with a sampled value `r` from a distribution `dist` to -the `VarInfo` `vi`, mutating if it makes sense. -""" -function BangBang.push!!(vi::VarInfo, vn::VarName, val, dist::Distribution) - @assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist" - md = push!!(getmetadata(vi, vn), vn, val, dist) - return VarInfo(md, vi.accs) -end - -function BangBang.push!!(vi::NTVarInfo, vn::VarName, val, dist::Distribution) - @assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to NTVarInfo of syms $(syms(vi)) with dist=$dist" - sym = getsym(vn) - meta = if ~haskey(vi.metadata, sym) - # The NamedTuple doesn't have an entry for this variable, let's add one. - _new_submetadata(vi, vn, val, dist) - else - push!!(getmetadata(vi, vn), vn, val, dist) - end - vi = Accessors.@set vi.metadata[sym] = meta - return vi -end - -""" - _new_submetadata(vi::VarInfo{NamedTuple{Names,SubMetas}}, args...) where {Names,SubMetas} - -Create a new sub-metadata for an NTVarInfo. The type is chosen by the types of existing -SubMetas. -""" -@generated function _new_submetadata( - vi::VarInfo{NamedTuple{Names,SubMetas}}, vn, r, dist -) where {Names,SubMetas} - has_vnv = any(s -> s <: VarNamedVector, SubMetas.parameters) - return if has_vnv - :(return _new_vnv_submetadata(vn, r, dist)) - else - :(return _new_metadata_submetadata(vn, r, dist)) - end -end - -_new_vnv_submetadata(vn, r, _) = VarNamedVector([vn], [r]) - -function _new_metadata_submetadata(vn, r, dist) - val = tovec(r) - return Metadata(Dict(vn => 1), [vn], [1:length(val)], val, [dist], BitVector([false])) -end - -function Base.push!(vi::UntypedVectorVarInfo, pair::Pair, args...) - vn, val = pair - return push!(vi, vn, val, args...) -end - -# TODO(mhauru) push! can't be implemented in-place for NTVarInfo if the symbol doesn't -# exist in the NTVarInfo already. We could implement it in the cases where it it does -# exist, but that feels a bit pointless. I think we should rather rely on `push!!`. - -function Base.push!(meta::Metadata, vn, r, dist) - val = tovec(r) - meta.idcs[vn] = length(meta.idcs) + 1 - push!(meta.vns, vn) - l = length(meta.vals) - n = length(val) - push!(meta.ranges, (l + 1):(l + n)) - append!(meta.vals, val) - push!(meta.dists, dist) - push!(meta.is_transformed, false) - return meta -end - -function BangBang.push!!(meta::Metadata, vn, r, dist) - push!(meta, vn, r, dist) - return meta -end - -function Base.delete!(vi::VarInfo, vn::VarName) - delete!(getmetadata(vi, vn), vn) - return vi -end - -####################################### -# Rand & replaying method for VarInfo # -####################################### - -# TODO: Maybe rename or something? -""" - _apply!(kernel!, vi::VarInfo, values, keys) - -Calls `kernel!(vi, vn, values, keys)` for every `vn` in `vi`. -""" -function _apply!(kernel!, vi::VarInfoOrThreadSafeVarInfo, values, keys) - keys_strings = map(string, collect_maybe(keys)) - num_indices_seen = 0 - - for vn in Base.keys(vi) - indices_found = kernel!(vi, vn, values, keys_strings) - if indices_found !== nothing - num_indices_seen += length(indices_found) - end - end - - if length(keys) > num_indices_seen - # Some keys have not been seen, i.e. attempted to set variables which - # we were not able to locate in `vi`. - # Find the ones we missed so we can warn the user. - unused_keys = _find_missing_keys(vi, keys_strings) - @warn "the following keys were not found in `vi`, and thus `kernel!` was not applied to these: $(unused_keys)" - end - - return vi -end - -function _apply!(kernel!, vi::NTVarInfo, values, keys) - return _typed_apply!(kernel!, vi, vi.metadata, values, collect_maybe(keys)) -end - -@generated function _typed_apply!( - kernel!, vi::NTVarInfo, metadata::NamedTuple{names}, values, keys -) where {names} - updates = map(names) do n - quote - for vn in Base.keys(metadata.$n) - indices_found = kernel!(vi, vn, values, keys_strings) - if indices_found !== nothing - num_indices_seen += length(indices_found) - end - end - end - end - - return quote - keys_strings = map(string, keys) - num_indices_seen = 0 - - $(updates...) - - if length(keys) > num_indices_seen - # Some keys have not been seen, i.e. attempted to set variables which - # we were not able to locate in `vi`. - # Find the ones we missed so we can warn the user. - unused_keys = _find_missing_keys(vi, keys_strings) - @warn "the following keys were not found in `vi`, and thus `kernel!` was not applied to these: $(unused_keys)" - end - - return vi - end -end - -function _find_missing_keys(vi::VarInfoOrThreadSafeVarInfo, keys) - string_vns = map(string, collect_maybe(Base.keys(vi))) - # If `key` isn't subsumed by any element of `string_vns`, it is not present in `vi`. - missing_keys = filter(keys) do key - !any(Base.Fix2(subsumes_string, key), string_vns) - end - - return missing_keys -end - -values_as(vi::VarInfo) = vi.metadata -values_as(vi::VarInfo, ::Type{Vector}) = copy(getindex_internal(vi, Colon())) -function values_as(vi::UntypedVarInfo, ::Type{NamedTuple}) - iter = values_from_metadata(vi.metadata) - return NamedTuple(map(p -> Symbol(p.first) => p.second, iter)) -end -function values_as(vi::UntypedVarInfo, ::Type{D}) where {D<:AbstractDict} - return ConstructionBase.constructorof(D)(values_from_metadata(vi.metadata)) -end - -function values_as(vi::VarInfo{<:NamedTuple{names}}, ::Type{NamedTuple}) where {names} - iter = Iterators.flatten(values_from_metadata(getfield(vi.metadata, n)) for n in names) - return NamedTuple(map(p -> Symbol(p.first) => p.second, iter)) -end - -function values_as( - vi::VarInfo{<:NamedTuple{names}}, ::Type{D} -) where {names,D<:AbstractDict} - iter = Iterators.flatten(values_from_metadata(getfield(vi.metadata, n)) for n in names) - return ConstructionBase.constructorof(D)(iter) -end - -values_as(vi::UntypedVectorVarInfo, args...) = values_as(vi.metadata, args...) -values_as(vi::UntypedVectorVarInfo, T::Type{Vector}) = values_as(vi.metadata, T) - -function values_from_metadata(md::Metadata) - return ( - # `copy` to avoid accidentally mutation of internal representation. - vn => copy( - from_internal_transform(md, vn, getdist(md, vn))(getindex_internal(md, vn)) - ) for vn in md.vns - ) -end - -values_from_metadata(md::VarNamedVector) = pairs(md) - -# Transforming from internal representation to distribution representation. -# Without `dist` argument: base on `dist` extracted from self. -function from_internal_transform(vi::VarInfo, vn::VarName) - return from_internal_transform(getmetadata(vi, vn), vn) -end -function from_internal_transform(md::Metadata, vn::VarName) - return from_internal_transform(md, vn, getdist(md, vn)) -end -function from_internal_transform(md::VarNamedVector, vn::VarName) - return gettransform(md, vn) -end -# With both `vn` and `dist` arguments: base on provided `dist`. -function from_internal_transform(vi::VarInfo, vn::VarName, dist) - return from_internal_transform(getmetadata(vi, vn), vn, dist) -end -from_internal_transform(::Metadata, ::VarName, dist) = from_vec_transform(dist) -function from_internal_transform(::VarNamedVector, ::VarName, dist) - return from_vec_transform(dist) -end - -# Without `dist` argument: base on `dist` extracted from self. -function from_linked_internal_transform(vi::VarInfo, vn::VarName) - return from_linked_internal_transform(getmetadata(vi, vn), vn) -end -function from_linked_internal_transform(md::Metadata, vn::VarName) - return from_linked_internal_transform(md, vn, getdist(md, vn)) -end -function from_linked_internal_transform(md::VarNamedVector, vn::VarName) - return gettransform(md, vn) -end -# With both `vn` and `dist` arguments: base on provided `dist`. -function from_linked_internal_transform(vi::VarInfo, vn::VarName, dist) - # Dispatch to metadata in case this alters the behavior. - return from_linked_internal_transform(getmetadata(vi, vn), vn, dist) -end -function from_linked_internal_transform(::Metadata, ::VarName, dist) - return from_linked_vec_transform(dist) -end -function from_linked_internal_transform(::VarNamedVector, ::VarName, dist) - return from_linked_vec_transform(dist) -end diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl deleted file mode 100644 index e5d2f2c2e..000000000 --- a/src/varnamedvector.jl +++ /dev/null @@ -1,1674 +0,0 @@ -const CHECK_CONSISTENCY_DEFAULT = true - -""" - VarNamedVector - -A container that stores values in a vectorised form, but indexable by variable names. - -A `VarNamedVector` can be thought of as an ordered mapping from `VarName`s to pairs of -`(internal_value, transform)`. Here `internal_value` is a vectorised value for the variable -and `transform` is a function such that `transform(internal_value)` is the "original" value -of the variable, the one that the user sees. For instance, if the variable has a matrix -value, `internal_value` could bea flattened `Vector` of its elements, and `transform` would -be a `reshape` call. - -`transform` may implement simply vectorisation, but it may do more. Most importantly, it may -implement linking, where the internal storage of a random variable is in a form where all -values in Euclidean space are valid. This is useful for sampling, because the sampler can -make changes to `internal_value` without worrying about constraints on the space of -the random variable. - -The way to access this storage format directly is through the functions `getindex_internal` -and `setindex_internal`. The `transform` argument for `setindex_internal` is optional, by -default it is either the identity, or the existing transform if a value already exists for -this `VarName`. - -`VarNamedVector` also provides a `Dict`-like interface that hides away the internal -vectorisation. This can be accessed with `getindex` and `setindex!`. `setindex!` only takes -the value, the transform is automatically set to be a simple vectorisation. The only notable -deviation from the behavior of a `Dict` is that `setindex!` will throw an error if one tries -to set a new value for a variable that lives in a different "space" than the old one (e.g. -is of a different type or size). This is because `setindex!` does not change the transform -of a variable, e.g. preserve linking, and thus the new value must be compatible with the old -transform. - -For now, a third value is in fact stored for each `VarName`: a boolean indicating whether -the variable has been transformed to unconstrained Euclidean space or not. This is only in -place temporarily due to the needs of our old Gibbs sampler. - -Internally, `VarNamedVector` stores the values of all variables in a single contiguous -vector. This makes some operations more efficient, and means that one can access the entire -contents of the internal storage quickly with `getindex_internal(vnv, :)`. The other fields -of `VarNamedVector` are mostly used to keep track of which part of the internal storage -belongs to which `VarName`. - -All constructors accept a keyword argument `check_consistency::Bool=true` that controls -whether to run checks like the number of values matching the number of variables. Some of -these checks can be expensive, so if you are confident in the input, you may want to turn -`check_consistency` off for performance. - -# Fields - -$(FIELDS) - -# Extended help - -The values for different variables are internally all stored in a single vector. For -instance, -```jldoctest varnamedvector-struct -julia> using DynamicPPL: ReshapeTransform, VarNamedVector, @varname, setindex!!, update!!, getindex_internal - -julia> vnv = VarNamedVector(); - -julia> vnv = setindex!!(vnv, [0.0, 0.0, 0.0, 0.0], @varname(x)); - -julia> vnv = setindex!!(vnv, reshape(1:6, (2,3)), @varname(y)); - -julia> vnv.vals -10-element Vector{Real}: - 0.0 - 0.0 - 0.0 - 0.0 - 1 - 2 - 3 - 4 - 5 - 6 -``` - -The `varnames`, `ranges`, and `varname_to_index` fields keep track of which value belongs to -which variable. The `transforms` field stores the transformations that needed to transform -the vectorised internal storage back to its original form: - -```jldoctest varnamedvector-struct -julia> vnv.transforms[vnv.varname_to_index[@varname(y)]] == DynamicPPL.ReshapeTransform((6,), (2,3)) -true -``` - -If a variable is updated with a new value that is of a smaller dimension than the old -value, rather than resizing `vnv.vals`, some elements in `vnv.vals` are marked as inactive. - -```jldoctest varnamedvector-struct -julia> vnv = update!!(vnv, [46.0, 48.0], @varname(x)); - -julia> vnv.vals -10-element Vector{Real}: - 46.0 - 48.0 - 0.0 - 0.0 - 1 - 2 - 3 - 4 - 5 - 6 - -julia> println(vnv.num_inactive); -Dict(1 => 2) -``` - -This helps avoid unnecessary memory allocations for values that repeatedly change dimension. -The user does not have to worry about the inactive entries as long as they use functions -like `setindex!` and `getindex!` rather than directly accessing `vnv.vals`. - -```jldoctest varnamedvector-struct -julia> vnv[@varname(x)] -2-element Vector{Real}: - 46.0 - 48.0 - -julia> getindex_internal(vnv, :) -8-element Vector{Real}: - 46.0 - 48.0 - 1 - 2 - 3 - 4 - 5 - 6 -``` -""" -struct VarNamedVector{ - K<:VarName,V,T,KVec<:AbstractVector{K},VVec<:AbstractVector{V},TVec<:AbstractVector{T} -} - """ - mapping from a `VarName` to its integer index in `varnames`, `ranges` and `transforms` - """ - varname_to_index::Dict{K,Int} - - """ - vector of `VarNames` for the variables, where `varnames[varname_to_index[vn]] == vn` - """ - varnames::KVec - - """ - vector of index ranges in `vals` corresponding to `varnames`; each `VarName` `vn` has - a single index or a set of contiguous indices, such that the values of `vn` can be found - at `vals[ranges[varname_to_index[vn]]]` - """ - ranges::Vector{UnitRange{Int}} - - """ - vector of values of all variables; the value(s) of `vn` is/are - `vals[ranges[varname_to_index[vn]]]` - """ - vals::VVec - - """ - vector of transformations, so that `transforms[varname_to_index[vn]]` is a callable - that transforms the value of `vn` back to its original space, undoing any linking and - vectorisation - """ - transforms::TVec - - """ - vector of booleans indicating whether a variable has been explicitly transformed to - unconstrained Euclidean space, i.e. whether its domain is all of `ℝ^ⁿ`. If - `is_unconstrained[varname_to_index[vn]]` is true, it guarantees that the variable - `vn` is not constrained. However, the converse does not hold: if `is_unconstrained` - is false, the variable `vn` may still happen to be unconstrained, e.g. if its - original distribution is itself unconstrained (like a normal distribution). - """ - is_unconstrained::BitVector - - """ - mapping from a variable index to the number of inactive entries for that variable. - Inactive entries are elements in `vals` that are not part of the value of any variable. - They arise when a variable is set to a new value with a different dimension, in-place. - Inactive entries always come after the last active entry for the given variable. - See the extended help with `??VarNamedVector` for more details. - """ - num_inactive::Dict{Int,Int} - - function VarNamedVector( - varname_to_index, - varnames::KVec, - ranges, - vals::VVec, - transforms::TVec, - is_unconstrained=fill!(BitVector(undef, length(varnames)), 0), - num_inactive=Dict{Int,Int}(); - check_consistency::Bool=CHECK_CONSISTENCY_DEFAULT, - ) where {K,V,T,KVec<:AbstractVector{K},VVec<:AbstractVector{V},TVec<:AbstractVector{T}} - if check_consistency - if length(varnames) != length(ranges) || - length(varnames) != length(transforms) || - length(varnames) != length(is_unconstrained) || - length(varnames) != length(varname_to_index) - msg = ( - "Inputs to VarNamedVector have inconsistent lengths. " * - "Got lengths varnames: $(length(varnames)), " * - "ranges: $(length(ranges)), " * - "transforms: $(length(transforms)), " * - "is_unconstrained: $(length(is_unconstrained)), " * - "varname_to_index: $(length(varname_to_index))." - ) - throw(ArgumentError(msg)) - end - - num_vals = mapreduce(length, (+), ranges; init=0) + sum(values(num_inactive)) - if num_vals != length(vals) - msg = ( - "The total number of elements in `vals` ($(length(vals))) does not " * - "match the sum of the lengths of the ranges and the number of " * - "inactive entries ($(num_vals))." - ) - throw(ArgumentError(msg)) - end - - if Set(values(varname_to_index)) != Set(axes(varnames, 1)) - msg = ( - "The set of values of `varname_to_index` is not the set of valid " * - "indices for `varnames`." - ) - throw(ArgumentError(msg)) - end - - if !issubset(Set(keys(num_inactive)), Set(values(varname_to_index))) - msg = ( - "The keys of `num_inactive` are not a subset of the values of " * - "`varname_to_index`." - ) - throw(ArgumentError(msg)) - end - - # Check that the varnames don't overlap. The time cost is quadratic in number of - # variables. If this ever becomes an issue, we should be able to go down to at - # least N log N by sorting based on subsumes-order. - for vn1 in keys(varname_to_index) - for vn2 in keys(varname_to_index) - vn1 === vn2 && continue - if subsumes(vn1, vn2) - msg = ( - "Variables in a VarNamedVector should not subsume each " * - "other, but $vn1 subsumes $vn2." - ) - throw(ArgumentError(msg)) - end - end - end - - # We could also have a test to check that the ranges don't overlap, but that - # sounds unlikely to occur, and implementing it in linear time would require a - # tiny bit of thought. - end - - return new{K,V,T,KVec,VVec,TVec}( - varname_to_index, - varnames, - ranges, - vals, - transforms, - is_unconstrained, - num_inactive, - ) - end -end - -function VarNamedVector{K,V,T}() where {K,V,T} - return VarNamedVector( - Dict{K,Int}(), K[], UnitRange{Int}[], V[], T[]; check_consistency=false - ) -end - -VarNamedVector() = VarNamedVector{Union{},Union{},Union{}}() -function VarNamedVector(xs::Pair...; check_consistency=CHECK_CONSISTENCY_DEFAULT) - return VarNamedVector(OrderedDict(xs...); check_consistency=check_consistency) -end -function VarNamedVector(x::AbstractDict; check_consistency=CHECK_CONSISTENCY_DEFAULT) - return VarNamedVector(keys(x), values(x); check_consistency=check_consistency) -end -function VarNamedVector(varnames, vals; check_consistency=CHECK_CONSISTENCY_DEFAULT) - return VarNamedVector( - collect_maybe(varnames), collect_maybe(vals); check_consistency=check_consistency - ) -end -function VarNamedVector( - varnames::AbstractVector, - orig_vals::AbstractVector, - transforms=fill(identity, length(varnames)); - check_consistency=CHECK_CONSISTENCY_DEFAULT, -) - if isempty(varnames) && isempty(orig_vals) && isempty(transforms) - return VarNamedVector{eltype(varnames),eltype(orig_vals),eltype(transforms)}() - end - # Convert `vals` into a vector of vectors. - vals_vecs = map(tovec, orig_vals) - transforms = map( - (t, val) -> _compose_no_identity(t, from_vec_transform(val)), transforms, orig_vals - ) - # Make `varnames` have as concrete an element type as possible. - varnames = [v for v in varnames] - varname_to_index = Dict{eltype(varnames),Int}( - vn => i for (i, vn) in enumerate(varnames) - ) - vals = reduce(vcat, vals_vecs) - # Make the ranges. - ranges = Vector{UnitRange{Int}}() - offset = 0 - for x in vals_vecs - r = (offset + 1):(offset + length(x)) - push!(ranges, r) - offset = r[end] - end - - # Passing on check_consistency here seems wasteful. Wouldn't it be faster to do a - # lightweight check of the arguments of this function, and rely on the correctness - # of what this function does? However, the expensive check is whether any variable - # subsumes another, and that's the same regardless of where it's done, so the - # optimisation would be quite pointless. - return VarNamedVector( - varname_to_index, - varnames, - ranges, - vals, - transforms; - check_consistency=check_consistency, - ) -end - -function ==(vnv_left::VarNamedVector, vnv_right::VarNamedVector) - return vnv_left.varname_to_index == vnv_right.varname_to_index && - vnv_left.varnames == vnv_right.varnames && - vnv_left.ranges == vnv_right.ranges && - vnv_left.vals == vnv_right.vals && - vnv_left.transforms == vnv_right.transforms && - vnv_left.is_unconstrained == vnv_right.is_unconstrained && - vnv_left.num_inactive == vnv_right.num_inactive -end - -function is_tightly_typed(vnv::VarNamedVector) - k = eltype(vnv.varnames) - v = eltype(vnv.vals) - t = eltype(vnv.transforms) - return (isconcretetype(k) || k === Union{}) && - (isconcretetype(v) || v === Union{}) && - (isconcretetype(t) || t === Union{}) -end - -getidx(vnv::VarNamedVector, vn::VarName) = vnv.varname_to_index[vn] - -getrange(vnv::VarNamedVector, idx::Int) = vnv.ranges[idx] -getrange(vnv::VarNamedVector, vn::VarName) = getrange(vnv, getidx(vnv, vn)) - -gettransform(vnv::VarNamedVector, idx::Int) = vnv.transforms[idx] -gettransform(vnv::VarNamedVector, vn::VarName) = gettransform(vnv, getidx(vnv, vn)) - -# TODO(mhauru) Eventually I would like to rename the is_transformed function to -# is_unconstrained, but that's significantly breaking. -""" - is_transformed(vnv::VarNamedVector, vn::VarName) - -Return a boolean for whether `vn` is guaranteed to have been transformed so that its domain -is all of Euclidean space. -""" -is_transformed(vnv::VarNamedVector, vn::VarName) = vnv.is_unconstrained[getidx(vnv, vn)] - -""" - set_transformed!(vnv::VarNamedVector, val::Bool, vn::VarName) - -Set the value for whether `vn` is guaranteed to have been transformed so that all of -Euclidean space is its domain. -""" -function set_transformed!(vnv::VarNamedVector, val::Bool, vn::VarName) - return vnv.is_unconstrained[vnv.varname_to_index[vn]] = val -end - -function set_transformed!!(vnv::VarNamedVector, val::Bool, vn::VarName) - set_transformed!(vnv, val, vn) - return vnv -end - -""" - has_inactive(vnv::VarNamedVector) - -Returns `true` if `vnv` has inactive entries. - -See also: [`num_inactive`](@ref) -""" -has_inactive(vnv::VarNamedVector) = !isempty(vnv.num_inactive) - -""" - num_inactive(vnv::VarNamedVector) - -Return the number of inactive entries in `vnv`. - -See also: [`has_inactive`](@ref), [`num_allocated`](@ref) -""" -num_inactive(vnv::VarNamedVector) = sum(values(vnv.num_inactive)) - -""" - num_inactive(vnv::VarNamedVector, vn::VarName) - -Returns the number of inactive entries for `vn` in `vnv`. -""" -num_inactive(vnv::VarNamedVector, vn::VarName) = num_inactive(vnv, getidx(vnv, vn)) -num_inactive(vnv::VarNamedVector, idx::Int) = get(vnv.num_inactive, idx, 0) - -""" - num_allocated(vnv::VarNamedVector) - num_allocated(vnv::VarNamedVector[, vn::VarName]) - num_allocated(vnv::VarNamedVector[, idx::Int]) - -Return the number of allocated entries in `vnv`, both active and inactive. - -If either a `VarName` or an `Int` index is specified, only count entries allocated for that -variable. - -Allocated entries take up memory in `vnv.vals`, but, if inactive, may not currently hold any -meaningful data. One can remove them with [`contiguify!`](@ref), but doing so may cause more -memory allocations in the future if variables change dimension. -""" -num_allocated(vnv::VarNamedVector) = length(vnv.vals) -num_allocated(vnv::VarNamedVector, vn::VarName) = num_allocated(vnv, getidx(vnv, vn)) -function num_allocated(vnv::VarNamedVector, idx::Int) - return length(getrange(vnv, idx)) + num_inactive(vnv, idx) -end - -# Dictionary interface. -Base.isempty(vnv::VarNamedVector) = isempty(vnv.varnames) -Base.length(vnv::VarNamedVector) = length(vnv.varnames) -Base.keys(vnv::VarNamedVector) = vnv.varnames -Base.values(vnv::VarNamedVector) = Iterators.map(Base.Fix1(getindex, vnv), vnv.varnames) -Base.pairs(vnv::VarNamedVector) = (vn => vnv[vn] for vn in keys(vnv)) -Base.haskey(vnv::VarNamedVector, vn::VarName) = haskey(vnv.varname_to_index, vn) - -# Vector-like interface. -Base.eltype(vnv::VarNamedVector) = eltype(vnv.vals) - -""" - length_internal(vnv::VarNamedVector) - -Return the length of the internal storage vector of `vnv`, ignoring inactive entries. -""" -function length_internal(vnv::VarNamedVector) - if !has_inactive(vnv) - return length(vnv.vals) - else - return sum(length, vnv.ranges) - end -end - -# Getting and setting values - -function Base.getindex(vnv::VarNamedVector, vn::VarName) - x = getindex_internal(vnv, vn) - f = gettransform(vnv, vn) - return f(x) -end - -""" - find_containing_range(ranges::AbstractVector{<:AbstractRange}, x) - -Find the first range in `ranges` that contains `x`. - -Throw an `ArgumentError` if `x` is not in any of the ranges. -""" -function find_containing_range(ranges::AbstractVector{<:AbstractRange}, x) - # TODO: Assume `ranges` to be sorted and contiguous, and use `searchsortedfirst` - # for a more efficient approach. - range_idx = findfirst(Base.Fix1(∈, x), ranges) - - # If we're out of bounds, we raise an error. - if range_idx === nothing - throw(ArgumentError("Value $x is not in any of the ranges.")) - end - - return range_idx -end - -""" - adjusted_ranges(vnv::VarNamedVector) - -Return what `vnv.ranges` would be if there were no inactive entries. -""" -function adjusted_ranges(vnv::VarNamedVector) - # Every range following inactive entries needs to be shifted. - offset = 0 - ranges_adj = similar(vnv.ranges) - for (idx, r) in enumerate(vnv.ranges) - # Remove the `offset` in `r` due to inactive entries. - ranges_adj[idx] = r .- offset - # Update `offset`. - offset += get(vnv.num_inactive, idx, 0) - end - - return ranges_adj -end - -""" - index_to_vals_index(vnv::VarNamedVector, i::Int) - -Convert an integer index that ignores inactive entries to an index that accounts for them. - -This is needed when the user wants to index `vnv` like a vector, but shouldn't have to care -about inactive entries in `vnv.vals`. -""" -function index_to_vals_index(vnv::VarNamedVector, i::Int) - # If we don't have any inactive entries, there's nothing to do. - has_inactive(vnv) || return i - - # Get the adjusted ranges. - ranges_adj = adjusted_ranges(vnv) - # Determine the adjusted range that the index corresponds to. - r_idx = find_containing_range(ranges_adj, i) - r = vnv.ranges[r_idx] - # Determine how much of the index `i` is used to get to this range. - i_used = r_idx == 1 ? 0 : sum(length, ranges_adj[1:(r_idx - 1)]) - # Use remainder to index into `r`. - i_remainder = i - i_used - return r[i_remainder] -end - -""" - getindex_internal(vnv::VarNamedVector, vn::VarName) - -Like `getindex`, but returns the values as they are stored in `vnv`, without transforming. -""" -getindex_internal(vnv::VarNamedVector, vn::VarName) = vnv.vals[getrange(vnv, vn)] - -""" - getindex_internal(vnv::VarNamedVector, i::Int) - -Gets the `i`th element of the internal storage vector, ignoring inactive entries. -""" -getindex_internal(vnv::VarNamedVector, i::Int) = vnv.vals[index_to_vals_index(vnv, i)] - -function getindex_internal(vnv::VarNamedVector, ::Colon) - return if has_inactive(vnv) - mapreduce(Base.Fix1(getindex, vnv.vals), vcat, vnv.ranges) - else - vnv.vals - end -end - -function Base.setindex!(vnv::VarNamedVector, val, vn::VarName) - if haskey(vnv, vn) - return update!(vnv, val, vn) - else - return insert!(vnv, val, vn) - end -end - -""" - reset!(vnv::VarNamedVector, val, vn::VarName) - -Reset the value of `vn` in `vnv` to `val`. - -This differs from `setindex!` in that it will always change the transform of the variable -to be the default vectorisation transform. This undoes any possible linking. - -# Examples - -```jldoctest varnamedvector-reset -julia> using DynamicPPL: VarNamedVector, @varname, reset! - -julia> vnv = VarNamedVector{VarName,Any,Any}(); - -julia> vnv[@varname(x)] = reshape(1:9, (3, 3)); - -julia> setindex!(vnv, 2.0, @varname(x)) -ERROR: An error occurred while assigning the value 2.0 to variable x. If you are changing the type or size of a variable you'll need to call reset! -[...] - -julia> reset!(vnv, 2.0, @varname(x)); - -julia> vnv[@varname(x)] -2.0 -``` -""" -function reset!(vnv::VarNamedVector, val, vn::VarName) - f = from_vec_transform(val) - retval = setindex_internal!(vnv, tovec(val), vn, f) - set_transformed!(vnv, false, vn) - return retval -end - -""" - update!(vnv::VarNamedVector, val, vn::VarName) - -Update the value of `vn` in `vnv` to `val`. - -Like `setindex!`, but errors if the key `vn` doesn't exist. -""" -function update!(vnv::VarNamedVector, val, vn::VarName) - if !haskey(vnv, vn) - throw(KeyError(vn)) - end - f = inverse(gettransform(vnv, vn)) - internal_val = try - f(val) - catch - error( - "An error occurred while assigning the value $val to variable $vn. " * - "If you are changing the type or size of a variable you'll need to call " * - "reset!", - ) - end - return setindex_internal!(vnv, internal_val, vn) -end - -""" - insert!(vnv::VarNamedVector, val, vn::VarName) - -Add a variable with given value to `vnv`. - -Like `setindex!`, but errors if the key `vn` already exists. -""" -function Base.insert!(vnv::VarNamedVector, val, vn::VarName) - if haskey(vnv, vn) - throw("Variable $vn already exists in VarNamedVector.") - end - return reset!(vnv, val, vn) -end - -""" - push!(vnv::VarNamedVector, pair::Pair) - -Add a variable with given value to `vnv`. Pair should be a `VarName` and a value. -""" -function Base.push!(vnv::VarNamedVector, pair::Pair) - vn, val = pair - # TODO(mhauru) Or should this rather call `reset!`? It would be more inline with what - # Dict does, but could also cause confusion. - return setindex!(vnv, val, vn) -end - -""" - setindex_internal!(vnv::VarNamedVector, val, i::Int) - -Sets the `i`th element of the internal storage vector, ignoring inactive entries. -""" -function setindex_internal!(vnv::VarNamedVector, val, i::Int) - return vnv.vals[index_to_vals_index(vnv, i)] = val -end - -""" - setindex_internal!(vnv::VarNamedVector, val, vn::VarName[, transform]) - -Like `setindex!`, but sets the values as they are stored internally in `vnv`. - -Optionally can set the transformation, such that `transform(val)` is the original value of -the variable. By default, the transform is the identity if creating a new entry in `vnv`, or -the existing transform if updating an existing entry. -""" -function setindex_internal!( - vnv::VarNamedVector, val::AbstractVector, vn::VarName, transform=nothing -) - if haskey(vnv, vn) - return update_internal!(vnv, val, vn, transform) - else - return insert_internal!(vnv, val, vn, transform) - end -end - -""" - insert_internal!(vnv::VarNamedVector, val::AbstractVector, vn::VarName[, transform]) - -Add a variable with given value to `vnv`. - -Like `setindex_internal!`, but errors if the key `vn` already exists. - -`transform` should be a function that converts `val` to the original representation. By -default it's `identity`. -""" -function insert_internal!( - vnv::VarNamedVector, val::AbstractVector, vn::VarName, transform=nothing -) - if transform === nothing - transform = identity - end - haskey(vnv, vn) && throw(ArgumentError("variable name $vn already exists")) - # NOTE: We need to compute the `nextrange` BEFORE we start mutating the underlying - # storage. - r_new = nextrange(vnv, val) - vnv.varname_to_index[vn] = length(vnv.varname_to_index) + 1 - push!(vnv.varnames, vn) - push!(vnv.ranges, r_new) - append!(vnv.vals, val) - push!(vnv.transforms, transform) - push!(vnv.is_unconstrained, false) - return nothing -end - -""" - update_internal!(vnv::VarNamedVector, vn::VarName, val::AbstractVector[, transform]) - -Update an existing entry for `vn` in `vnv` with the value `val`. - -Like `setindex_internal!`, but errors if the key `vn` doesn't exist. - -`transform` should be a function that converts `val` to the original representation. By -default it's the same as the old transform for `vn`. -""" -function update_internal!( - vnv::VarNamedVector, val::AbstractVector, vn::VarName, transform=nothing -) - # Here we update an existing entry. - if !haskey(vnv, vn) - throw(KeyError(vn)) - end - idx = getidx(vnv, vn) - # Extract the old range. - r_old = getrange(vnv, idx) - start_old, end_old = first(r_old), last(r_old) - n_old = length(r_old) - # Compute the new range. - n_new = length(val) - start_new = start_old - end_new = start_old + n_new - 1 - r_new = start_new:end_new - - #= - Suppose we currently have the following: - - | x | x | o | o | o | y | y | y | <- Current entries - - where 'O' denotes an inactive entry, and we're going to - update the variable `x` to be of size `k` instead of 2. - - We then have a few different scenarios: - 1. `k > 5`: All inactive entries become active + need to shift `y` to the right. - E.g. if `k = 7`, then - - | x | x | o | o | o | y | y | y | <- Current entries - | x | x | x | x | x | x | x | y | y | y | <- New entries - - 2. `k = 5`: All inactive entries become active. - Then - - | x | x | o | o | o | y | y | y | <- Current entries - | x | x | x | x | x | y | y | y | <- New entries - - 3. `k < 5`: Some inactive entries become active, some remain inactive. - E.g. if `k = 3`, then - - | x | x | o | o | o | y | y | y | <- Current entries - | x | x | x | o | o | y | y | y | <- New entries - - 4. `k = 2`: No inactive entries become active. - Then - - | x | x | o | o | o | y | y | y | <- Current entries - | x | x | o | o | o | y | y | y | <- New entries - - 5. `k < 2`: More entries become inactive. - E.g. if `k = 1`, then - - | x | x | o | o | o | y | y | y | <- Current entries - | x | o | o | o | o | y | y | y | <- New entries - =# - - # Compute the allocated space for `vn`. - had_inactive = haskey(vnv.num_inactive, idx) - n_allocated = had_inactive ? n_old + vnv.num_inactive[idx] : n_old - - if n_new > n_allocated - # Then we need to grow the underlying vector. - n_extra = n_new - n_allocated - # Allocate. - resize!(vnv.vals, length(vnv.vals) + n_extra) - # Shift current values. - shift_right!(vnv.vals, end_old + 1, n_extra) - # No more inactive entries. - had_inactive && delete!(vnv.num_inactive, idx) - # Update the ranges for all variables after this one. - shift_subsequent_ranges_by!(vnv, idx, n_extra) - elseif n_new == n_allocated - # => No more inactive entries. - had_inactive && delete!(vnv.num_inactive, idx) - else - # `n_new < n_allocated` - # => Need to update the number of inactive entries. - vnv.num_inactive[idx] = n_allocated - n_new - end - - # Update the range for this variable. - vnv.ranges[idx] = r_new - # Update the value. - vnv.vals[r_new] = val - if transform !== nothing - # Update the transform. - vnv.transforms[idx] = transform - end - - # TODO: Should we maybe sweep over inactive ranges and re-contiguify - # if the total number of inactive elements is "large" in some sense? - - return nothing -end - -function Base.push!(vnv::VarNamedVector, vn, val, dist) - f = from_vec_transform(dist) - return setindex_internal!(vnv, tovec(val), vn, f) -end - -function BangBang.push!!(vnv::VarNamedVector, vn, val, dist) - f = from_vec_transform(dist) - return setindex_internal!!(vnv, tovec(val), vn, f) -end - -# BangBang versions of the above functions. -# The only difference is that update_internal!! and insert_internal!! check whether the -# container types of the VarNamedVector vector need to be expanded to accommodate the new -# values. If so, they create a new instance, otherwise they mutate in place. All the others -# functions, e.g. setindex!!, setindex_internal!!, etc., are carbon copies of the ! versions -# with every ! call replaced with a !! call. - -""" - loosen_types!!(vnv::VarNamedVector, ::Type{KNew}, ::Type{VNew}, ::Type{TNew}) - -Loosen the types of `vnv` to allow varname type `KNew` and transformation type `TransNew`. - -If `KNew` is a subtype of `K` and `TransNew` is a subtype of the element type of the -`TTrans` then this is a no-op and `vnv` is returned as is. Otherwise a new `VarNamedVector` -is returned with the same data but more abstract types, so that variables of type `KNew` and -transformations of type `TransNew` can be pushed to it. Some of the underlying storage is -shared between `vnv` and the return value, and thus mutating one may affect the other. - -# See also -[`tighten_types!!`](@ref) - -# Examples - -```jldoctest varnamedvector-loosen-types -julia> using DynamicPPL: VarNamedVector, @varname, loosen_types!!, setindex_internal! - -julia> vnv = VarNamedVector(@varname(x) => [1.0]); - -julia> y_trans(x) = reshape(x, (2, 2)); - -julia> setindex_internal!(vnv, collect(1:4), @varname(y), y_trans) -ERROR: MethodError: Cannot `convert` an object of type -[...] - -julia> vnv_loose = DynamicPPL.loosen_types!!( - vnv, typeof(@varname(y)), Float64, typeof(y_trans) - ); - -julia> setindex_internal!(vnv_loose, collect(1:4), @varname(y), y_trans) - -julia> vnv_loose[@varname(y)] -2×2 Matrix{Float64}: - 1.0 3.0 - 2.0 4.0 -``` -""" -function loosen_types!!( - vnv::VarNamedVector, ::Type{KNew}, ::Type{VNew}, ::Type{TNew} -) where {KNew,VNew,TNew} - K = eltype(vnv.varnames) - V = eltype(vnv.vals) - T = eltype(vnv.transforms) - if KNew <: K && VNew <: V && TNew <: T - return vnv - else - # We could use promote_type here, instead of typejoin. However, that would e.g. - # cause Ints to be converted to Float64s, since - # promote_type(Int, Float64) == Float64, which can cause problems. See - # https://github.com/TuringLang/DynamicPPL.jl/pull/1098#discussion_r2472636188. - # Base.promote_typejoin would be like typejoin, but creates Unions out of Nothing - # and Missing, rather than falling back on Any. However, it's not exported. - vn_type = typejoin(K, KNew) - val_type = typejoin(V, VNew) - transform_type = typejoin(T, TNew) - # This function would work the same way if the first if statement a few lines above - # was skipped, and we only checked for the below condition. However, the first one - # is constant propagated away at compile time (at least on Julia v1.11.7), whereas - # this one isn't. Hence we keep both for performance. - return if vn_type == K && val_type == V && transform_type == T - vnv - elseif isempty(vnv) - VarNamedVector( - Dict{vn_type,Int}(), - Vector{vn_type}(), - UnitRange{Int}[], - Vector{val_type}(), - Vector{transform_type}(), - BitVector(), - Dict{Int,Int}(); - check_consistency=false, - ) - else - # TODO(mhauru) We allow a `vnv` to have any AbstractVector type as its vals, but - # then here always revert to Vector. - VarNamedVector( - Dict{vn_type,Int}(vnv.varname_to_index), - Vector{vn_type}(vnv.varnames), - vnv.ranges, - Vector{val_type}(vnv.vals), - Vector{transform_type}(vnv.transforms), - vnv.is_unconstrained, - vnv.num_inactive; - check_consistency=false, - ) - end - end -end - -""" - tighten_types!!(vnv::VarNamedVector) - -Return a `VarNamedVector` like `vnv` with the most concrete types possible. - -This function either returns `vnv` itself or new `VarNamedVector` with the same values in -it, but with the element types of various containers made as concrete as possible. - -For instance, if `vnv` has its vector of transforms have eltype `Any`, but all the -transforms are actually identity transformations, this function will return a new -`VarNamedVector` with the transforms vector having eltype `typeof(identity)`. - -This is a lot like the reverse of [`loosen_types!!`](@ref). Like with `loosen_types!!`, the -return value may share some of its underlying storage with `vnv`, and thus mutating one may -affect the other. - -# See also -[`loosen_types!!`](@ref) - -# Examples - -```jldoctest varnamedvector-tighten-types -julia> using DynamicPPL: VarNamedVector, @varname, loosen_types!!, setindex_internal! - -julia> vnv = VarNamedVector(@varname(x) => Real[23], @varname(y) => randn(2,2)); - -julia> vnv = delete!(vnv, @varname(y)); - -julia> eltype(vnv) -Real - -julia> vnv.transforms -1-element Vector{Any}: - identity (generic function with 1 method) - -julia> vnv_tight = DynamicPPL.tighten_types!!(vnv); - -julia> eltype(vnv_tight) == Int -true - -julia> vnv_tight.transforms -1-element Vector{typeof(identity)}: - identity (generic function with 1 method) -``` -""" -function tighten_types!!(vnv::VarNamedVector) - return if is_tightly_typed(vnv) - # There can not be anything to tighten, so short-circuit. - vnv - elseif isempty(vnv) - VarNamedVector() - else - VarNamedVector( - Dict(vnv.varname_to_index...), - [x for x in vnv.varnames], - vnv.ranges, - [x for x in vnv.vals], - [x for x in vnv.transforms], - vnv.is_unconstrained, - vnv.num_inactive; - check_consistency=false, - ) - end -end - -function BangBang.setindex!!(vnv::VarNamedVector, val, vn::VarName) - if haskey(vnv, vn) - return update!!(vnv, val, vn) - else - return insert!!(vnv, val, vn) - end -end - -function reset!!(vnv::VarNamedVector, val, vn::VarName) - f = from_vec_transform(val) - vnv = setindex_internal!!(vnv, tovec(val), vn, f) - vnv = set_transformed!!(vnv, false, vn) - return vnv -end - -function update!!(vnv::VarNamedVector, val, vn::VarName) - if !haskey(vnv, vn) - throw(KeyError(vn)) - end - f = inverse(gettransform(vnv, vn)) - internal_val = try - f(val) - catch - error( - "An error occurred while assigning the value $val to variable $vn. " * - "If you are changing the type or size of a variable you'll need to either " * - "`delete!` it first or use `setindex_internal!`", - ) - end - return setindex_internal!!(vnv, internal_val, vn) -end - -function insert!!(vnv::VarNamedVector, val, vn::VarName) - if haskey(vnv, vn) - throw("Variable $vn already exists in VarNamedVector.") - end - return reset!!(vnv, val, vn) -end - -function setindex_internal!!( - vnv::VarNamedVector, val::AbstractVector, vn::VarName, transform=nothing -) - if haskey(vnv, vn) - return update_internal!!(vnv, val, vn, transform) - else - return insert_internal!!(vnv, val, vn, transform) - end -end - -function insert_internal!!( - vnv::VarNamedVector, val::AbstractVector, vn::VarName, transform=nothing -) - if transform === nothing - transform = identity - end - vnv = loosen_types!!(vnv, typeof(vn), eltype(val), typeof(transform)) - insert_internal!(vnv, val, vn, transform) - vnv = tighten_types!!(vnv) - return vnv -end - -function update_internal!!( - vnv::VarNamedVector, val::AbstractVector, vn::VarName, transform=nothing -) - transform_resolved = transform === nothing ? gettransform(vnv, vn) : transform - vnv = loosen_types!!(vnv, typeof(vn), eltype(val), typeof(transform_resolved)) - update_internal!(vnv, val, vn, transform) - vnv = tighten_types!!(vnv) - return vnv -end - -function BangBang.push!!(vnv::VarNamedVector, pair::Pair) - vn, val = pair - return setindex!!(vnv, val, vn) -end - -function Base.empty!(vnv::VarNamedVector) - # TODO: Or should the semantics be different, e.g. keeping `varnames`? - empty!(vnv.varname_to_index) - empty!(vnv.varnames) - empty!(vnv.ranges) - empty!(vnv.vals) - empty!(vnv.transforms) - empty!(vnv.is_unconstrained) - empty!(vnv.num_inactive) - return nothing -end -BangBang.empty!!(vnv::VarNamedVector) = (empty!(vnv); return vnv) - -""" - replace_raw_storage(vnv::VarNamedVector, vals::AbstractVector) - -Replace the values in `vnv` with `vals`, as they are stored internally. - -This is useful when we want to update the entire underlying vector of values in one go or if -we want to change the how the values are stored, e.g. alter the `eltype`. - -!!! warning - This replaces the raw underlying values, and so care should be taken when using this - function. For example, if `vnv` has any inactive entries, then the provided `vals` - should also contain the inactive entries to avoid unexpected behavior. - -# Examples - -```jldoctest varnamedvector-replace-raw-storage -julia> using DynamicPPL: VarNamedVector, replace_raw_storage - -julia> vnv = VarNamedVector(@varname(x) => [1.0]); - -julia> replace_raw_storage(vnv, [2.0])[@varname(x)] == [2.0] -true -``` - -This is also useful when we want to differentiate wrt. the values using automatic -differentiation, e.g. ForwardDiff.jl. - -```jldoctest varnamedvector-replace-raw-storage -julia> using ForwardDiff: ForwardDiff - -julia> f(x) = sum(abs2, replace_raw_storage(vnv, x)[@varname(x)]) -f (generic function with 1 method) - -julia> ForwardDiff.gradient(f, [1.0]) -1-element Vector{Float64}: - 2.0 -``` -""" -replace_raw_storage(vnv::VarNamedVector, vals) = Accessors.@set vnv.vals = vals - -vector_length(vnv::VarNamedVector) = length(vnv.vals) - num_inactive(vnv) - -""" - unflatten(vnv::VarNamedVector, vals::AbstractVector) - -Return a new instance of `vnv` with the values of `vals` assigned to the variables. - -This assumes that `vals` have been transformed by the same transformations that that the -values in `vnv` have been transformed by. However, unlike [`replace_raw_storage`](@ref), -`unflatten` does account for inactive entries in `vnv`, so that the user does not have to -care about them. - -This is in a sense the reverse operation of `vnv[:]`. - -The return value may share memory with the input `vnv`, and thus one can not be mutated -safely without affecting the other. - -Unflatten recontiguifies the internal storage, getting rid of any inactive entries. - -# Examples - -```jldoctest varnamedvector-unflatten -julia> using DynamicPPL: VarNamedVector, unflatten - -julia> vnv = VarNamedVector(@varname(x) => [1.0, 2.0], @varname(y) => [3.0]); - -julia> unflatten(vnv, vnv[:]) == vnv -true -""" -function unflatten(vnv::VarNamedVector, vals::AbstractVector) - if length(vals) != vector_length(vnv) - throw( - ArgumentError( - "Length of `vals` ($(length(vals))) does not match the length of " * - "`vnv` ($(vector_length(vnv))).", - ), - ) - end - new_ranges = vnv.ranges - num_inactive = vnv.num_inactive - if has_inactive(vnv) - new_ranges = recontiguify_ranges!(new_ranges) - num_inactive = Dict{Int,Int}() - end - return VarNamedVector( - vnv.varname_to_index, - vnv.varnames, - new_ranges, - vals, - vnv.transforms, - vnv.is_unconstrained, - num_inactive; - check_consistency=false, - ) -end - -function Base.merge(left_vnv::VarNamedVector, right_vnv::VarNamedVector) - # Return early if possible. - isempty(left_vnv) && return deepcopy(right_vnv) - isempty(right_vnv) && return deepcopy(left_vnv) - - # Determine varnames. - vns_left = left_vnv.varnames - vns_right = right_vnv.varnames - vns_both = union(vns_left, vns_right) - - # Check that varnames do not subsume each other. - for vn_left in vns_left - for vn_right in vns_right - vn_left == vn_right && continue - # TODO(mhauru) Subsumation doesn't actually need to be a showstopper. For - # instance, if right has a value for `x` and left has a value for `x[1]`, then - # right will take precedence anyway, and we could merge. However, that requires - # some extra logic that hasn't been done yet. - if subsumes(vn_left, vn_right) - throw( - ArgumentError( - "Cannot merge VarNamedVectors: variable name $vn_left " * - "subsumes $vn_right.", - ), - ) - elseif subsumes(vn_right, vn_left) - throw( - ArgumentError( - "Cannot merge VarNamedVectors: variable name $vn_right " * - "subsumes $vn_left.", - ), - ) - end - end - end - - # Determine `eltype` of `vals`. - T_left = eltype(left_vnv.vals) - T_right = eltype(right_vnv.vals) - T = typejoin(T_left, T_right) - - # Determine `eltype` of `varnames`. - V_left = eltype(left_vnv.varnames) - V_right = eltype(right_vnv.varnames) - V = typejoin(V_left, V_right) - if !(V <: VarName) - V = VarName - end - - # Determine `eltype` of `transforms`. - F_left = eltype(left_vnv.transforms) - F_right = eltype(right_vnv.transforms) - F = typejoin(F_left, F_right) - - # Allocate. - varname_to_index = Dict{V,Int}() - ranges = UnitRange{Int}[] - vals = T[] - transforms = F[] - is_unconstrained = BitVector(undef, length(vns_both)) - - # Range offset. - offset = 0 - - for (idx, vn) in enumerate(vns_both) - varname_to_index[vn] = idx - # Extract the necessary information from `left` or `right`. - if vn in vns_left && !(vn in vns_right) - # `vn` is only in `left`. - val = getindex_internal(left_vnv, vn) - f = gettransform(left_vnv, vn) - is_unconstrained[idx] = is_transformed(left_vnv, vn) - else - # `vn` is either in both or just `right`. - # Note that in a `merge` the right value has precedence. - val = getindex_internal(right_vnv, vn) - f = gettransform(right_vnv, vn) - is_unconstrained[idx] = is_transformed(right_vnv, vn) - end - n = length(val) - r = (offset + 1):(offset + n) - # Update. - append!(vals, val) - push!(ranges, r) - push!(transforms, f) - # Increment `offset`. - offset += n - end - - return VarNamedVector( - varname_to_index, - vns_both, - ranges, - vals, - transforms, - is_unconstrained; - check_consistency=false, - ) -end - -""" - subset(vnv::VarNamedVector, vns::AbstractVector{<:VarName}) - -Return a new `VarNamedVector` containing the values from `vnv` for variables in `vns`. - -Which variables to include is determined by the `VarName`'s `subsumes` relation, meaning -that e.g. `subset(vnv, [@varname(x)])` will include variables like `@varname(x.a[1])`. - -Preserves the order of variables in `vnv`. - -# Examples - -```jldoctest varnamedvector-subset -julia> using DynamicPPL: VarNamedVector, @varname, subset - -julia> vnv = VarNamedVector(@varname(x) => [1.0, 2.0], @varname(y) => [3.0]); - -julia> subset(vnv, [@varname(x)]) == VarNamedVector(@varname(x) => [1.0, 2.0]) -true - -julia> subset(vnv, [@varname(x[2])]) == VarNamedVector(@varname(x[2]) => [2.0]) -true -""" -function subset(vnv::VarNamedVector, vns_given::AbstractVector{<:VarName}) - vnv_new = similar(vnv) - # Return early if possible. - isempty(vnv) && return vnv_new - - for vn in vnv.varnames - if any(subsumes(vn_given, vn) for vn_given in vns_given) - insert_internal!(vnv_new, getindex_internal(vnv, vn), vn, gettransform(vnv, vn)) - set_transformed!(vnv_new, is_transformed(vnv, vn), vn) - end - end - - return tighten_types!!(vnv_new) -end - -""" - similar(vnv::VarNamedVector) - -Return a new `VarNamedVector` with the same structure as `vnv`, but with empty values. - -In this respect `vnv` behaves more like a dictionary than an array: `similar(vnv)` will -be entirely empty, rather than have `undef` values in it. - -# Examples - -```julia-doctest-varnamedvector-similar -julia> using DynamicPPL: VarNamedVector, @varname, similar - -julia> vnv = VarNamedVector(@varname(x) => [1.0, 2.0], @varname(x[3]) => [3.0]); - -julia> similar(vnv) == VarNamedVector{VarName{:x}, Float64}() -true -""" -function Base.similar(vnv::VarNamedVector) - # NOTE: Whether or not we should empty the underlying containers or not - # is somewhat ambiguous. For example, `similar(vnv.varname_to_index)` will - # result in an empty `AbstractDict`, while the vectors, e.g. `vnv.ranges`, - # will result in non-empty vectors but with entries as `undef`. But it's - # much easier to write the rest of the code assuming that `undef` is not - # present, and so for now we empty the underlying containers, thus differing - # from the behavior of `similar` for `AbstractArray`s. - return VarNamedVector( - empty(vnv.varname_to_index), - similar(vnv.varnames, 0), - similar(vnv.ranges, 0), - similar(vnv.vals, 0), - similar(vnv.transforms, 0), - BitVector(), - empty(vnv.num_inactive); - check_consistency=false, - ) -end - -""" - is_contiguous(vnv::VarNamedVector) - -Returns `true` if the underlying data of `vnv` is stored in a contiguous array. - -This is equivalent to negating [`has_inactive(vnv)`](@ref). -""" -is_contiguous(vnv::VarNamedVector) = !has_inactive(vnv) - -""" - nextrange(vnv::VarNamedVector, x) - -Return the range of `length(x)` from the end of current data in `vnv`. -""" -function nextrange(vnv::VarNamedVector, x) - offset = length(vnv.vals) - return (offset + 1):(offset + length(x)) -end - -""" - shift_right!(x::AbstractVector{<:Real}, start::Int, n::Int) - -Shifts the elements of `x` starting from index `start` by `n` to the right. -""" -function shift_right!(x::AbstractVector{<:Real}, start::Int, n::Int) - x[(start + n):end] = x[start:(end - n)] - return x -end - -""" - shift_subsequent_ranges_by!(vnv::VarNamedVector, idx::Int, n) - -Shifts the ranges of variables in `vnv` starting from index `idx` by `n`. -""" -function shift_subsequent_ranges_by!(vnv::VarNamedVector, idx::Int, n) - for i in (idx + 1):length(vnv.ranges) - vnv.ranges[i] = vnv.ranges[i] .+ n - end - return nothing -end - -# set!! is the function defined in utils.jl that tries to do fancy stuff with optics when -# setting the value of a generic container using a VarName. We can bypass all that because -# VarNamedVector handles VarNames natively. However, it's semantics are slightly different -# from setindex!'s: It allows resetting variables that already have a value with values of -# a different type/size. -set!!(vnv::VarNamedVector, vn::VarName, val) = reset!!(vnv, val, vn) - -function setval!(vnv::VarNamedVector, val, vn::VarName) - return setindex_internal!(vnv, tovec(val), vn) -end - -function recontiguify_ranges!(ranges::AbstractVector{<:AbstractRange}) - offset = 0 - for i in 1:length(ranges) - r_old = ranges[i] - ranges[i] = (offset + 1):(offset + length(r_old)) - offset += length(r_old) - end - - return ranges -end - -""" - contiguify!(vnv::VarNamedVector) - -Re-contiguify the underlying vector and shrink if possible. - -# Examples - -```jldoctest varnamedvector-contiguify -julia> using DynamicPPL: VarNamedVector, @varname, contiguify!, update!, has_inactive - -julia> vnv = VarNamedVector(@varname(x) => [1.0, 2.0, 3.0], @varname(y) => [3.0]); - -julia> update!(vnv, [23.0, 24.0], @varname(x)); - -julia> has_inactive(vnv) -true - -julia> length(vnv.vals) -4 - -julia> contiguify!(vnv); - -julia> has_inactive(vnv) -false - -julia> length(vnv.vals) -3 - -julia> vnv[@varname(x)] # All the values are still there. -2-element Vector{Float64}: - 23.0 - 24.0 -``` -""" -function contiguify!(vnv::VarNamedVector) - if !has_inactive(vnv) - return vnv - end - # Extract the re-contiguified values. - # NOTE: We need to do this before we update the ranges. - old_vals = copy(vnv.vals) - old_ranges = copy(vnv.ranges) - # And then we re-contiguify the ranges. - recontiguify_ranges!(vnv.ranges) - # Clear the inactive ranges. - empty!(vnv.num_inactive) - # Now we update the values. - for (old_range, new_range) in zip(old_ranges, vnv.ranges) - vnv.vals[new_range] = old_vals[old_range] - end - # And (potentially) shrink the underlying vector. - resize!(vnv.vals, vnv.ranges[end][end]) - # The rest should be left as is. - return vnv -end - -""" - group_by_symbol(vnv::VarNamedVector) - -Return a dictionary mapping symbols to `VarNamedVector`s with varnames containing that -symbol. - -# Examples - -```jldoctest varnamedvector-group-by-symbol -julia> using DynamicPPL: VarNamedVector, @varname, group_by_symbol - -julia> vnv = VarNamedVector(@varname(x) => [1.0], @varname(y) => [2.0], @varname(x[1]) => [3.0]); - -julia> d = group_by_symbol(vnv); - -julia> collect(keys(d)) -[Symbol("x"), Symbol("y")] - -julia> d[@varname(x)] == VarNamedVector(@varname(x) => [1.0], @varname(x[1]) => [3.0]) -true - -julia> d[@varname(y)] == VarNamedVector(@varname(y) => [2.0]) -true -""" -function group_by_symbol(vnv::VarNamedVector) - symbols = unique(map(getsym, vnv.varnames)) - nt_vals = map(s -> tighten_types!!(subset(vnv, [VarName{s}()])), symbols) - return OrderedDict(zip(symbols, nt_vals)) -end - -""" - shift_index_left!(vnv::VarNamedVector, idx::Int) - -Shift the index `idx` to the left by one and update the relevant fields. - -This only affects `vnv.varname_to_index` and `vnv.num_inactive` and is only valid as a -helper function for [`shift_subsequent_indices_left!`](@ref). - -!!! warning - This does not check if index we're shifting to is already occupied. -""" -function shift_index_left!(vnv::VarNamedVector, idx::Int) - # Shift the index in the lookup table. - vn = vnv.varnames[idx] - vnv.varname_to_index[vn] = idx - 1 - # Shift the index in the inactive ranges. - if haskey(vnv.num_inactive, idx) - # Done in increasing order => don't need to worry about - # potentially shifting the same index twice. - vnv.num_inactive[idx - 1] = pop!(vnv.num_inactive, idx) - end -end - -""" - shift_subsequent_indices_left!(vnv::VarNamedVector, idx::Int) - -Shift the indices for all variables after `idx` to the left by one and update the relevant - fields. - -This only affects `vnv.varname_to_index` and `vnv.num_inactive` and is only valid as a -helper function for [`delete!`](@ref). -""" -function shift_subsequent_indices_left!(vnv::VarNamedVector, idx::Int) - # Shift the indices for all variables after `idx`. - for idx_to_shift in (idx + 1):length(vnv.varnames) - shift_index_left!(vnv, idx_to_shift) - end -end - -function Base.delete!(vnv::VarNamedVector, vn::VarName) - # Error if we don't have the variable. - !haskey(vnv, vn) && throw(ArgumentError("variable name $vn does not exist")) - - # Get the index of the variable. - idx = getidx(vnv, vn) - - # Delete the values. - r_start = first(getrange(vnv, idx)) - n_allocated = num_allocated(vnv, idx) - # NOTE: `deleteat!` also results in a `resize!` so we don't need to do that. - deleteat!(vnv.vals, r_start:(r_start + n_allocated - 1)) - - # Delete `vn` from the lookup table. - delete!(vnv.varname_to_index, vn) - - # Delete any inactive ranges corresponding to `vn`. - haskey(vnv.num_inactive, idx) && delete!(vnv.num_inactive, idx) - - # Re-adjust the indices for varnames occuring after `vn` so - # that they point to the correct indices after the deletions below. - shift_subsequent_indices_left!(vnv, idx) - - # Re-adjust the ranges for varnames occuring after `vn`. - shift_subsequent_ranges_by!(vnv, idx, -n_allocated) - - # Delete references from vector fields, thus shifting the indices of - # varnames occuring after `vn` by one to the left, as we adjusted for above. - deleteat!(vnv.varnames, idx) - deleteat!(vnv.ranges, idx) - deleteat!(vnv.transforms, idx) - - return vnv -end - -""" - delete!!(vnv::VarNamedVector, vn::VarName) - -Like `delete!!`, but tightens the element types of the returned `VarNamedVector`. - -# See also: -[`tighten_types!!`](@ref) -""" -BangBang.delete!!(vnv::VarNamedVector, vn::VarName) = tighten_types!!(delete!(vnv, vn)) - -""" - values_as(vnv::VarNamedVector[, T]) - -Return the values/realizations in `vnv` as type `T`, if implemented. - -If no type `T` is provided, return values as stored in `vnv`. - -# Examples - -```jldoctest -julia> using DynamicPPL: VarNamedVector - -julia> vnv = VarNamedVector(@varname(x) => 1, @varname(y) => [2.0]); - -julia> values_as(vnv) == [1.0, 2.0] -true - -julia> values_as(vnv, Vector{Float32}) == Vector{Float32}([1.0, 2.0]) -true - -julia> values_as(vnv, OrderedDict) == OrderedDict(@varname(x) => 1.0, @varname(y) => [2.0]) -true - -julia> values_as(vnv, NamedTuple) == (x = 1.0, y = [2.0]) -true -``` -""" -values_as(vnv::VarNamedVector) = values_as(vnv, Vector) -values_as(vnv::VarNamedVector, ::Type{Vector}) = getindex_internal(vnv, :) -function values_as(vnv::VarNamedVector, ::Type{Vector{T}}) where {T} - return convert(Vector{T}, values_as(vnv, Vector)) -end -function values_as(vnv::VarNamedVector, ::Type{NamedTuple}) - return NamedTuple(zip(map(Symbol, keys(vnv)), values(vnv))) -end -function values_as(vnv::VarNamedVector, ::Type{D}) where {D<:AbstractDict} - return ConstructionBase.constructorof(D)(pairs(vnv)) -end - -# See the docstring of `getvalue` for the semantics of `hasvalue` and `getvalue`, and how -# they differ from `haskey` and `getindex`. They can be found in AbstractPPL.jl. - -# TODO(mhauru) This is tricky to implement in the general case, and the below implementation -# only covers some simple cases. It's probably sufficient in most situations though. -function hasvalue(vnv::VarNamedVector, vn::VarName) - haskey(vnv, vn) && return true - any(subsumes(vn, k) for k in keys(vnv)) && return true - # Handle the easy case where the right symbol isn't even present. - !any(k -> getsym(k) == getsym(vn), keys(vnv)) && return false - - optic = getoptic(vn) - if optic isa Accessors.IndexLens || optic isa Accessors.ComposedOptic - # If vn is of the form @varname(somesymbol[someindex]), we check whether we store - # @varname(somesymbol) and can index into it with someindex. If we rather have a - # composed optic with the last part being an index lens, we do a similar check but - # stripping out the last index lens part. If these pass, the answer is definitely - # "yes". If not, we still don't know for sure. - # TODO(mhauru) What about casese where vnv stores both @varname(x) and - # @varname(x[1]) or @varname(x.a)? Those should probably be banned, but currently - # aren't. - head, tail = if optic isa Accessors.ComposedOptic - decomp_optic = Accessors.decompose(optic) - first(decomp_optic), Accessors.compose(decomp_optic[2:end]...) - else - optic, identity - end - parent_varname = VarName{getsym(vn)}(tail) - if haskey(vnv, parent_varname) - valvec = getindex(vnv, parent_varname) - return canview(head, valvec) - end - end - throw(ErrorException("hasvalue has not been fully implemented for this VarName: $(vn)")) -end - -# TODO(mhauru) Like hasvalue, this is only partially implemented. -function getvalue(vnv::VarNamedVector, vn::VarName) - !hasvalue(vnv, vn) && throw(KeyError(vn)) - haskey(vnv, vn) && getindex(vnv, vn) - - subsumed_keys = filter(k -> subsumes(vn, k), keys(vnv)) - if length(subsumed_keys) > 0 - # TODO(mhauru) What happens if getindex returns e.g. matrices, and we vcat them? - return mapreduce(k -> getindex(vnv, k), vcat, subsumed_keys) - end - - optic = getoptic(vn) - # See hasvalue for some comments on the logic of this if block. - if optic isa Accessors.IndexLens || optic isa Accessors.ComposedOptic - head, tail = if optic isa Accessors.ComposedOptic - decomp_optic = Accessors.decompose(optic) - first(decomp_optic), Accessors.compose(decomp_optic[2:end]...) - else - optic, identity - end - parent_varname = VarName{getsym(vn)}(tail) - valvec = getindex(vnv, parent_varname) - return head(valvec) - end - throw(ErrorException("getvalue has not been fully implemented for this VarName: $(vn)")) -end - -Base.get(vnv::VarNamedVector, vn::VarName) = getvalue(vnv, vn) diff --git a/src/vntvarinfo.jl b/src/vntvarinfo.jl index 6ce1a861e..b0cafa364 100644 --- a/src/vntvarinfo.jl +++ b/src/vntvarinfo.jl @@ -17,6 +17,15 @@ VarNamedTuples.vnt_size(tv::TransformedValue) = tv.size VNTVarInfo() = VNTVarInfo(VarNamedTuple(), default_accumulators()) +function VNTVarInfo(values::Union{NamedTuple,AbstractDict}) + vi = VarInfo() + for (k, v) in pairs(values) + vn = k isa Symbol ? VarName{k}() : k + vi = setindex!!(vi, v, vn) + end + return vi +end + function VNTVarInfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) return VNTVarInfo(Random.default_rng(), model, init_strategy) end diff --git a/test/model.jl b/test/model.jl index 7c5dc2fcc..281eaaad4 100644 --- a/test/model.jl +++ b/test/model.jl @@ -25,9 +25,6 @@ function innermost_distribution_type(d::Distributions.Product) return dists[1] end -is_type_stable_varinfo(::DynamicPPL.AbstractVarInfo) = false -is_type_stable_varinfo(varinfo::DynamicPPL.VNTVarInfo) = true - const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @testset "model.jl" begin @@ -221,7 +218,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @test !any(map(x -> x isa DynamicPPL.AbstractVarInfo, call_retval)) end - @testset "Dynamic constraints, Metadata" begin + @testset "Dynamic constraints" begin model = DynamicPPL.TestUtils.demo_dynamic_constraint() vi = VarInfo(model) vi = link!!(vi, model) @@ -415,10 +412,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() end vns = DynamicPPL.TestUtils.varnames(model) example_values = DynamicPPL.TestUtils.rand_prior_true(model) - varinfos = filter( - is_type_stable_varinfo, - DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns), - ) + varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos @test begin @inferred(DynamicPPL.evaluate!!(model, varinfo)) diff --git a/test/runtests.jl b/test/runtests.jl index 23dda437b..6521f1e4a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -52,9 +52,7 @@ include("test_util.jl") include("accumulators.jl") include("compiler.jl") include("varnamedtuple.jl") - # include("varnamedvector.jl") include("varinfo.jl") - # include("simple_varinfo.jl") include("model.jl") include("distribution_wrappers.jl") include("linking.jl") diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl deleted file mode 100644 index 2c0e21bec..000000000 --- a/test/simple_varinfo.jl +++ /dev/null @@ -1,345 +0,0 @@ -@testset "simple_varinfo.jl" begin - @testset "constructor & indexing" begin - @testset "NamedTuple" begin - svi = SimpleVarInfo(; m=1.0) - @test getlogjoint(svi) == 0.0 - @test haskey(svi, @varname(m)) - @test !haskey(svi, @varname(m[1])) - - svi = SimpleVarInfo(; m=[1.0]) - @test getlogjoint(svi) == 0.0 - @test haskey(svi, @varname(m)) - @test haskey(svi, @varname(m[1])) - @test !haskey(svi, @varname(m[2])) - @test svi[@varname(m)][1] == svi[@varname(m[1])] - - svi = SimpleVarInfo(; m=(a=[1.0],)) - @test haskey(svi, @varname(m)) - @test haskey(svi, @varname(m.a)) - @test haskey(svi, @varname(m.a[1])) - @test !haskey(svi, @varname(m.a[2])) - @test !haskey(svi, @varname(m.a.b)) - - svi = SimpleVarInfo{Float32}(; m=1.0) - @test getlogjoint(svi) isa Float32 - - svi = SimpleVarInfo((m=1.0,)) - svi = accloglikelihood!!(svi, 1.0) - @test getlogjoint(svi) == 1.0 - end - - @testset "Dict" begin - svi = SimpleVarInfo(OrderedDict(@varname(m) => 1.0)) - @test getlogjoint(svi) == 0.0 - @test haskey(svi, @varname(m)) - @test !haskey(svi, @varname(m[1])) - - svi = SimpleVarInfo(OrderedDict(@varname(m) => [1.0])) - @test getlogjoint(svi) == 0.0 - @test haskey(svi, @varname(m)) - @test haskey(svi, @varname(m[1])) - @test !haskey(svi, @varname(m[2])) - @test svi[@varname(m)][1] == svi[@varname(m[1])] - - svi = SimpleVarInfo(OrderedDict(@varname(m) => (a=[1.0],))) - @test haskey(svi, @varname(m)) - @test haskey(svi, @varname(m.a)) - @test haskey(svi, @varname(m.a[1])) - @test !haskey(svi, @varname(m.a[2])) - @test !haskey(svi, @varname(m.a.b)) - - svi = SimpleVarInfo(OrderedDict(@varname(m.a) => [1.0])) - # Now we only have a variable `m.a` which is subsumed by `m`, - # but we can't guarantee that we have the "entire" `m`. - @test !haskey(svi, @varname(m)) - @test haskey(svi, @varname(m.a)) - @test haskey(svi, @varname(m.a[1])) - @test !haskey(svi, @varname(m.a[2])) - @test !haskey(svi, @varname(m.a.b)) - end - - @testset "VarNamedVector" begin - svi = SimpleVarInfo(push!!(DynamicPPL.VarNamedVector(), @varname(m) => 1.0)) - @test getlogjoint(svi) == 0.0 - @test haskey(svi, @varname(m)) - @test !haskey(svi, @varname(m[1])) - - svi = SimpleVarInfo(push!!(DynamicPPL.VarNamedVector(), @varname(m) => [1.0])) - @test getlogjoint(svi) == 0.0 - @test haskey(svi, @varname(m)) - @test haskey(svi, @varname(m[1])) - @test !haskey(svi, @varname(m[2])) - @test svi[@varname(m)][1] == svi[@varname(m[1])] - - svi = SimpleVarInfo(push!!(DynamicPPL.VarNamedVector(), @varname(m.a) => [1.0])) - @test haskey(svi, @varname(m)) - @test haskey(svi, @varname(m.a)) - @test haskey(svi, @varname(m.a[1])) - @test !haskey(svi, @varname(m.a[2])) - @test !haskey(svi, @varname(m.a.b)) - # The implementation of haskey and getvalue fo VarNamedVector is incomplete, the - # next test is here to remind of us that. - svi = SimpleVarInfo( - push!!(DynamicPPL.VarNamedVector(), @varname(m.a.b) => [1.0]) - ) - @test_broken !haskey(svi, @varname(m.a.b.c.d)) - end - end - - @testset "link!! & invlink!! on $(nameof(model))" for model in - DynamicPPL.TestUtils.ALL_MODELS - values_constrained = DynamicPPL.TestUtils.rand_prior_true(model) - @testset "$name" for (name, vi) in ( - ("SVI{Dict}", SimpleVarInfo(OrderedDict{VarName,Any}())), - ("SVI{NamedTuple}", SimpleVarInfo(values_constrained)), - ("SVI{VNV}", SimpleVarInfo(DynamicPPL.VarNamedVector())), - ("TypedVarInfo", DynamicPPL.typed_varinfo(model)), - ) - if name == "SVI{NamedTuple}" && - model.f === DynamicPPL.TestUtils.demo_one_variable_multiple_constraints - # TODO(mhauru) There's a bug in SimpleVarInfo{<:NamedTuple} for cases where - # a variable set with IndexLenses changes dimension under linking. This - # makes the link!! call crash. The below call to @test just marks the fact - # that there's something broken here. - @test false broken = true - continue - end - for vn in DynamicPPL.TestUtils.varnames(model) - vi = DynamicPPL.setindex!!(vi, get(values_constrained, vn), vn) - end - vi = last(DynamicPPL.evaluate!!(model, vi)) - - # Calculate ground truth - lp_unlinked_true = DynamicPPL.TestUtils.logjoint_true( - model, values_constrained... - ) - _, lp_linked_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( - model, values_constrained... - ) - - # `link!!` - vi_linked = link!!(deepcopy(vi), model) - lp_unlinked = getlogjoint(vi_linked) - lp_linked = getlogjoint_internal(vi_linked) - @test lp_linked ≈ lp_linked_true - @test lp_unlinked ≈ lp_unlinked_true - @test logjoint(model, vi_linked) ≈ lp_unlinked - - # `invlink!!` - vi_invlinked = invlink!!(deepcopy(vi_linked), model) - lp_unlinked = getlogjoint(vi_invlinked) - also_lp_unlinked = getlogjoint_internal(vi_invlinked) - @test lp_unlinked ≈ lp_unlinked_true - @test also_lp_unlinked ≈ lp_unlinked_true - @test logjoint(model, vi_invlinked) ≈ lp_unlinked - - # Should result in same values. - @test all( - DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_invlinked, vn)) ≈ - DynamicPPL.tovec(get(values_constrained, vn)) for - vn in DynamicPPL.TestUtils.varnames(model) - ) - end - end - - @testset "SimpleVarInfo on $(nameof(model))" for model in - DynamicPPL.TestUtils.ALL_MODELS - if model.f === DynamicPPL.TestUtils.demo_nested_colons - # TODO(mhauru) Either VarNamedVector or SimpleVarInfo has a bug that causes - # the push!! below to fail with a NamedTuple variable like what - # demo_nested_colons has. I don't want to fix it now though, because this may - # all go soon (as of 2025-12-16). - @test false broken = true - continue - end - # We might need to pre-allocate for the variable `m`, so we need - # to see whether this is the case. - svi_nt = SimpleVarInfo(DynamicPPL.TestUtils.rand_prior_true(model)) - svi_dict = SimpleVarInfo(VarInfo(model), Dict) - vnv = DynamicPPL.VarNamedVector() - for (k, v) in pairs(DynamicPPL.TestUtils.rand_prior_true(model)) - vnv = push!!(vnv, VarName{k}() => v) - end - svi_vnv = SimpleVarInfo(vnv) - - @testset "$name" for (name, svi) in ( - ("NamedTuple", svi_nt), - ("Dict", svi_dict), - ("VarNamedVector", svi_vnv), - # TODO(mhauru) Fix linked SimpleVarInfos to work with our test models. - # DynamicPPL.set_transformed!!(deepcopy(svi_nt), true), - # DynamicPPL.set_transformed!!(deepcopy(svi_dict), true), - # DynamicPPL.set_transformed!!(deepcopy(svi_vnv), true), - ) - # Random seed is set in each `@testset`, so we need to sample - # a new realization for `m` here. - retval = model() - - ### Sampling ### - # Sample a new varinfo! - _, svi_new = DynamicPPL.init!!(model, svi) - - # Realization for `m` should be different wp. 1. - for vn in DynamicPPL.TestUtils.varnames(model) - @test svi_new[vn] != get(retval, vn) - end - - # Logjoint should be non-zero wp. 1. - @test getlogjoint(svi_new) != 0 - - ### Evaluation ### - values_eval_constrained = DynamicPPL.TestUtils.rand_prior_true(model) - if DynamicPPL.is_transformed(svi) - _values_prior, logpri_true = DynamicPPL.TestUtils.logprior_true_with_logabsdet_jacobian( - model, values_eval_constrained... - ) - values_eval, logπ_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( - model, values_eval_constrained... - ) - # Make sure that these two computation paths provide the same - # transformed values. - @test values_eval == _values_prior - else - logpri_true = DynamicPPL.TestUtils.logprior_true( - model, values_eval_constrained... - ) - logπ_true = DynamicPPL.TestUtils.logjoint_true( - model, values_eval_constrained... - ) - values_eval = values_eval_constrained - end - - # No logabsdet-jacobian correction needed for the likelihood. - loglik_true = DynamicPPL.TestUtils.loglikelihood_true( - model, values_eval_constrained... - ) - - # Update the realizations in `svi_new`. - svi_eval = svi_new - for vn in DynamicPPL.TestUtils.varnames(model) - svi_eval = DynamicPPL.setindex!!(svi_eval, get(values_eval, vn), vn) - end - - # Reset the logp accumulators. - svi_eval = DynamicPPL.resetaccs!!(svi_eval) - - # Compute `logjoint` using the varinfo. - logπ = logjoint(model, svi_eval) - logpri = logprior(model, svi_eval) - loglik = loglikelihood(model, svi_eval) - - # Values should not have changed. - for vn in DynamicPPL.TestUtils.varnames(model) - # TODO(mhauru) Workaround for - # https://github.com/JuliaLang/LinearAlgebra.jl/pull/1404 - # Remove once the fix is all Julia versions we support. - val = get(values_eval, vn) - if val isa Cholesky - @test svi_eval[vn].L == val.L - else - @test svi_eval[vn] == val - end - end - - # Compare log-probability computations. - @test logpri ≈ logpri_true - @test loglik ≈ loglik_true - @test logπ ≈ logπ_true - end - end - - @testset "Dynamic constraints" begin - model = DynamicPPL.TestUtils.demo_dynamic_constraint() - - # Initialize. - svi_nt = DynamicPPL.set_transformed!!(SimpleVarInfo(), true) - svi_nt = last(DynamicPPL.init!!(model, svi_nt)) - svi_vnv = DynamicPPL.set_transformed!!( - SimpleVarInfo(DynamicPPL.VarNamedVector()), true - ) - svi_vnv = last(DynamicPPL.init!!(model, svi_vnv)) - - for svi in (svi_nt, svi_vnv) - # Sample with large variations in unconstrained space. - for i in 1:10 - for vn in keys(svi) - svi = DynamicPPL.setindex!!(svi, 10 * randn(), vn) - end - retval, svi = DynamicPPL.evaluate!!(model, svi) - @test retval.m == svi[@varname(m)] # `m` is unconstrained - @test retval.x ≠ svi[@varname(x)] # `x` is constrained depending on `m` - - retval_unconstrained, lp_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( - model, retval.m, retval.x - ) - - # Realizations from model should all be equal to the unconstrained realization. - for vn in DynamicPPL.TestUtils.varnames(model) - @test get(retval_unconstrained, vn) ≈ svi[vn] rtol = 1e-6 - end - - # `getlogp` should be equal to the logjoint with log-absdet-jac correction. - lp = getlogjoint_internal(svi) - # needs higher atol because of https://github.com/TuringLang/Bijectors.jl/issues/375 - @test lp ≈ lp_true atol = 1.2e-5 - end - end - end - - @testset "Static transformation" begin - model = DynamicPPL.TestUtils.demo_static_transformation() - - varinfos = DynamicPPL.TestUtils.setup_varinfos( - model, DynamicPPL.TestUtils.rand_prior_true(model), [@varname(s), @varname(m)] - ) - @testset "$(short_varinfo_name(vi))" for vi in varinfos - # Initialize varinfo and link. - vi_linked = DynamicPPL.link!!(vi, model) - - # Make sure `maybe_invlink_before_eval!!` results in `invlink!!`. - @test !DynamicPPL.is_transformed( - DynamicPPL.maybe_invlink_before_eval!!(deepcopy(vi), model) - ) - - # Resulting varinfo should no longer be transformed. - vi_result = last(DynamicPPL.init!!(model, deepcopy(vi))) - @test !DynamicPPL.is_transformed(vi_result) - - # Set the values to something that is out of domain if we're in constrained space. - for vn in keys(vi) - vi_linked = DynamicPPL.setindex!!(vi_linked, -rand(), vn) - end - - # NOTE: Evaluating a linked VarInfo, **specifically when the transformation - # is static**, will result in an invlinked VarInfo. This is because of - # `maybe_invlink_before_eval!`, which only invlinks if the transformation - # is static. (src/abstract_varinfo.jl) - retval, vi_unlinked_again = DynamicPPL.evaluate!!(model, deepcopy(vi_linked)) - - @test DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_linked, @varname(s))) ≠ - DynamicPPL.tovec(retval.s) # `s` is unconstrained in original - @test DynamicPPL.tovec( - DynamicPPL.getindex_internal(vi_unlinked_again, @varname(s)) - ) == DynamicPPL.tovec(retval.s) # `s` is constrained in result - - # `m` should not be transformed. - @test vi_linked[@varname(m)] == retval.m - @test vi_unlinked_again[@varname(m)] == retval.m - - # Get ground truths - retval_unconstrained, lp_linked_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( - model, retval.s, retval.m - ) - lp_unlinked_true = DynamicPPL.TestUtils.logjoint_true(model, retval.s, retval.m) - - @test DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_linked, @varname(s))) ≈ - DynamicPPL.tovec(retval_unconstrained.s) - @test DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_linked, @varname(m))) ≈ - DynamicPPL.tovec(retval_unconstrained.m) - - # The unlinked varinfo should hold the unlinked logp. - lp_unlinked = getlogjoint(vi_unlinked_again) - @test getlogjoint(vi_unlinked_again) ≈ lp_unlinked_true - end - end -end diff --git a/test/test_util.jl b/test/test_util.jl index 821b1e0db..9f6939adf 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -16,29 +16,6 @@ Return string representing a short description of `vi`. function short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) return "threadsafe($(short_varinfo_name(vi.varinfo)))" end -# function short_varinfo_name(vi::DynamicPPL.NTVarInfo) -# return if DynamicPPL.has_varnamedvector(vi) -# "TypedVectorVarInfo" -# else -# "TypedVarInfo" -# end -# end -# short_varinfo_name(::DynamicPPL.UntypedVarInfo) = "UntypedVarInfo" -# short_varinfo_name(::DynamicPPL.UntypedVectorVarInfo) = "UntypedVectorVarInfo" -function short_varinfo_name(::SimpleVarInfo{<:NamedTuple,<:Ref}) - return "SimpleVarInfo{<:NamedTuple,<:Ref}" -end -function short_varinfo_name(::SimpleVarInfo{<:OrderedDict,<:Ref}) - return "SimpleVarInfo{<:OrderedDict,<:Ref}" -end -# function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector,<:Ref}) -# return "SimpleVarInfo{<:VarNamedVector,<:Ref}" -# end -short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}" -short_varinfo_name(::SimpleVarInfo{<:OrderedDict}) = "SimpleVarInfo{<:OrderedDict}" -# function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector}) -# return "SimpleVarInfo{<:VarNamedVector}" -# end function short_varinfo_name(::DynamicPPL.VNTVarInfo) return "VNTVarInfo" end diff --git a/test/varinfo.jl b/test/varinfo.jl index 1d01a0cf8..8ae0535c7 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -1,17 +1,6 @@ function check_varinfo_keys(varinfo, vns) - if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple{<:NamedTuple} - # NOTE: We can't compare the `keys(varinfo_merged)` directly with `vns`, - # since `keys(varinfo_merged)` only contains `VarName` with `identity`. - # So we just check that the original keys are present. - for vn in vns - # Should have all the original keys. - @test haskey(varinfo, vn) - end - else - vns_varinfo = keys(varinfo) - # Should be equivalent. - @test union(vns_varinfo, vns) == intersect(vns_varinfo, vns) - end + vns_varinfo = keys(varinfo) + @test union(vns_varinfo, vns) == intersect(vns_varinfo, vns) end @testset "varinfo.jl" begin @@ -446,13 +435,9 @@ end varinfos = DynamicPPL.TestUtils.setup_varinfos( model, model(), vns; include_threadsafe=true ) - varinfos_standard = filter(Base.Fix2(isa, VarInfo), varinfos) - varinfos_simple = filter( - Base.Fix2(isa, DynamicPPL.SimpleOrThreadSafeSimple), varinfos - ) # `VarInfo` supports subsetting using, basically, arbitrary varnames. - vns_supported_standard = [ + vns_supported = [ [@varname(s)], [@varname(m)], [@varname(x[1])], @@ -477,25 +462,10 @@ end [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])], ] - # `SimpleVarInfo` only supports subsetting using the varnames as they appear - # in the model. - vns_supported_simple = filter(∈(vns), vns_supported_standard) - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos # All variables. check_varinfo_keys(varinfo, vns) - # Added a `convert` to make the naming of the testsets a bit more readable. - # `SimpleVarInfo{<:NamedTuple}` only supports subsetting with "simple" varnames, - ## i.e. `VarName{sym}()` without any indexing, etc. - vns_supported = - if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple && - values_as(varinfo) isa NamedTuple - vns_supported_simple - else - vns_supported_standard - end - @testset ("$(convert(Vector{VarName}, vns_subset)) empty") for vns_subset in vns_supported varinfo_subset = subset(varinfo, VarName[]) diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl deleted file mode 100644 index 9a4ef12c3..000000000 --- a/test/varnamedvector.jl +++ /dev/null @@ -1,711 +0,0 @@ -replace_sym(vn::VarName, sym_new::Symbol) = VarName{sym_new}(vn.lens) - -increase_size_for_test(x::Real) = [x] -increase_size_for_test(x::AbstractArray) = repeat(x, 2) - -decrease_size_for_test(x::Real) = x -decrease_size_for_test(x::AbstractVector) = first(x) -decrease_size_for_test(x::AbstractArray) = first(eachslice(x; dims=1)) - -function need_varnames_relaxation(vnv::DynamicPPL.VarNamedVector, vn::VarName, val) - if isconcretetype(eltype(vnv.varnames)) - # If the container is concrete, we need to make sure that the varname types match. - # E.g. if `vnv.varnames` has `eltype` `VarName{:x, IndexLens{Tuple{Int64}}}` then - # we need `vn` to also be of this type. - # => If the varname types don't match, we need to relax the container type. - return any(keys(vnv)) do vn_present - typeof(vn_present) !== typeof(val) - end - end - - return false -end -function need_varnames_relaxation(vnv::DynamicPPL.VarNamedVector, vns, vals) - return any(need_varnames_relaxation(vnv, vn, val) for (vn, val) in zip(vns, vals)) -end - -function need_values_relaxation(vnv::DynamicPPL.VarNamedVector, vn::VarName, val) - if isconcretetype(eltype(vnv.vals)) - return promote_type(eltype(vnv.vals), eltype(val)) != eltype(vnv.vals) - end - - return false -end -function need_values_relaxation(vnv::DynamicPPL.VarNamedVector, vns, vals) - return any(need_values_relaxation(vnv, vn, val) for (vn, val) in zip(vns, vals)) -end - -function need_transforms_relaxation(vnv::DynamicPPL.VarNamedVector, vn::VarName, val) - return if isconcretetype(eltype(vnv.transforms)) - # If the container is concrete, we need to make sure that the sizes match. - # => If the sizes don't match, we need to relax the container type. - any(keys(vnv)) do vn_present - size(vnv[vn_present]) != size(val) - end - elseif eltype(vnv.transforms) !== Any - # If it's not concrete AND it's not `Any`, then we should just make it `Any`. - true - else - # Otherwise, it's `Any`, so we don't need to relax the container type. - false - end -end -function need_transforms_relaxation(vnv::DynamicPPL.VarNamedVector, vns, vals) - return any(need_transforms_relaxation(vnv, vn, val) for (vn, val) in zip(vns, vals)) -end - -""" - relax_container_types(vnv::VarNamedVector, vn::VarName, val) - relax_container_types(vnv::VarNamedVector, vns, val) - -Relax the container types of `vnv` if necessary to accommodate `vn` and `val`. - -This attempts to avoid unnecessary container type relaxations by checking whether -the container types of `vnv` are already compatible with `vn` and `val`. - -# Notes -For example, if `vn` is not compatible with the current keys in `vnv`, then -the underlying types will be changed to `VarName` to accommodate `vn`. - -Similarly: -- If `val` is not compatible with the current values in `vnv`, then - the underlying value type will be changed to `Real`. -- If `val` requires a transformation that is not compatible with the current - transformations type in `vnv`, then the underlying transformation type will - be changed to `Any`. -""" -function relax_container_types(vnv::DynamicPPL.VarNamedVector, vn::VarName, val) - return relax_container_types(vnv, [vn], [val]) -end -function relax_container_types(vnv::DynamicPPL.VarNamedVector, vns, vals) - if need_varnames_relaxation(vnv, vns, vals) - varname_to_index_new = convert(Dict{VarName,Int}, vnv.varname_to_index) - varnames_new = convert(Vector{VarName}, vnv.varnames) - else - varname_to_index_new = vnv.varname_to_index - varnames_new = vnv.varnames - end - - transforms_new = if need_transforms_relaxation(vnv, vns, vals) - convert(Vector{Any}, vnv.transforms) - else - vnv.transforms - end - - vals_new = if need_values_relaxation(vnv, vns, vals) - convert(Vector{Real}, vnv.vals) - else - vnv.vals - end - - return DynamicPPL.VarNamedVector( - varname_to_index_new, - varnames_new, - vnv.ranges, - vals_new, - transforms_new, - vnv.is_unconstrained, - vnv.num_inactive, - ) -end - -@testset "VarNamedVector" begin - # Test element-related operations: - # - `getindex` - # - `setindex!` - # - `push!` - # - `update!` - # - `insert!` - # - `reset!` - # - `_internal!` versions of the above - # - !! versions of the above - # - # And these are all be tested for different types of values: - # - scalar - # - vector - # - matrix - - # Test operations on `VarNamedVector`: - # - `empty!` - # - `iterate` - # - `convert` to - # - `AbstractDict` - test_pairs = OrderedDict( - @varname(x[1]) => rand(), - @varname(x[2]) => rand(2), - @varname(x[3]) => rand(2, 3), - @varname(y[1]) => rand(), - @varname(y[2]) => rand(2), - @varname(y[3]) => rand(2, 3), - @varname(z[1]) => rand(1:10), - @varname(z[2]) => rand(1:10, 2), - @varname(z[3]) => rand(1:10, 2, 3), - ) - test_vns = collect(keys(test_pairs)) - test_vals = collect(values(test_pairs)) - - @testset "constructor: no args" begin - # Empty. - vnv = DynamicPPL.VarNamedVector() - @test isempty(vnv) - @test eltype(vnv) == Union{} - - # Empty with types. - vnv = DynamicPPL.VarNamedVector{VarName,Float64,typeof(identity)}() - @test isempty(vnv) - @test eltype(vnv) == Float64 - end - - test_varnames_iter = combinations(test_vns, 2) - @testset "$(vn_left) and $(vn_right)" for (vn_left, vn_right) in test_varnames_iter - val_left = test_pairs[vn_left] - val_right = test_pairs[vn_right] - vnv_base = DynamicPPL.VarNamedVector([vn_left, vn_right], [val_left, val_right]) - - # We'll need the transformations later. - # TODO: Should we test other transformations than just `ReshapeTransform`? - from_vec_left = DynamicPPL.from_vec_transform(val_left) - from_vec_right = DynamicPPL.from_vec_transform(val_right) - to_vec_left = inverse(from_vec_left) - to_vec_right = inverse(from_vec_right) - - # Compare to alternative constructors. - vnv_from_dict = DynamicPPL.VarNamedVector( - OrderedDict(vn_left => val_left, vn_right => val_right) - ) - @test vnv_base == vnv_from_dict - - # We want the types of fields such as `varnames` and `transforms` to specialize - # whenever possible + some functionality, e.g. `push!`, is only sensible - # if the underlying containers can support it. - # Expected behavior - should_have_restricted_varname_type = typeof(vn_left) == typeof(vn_right) - should_have_restricted_transform_type = size(val_left) == size(val_right) - # Actual behavior - has_restricted_transform_type = isconcretetype(eltype(vnv_base.transforms)) - has_restricted_varname_type = isconcretetype(eltype(vnv_base.varnames)) - - @testset "type specialization" begin - @test !should_have_restricted_varname_type || has_restricted_varname_type - @test !should_have_restricted_transform_type || has_restricted_transform_type - end - - @test eltype(vnv_base) == promote_type(eltype(val_left), eltype(val_right)) - @test DynamicPPL.length_internal(vnv_base) == length(val_left) + length(val_right) - @test length(vnv_base) == 2 - - @test !isempty(vnv_base) - - @testset "empty!" begin - vnv = deepcopy(vnv_base) - empty!(vnv) - @test isempty(vnv) - end - - @testset "similar" begin - vnv = similar(vnv_base) - @test isempty(vnv) - @test typeof(vnv) == typeof(vnv_base) - end - - @testset "getindex" begin - # With `VarName` index. - @test vnv_base[vn_left] == val_left - @test vnv_base[vn_right] == val_right - end - - @testset "getindex_internal" begin - @test DynamicPPL.DynamicPPL.getindex_internal(vnv_base, vn_left) == - to_vec_left(val_left) - @test DynamicPPL.DynamicPPL.getindex_internal(vnv_base, vn_right) == - to_vec_right(val_right) - end - - @testset "getindex_internal with Ints" begin - for (i, val) in enumerate(to_vec_left(val_left)) - @test DynamicPPL.DynamicPPL.getindex_internal(vnv_base, i) == val - end - offset = length(to_vec_left(val_left)) - for (i, val) in enumerate(to_vec_right(val_right)) - @test DynamicPPL.DynamicPPL.getindex_internal(vnv_base, offset + i) == val - end - end - - @testset "update!" begin - vnv = deepcopy(vnv_base) - DynamicPPL.update!(vnv, val_left .+ 100, vn_left) - @test vnv[vn_left] == val_left .+ 100 - DynamicPPL.update!(vnv, val_right .+ 100, vn_right) - @test vnv[vn_right] == val_right .+ 100 - end - - @testset "update!!" begin - vnv = deepcopy(vnv_base) - vnv = DynamicPPL.update!!(vnv, val_left .+ 100, vn_left) - @test vnv[vn_left] == val_left .+ 100 - vnv = DynamicPPL.update!!(vnv, val_right .+ 100, vn_right) - @test vnv[vn_right] == val_right .+ 100 - end - - @testset "update_internal!" begin - vnv = deepcopy(vnv_base) - DynamicPPL.update_internal!(vnv, to_vec_left(val_left .+ 100), vn_left) - @test vnv[vn_left] == val_left .+ 100 - DynamicPPL.update_internal!(vnv, to_vec_right(val_right .+ 100), vn_right) - @test vnv[vn_right] == val_right .+ 100 - end - - @testset "update_internal!!" begin - vnv = deepcopy(vnv_base) - vnv = DynamicPPL.update_internal!!(vnv, to_vec_left(val_left .+ 100), vn_left) - @test vnv[vn_left] == val_left .+ 100 - vnv = DynamicPPL.update_internal!!( - vnv, to_vec_right(val_right .+ 100), vn_right - ) - @test vnv[vn_right] == val_right .+ 100 - end - - @testset "delete!" begin - vnv = deepcopy(vnv_base) - delete!(vnv, vn_left) - @test !haskey(vnv, vn_left) - @test haskey(vnv, vn_right) - delete!(vnv, vn_right) - @test !haskey(vnv, vn_right) - end - - @testset "insert!" begin - vnv = deepcopy(vnv_base) - delete!(vnv, vn_left) - delete!(vnv, vn_right) - DynamicPPL.insert!(vnv, val_left .+ 100, vn_left) - @test vnv[vn_left] == val_left .+ 100 - DynamicPPL.insert!(vnv, val_right .+ 100, vn_right) - @test vnv[vn_right] == val_right .+ 100 - end - - @testset "insert!!" begin - vnv = deepcopy(vnv_base) - delete!(vnv, vn_left) - delete!(vnv, vn_right) - vnv = DynamicPPL.insert!!(vnv, val_left .+ 100, vn_left) - @test vnv[vn_left] == val_left .+ 100 - vnv = DynamicPPL.insert!!(vnv, val_right .+ 100, vn_right) - @test vnv[vn_right] == val_right .+ 100 - end - - @testset "insert_internal!" begin - vnv = deepcopy(vnv_base) - delete!(vnv, vn_left) - delete!(vnv, vn_right) - DynamicPPL.insert_internal!( - vnv, to_vec_left(val_left .+ 100), vn_left, from_vec_left - ) - @test vnv[vn_left] == val_left .+ 100 - DynamicPPL.insert_internal!( - vnv, to_vec_right(val_right .+ 100), vn_right, from_vec_right - ) - @test vnv[vn_right] == val_right .+ 100 - end - - @testset "insert_internal!!" begin - vnv = deepcopy(vnv_base) - delete!(vnv, vn_left) - delete!(vnv, vn_right) - vnv = DynamicPPL.insert_internal!!( - vnv, to_vec_left(val_left .+ 100), vn_left, from_vec_left - ) - @test vnv[vn_left] == val_left .+ 100 - vnv = DynamicPPL.insert_internal!!( - vnv, to_vec_right(val_right .+ 100), vn_right, from_vec_right - ) - @test vnv[vn_right] == val_right .+ 100 - end - - @testset "merge" begin - # When there are no inactive entries, `merge` on itself result in the same. - @test merge(vnv_base, vnv_base) == vnv_base - - # Merging with empty should result in the same. - @test merge(vnv_base, similar(vnv_base)) == vnv_base - @test merge(similar(vnv_base), vnv_base) == vnv_base - - # With differences. - vnv_left_only = deepcopy(vnv_base) - delete!(vnv_left_only, vn_right) - vnv_right_only = deepcopy(vnv_base) - delete!(vnv_right_only, vn_left) - - # `(x,)` and `(x, y)` should be `(x, y)`. - @test merge(vnv_left_only, vnv_base) == vnv_base - # `(x, y)` and `(x,)` should be `(x, y)`. - @test merge(vnv_base, vnv_left_only) == vnv_base - # `(x, y)` and `(y,)` should be `(x, y)`. - @test merge(vnv_base, vnv_right_only) == vnv_base - # `(y,)` and `(x, y)` should be `(y, x)`. - vnv_merged = merge(vnv_right_only, vnv_base) - @test vnv_merged != vnv_base - @test collect(keys(vnv_merged)) == [vn_right, vn_left] - end - - @testset "push!" begin - vnv = relax_container_types(deepcopy(vnv_base), test_vns, test_vals) - @testset "$vn" for vn in test_vns - val = test_pairs[vn] - vnv_copy = deepcopy(vnv) - push!(vnv, (vn => val)) - @test vnv[vn] == val - end - end - - @testset "setindex_internal!" begin - # Not setting the transformation. - vnv = deepcopy(vnv_base) - DynamicPPL.setindex_internal!(vnv, to_vec_left(val_left .+ 100), vn_left) - @test vnv[vn_left] == val_left .+ 100 - DynamicPPL.setindex_internal!(vnv, to_vec_right(val_right .+ 100), vn_right) - @test vnv[vn_right] == val_right .+ 100 - - # Explicitly setting the transformation. - increment(x) = x .+ 10 - vnv = deepcopy(vnv_base) - vnv = DynamicPPL.loosen_types!!( - vnv, typeof(vn_left), eltype(vnv), typeof(increment) - ) - DynamicPPL.setindex_internal!( - vnv, to_vec_left(val_left .+ 100), vn_left, increment - ) - @test vnv[vn_left] == to_vec_left(val_left .+ 110) - - vnv = DynamicPPL.loosen_types!!( - vnv, typeof(vn_right), eltype(vnv), typeof(increment) - ) - DynamicPPL.setindex_internal!( - vnv, to_vec_right(val_right .+ 100), vn_right, increment - ) - @test vnv[vn_right] == to_vec_right(val_right .+ 110) - - # Adding new values. - vnv = relax_container_types(deepcopy(vnv_base), test_vns, test_vals) - @testset "$vn" for vn in test_vns - val = test_pairs[vn] - from_vec_vn = DynamicPPL.from_vec_transform(val) - to_vec_vn = inverse(from_vec_vn) - DynamicPPL.setindex_internal!(vnv, to_vec_vn(val), vn, from_vec_vn) - @test vnv[vn] == val - end - end - - @testset "setindex_internal! with Ints" begin - vnv = deepcopy(vnv_base) - for i in 1:DynamicPPL.length_internal(vnv_base) - DynamicPPL.setindex_internal!(vnv, i, i) - end - for i in 1:DynamicPPL.length_internal(vnv_base) - @test DynamicPPL.getindex_internal(vnv, i) == i - end - end - - @testset "setindex_internal!!" begin - # Not setting the transformation. - vnv = deepcopy(vnv_base) - vnv = DynamicPPL.setindex_internal!!(vnv, to_vec_left(val_left .+ 100), vn_left) - @test vnv[vn_left] == val_left .+ 100 - vnv = DynamicPPL.setindex_internal!!( - vnv, to_vec_right(val_right .+ 100), vn_right - ) - @test vnv[vn_right] == val_right .+ 100 - - # Explicitly setting the transformation. - # Note that unlike with setindex_internal!, we don't need loosen_types!! here. - increment(x) = x .+ 10 - vnv = deepcopy(vnv_base) - vnv = DynamicPPL.setindex_internal!!( - vnv, to_vec_left(val_left .+ 100), vn_left, increment - ) - @test vnv[vn_left] == to_vec_left(val_left .+ 110) - - vnv = DynamicPPL.setindex_internal!!( - vnv, to_vec_right(val_right .+ 100), vn_right, increment - ) - @test vnv[vn_right] == to_vec_right(val_right .+ 110) - - # Adding new values. - vnv = relax_container_types(deepcopy(vnv_base), test_vns, test_vals) - @testset "$vn" for vn in test_vns - val = test_pairs[vn] - from_vec_vn = DynamicPPL.from_vec_transform(val) - to_vec_vn = inverse(from_vec_vn) - vnv = DynamicPPL.setindex_internal!!(vnv, to_vec_vn(val), vn, from_vec_vn) - @test vnv[vn] == val - end - end - - @testset "setindex! and reset!" begin - vnv = relax_container_types(deepcopy(vnv_base), test_vns, test_vals) - @testset "$vn" for vn in test_vns - val = test_pairs[vn] - expected_length = if haskey(vnv, vn) - # If it's already present, the resulting length will be unchanged. - DynamicPPL.length_internal(vnv) - else - DynamicPPL.length_internal(vnv) + length(val) - end - - vnv[vn] = val .+ 1 - x = DynamicPPL.getindex_internal(vnv, :) - @test vnv[vn] == val .+ 1 - @test DynamicPPL.length_internal(vnv) == expected_length - @test length(x) == DynamicPPL.length_internal(vnv) - @test all( - DynamicPPL.getindex_internal(vnv, i) == x[i] for i in eachindex(x) - ) - - # There should be no redundant values in the underlying vector. - @test !DynamicPPL.has_inactive(vnv) - end - - vnv = relax_container_types(deepcopy(vnv_base), test_vns, test_vals) - @testset "$vn (increased size)" for vn in test_vns - val_original = test_pairs[vn] - val = increase_size_for_test(val_original) - vn_already_present = haskey(vnv, vn) - expected_length = if vn_already_present - # If it's already present, the resulting length will be altered. - DynamicPPL.length_internal(vnv) + length(val) - length(val_original) - else - DynamicPPL.length_internal(vnv) + length(val) - end - - # Have to use reset!, because setindex! doesn't support decreasing size. - DynamicPPL.reset!(vnv, val .+ 1, vn) - x = DynamicPPL.getindex_internal(vnv, :) - @test vnv[vn] == val .+ 1 - @test DynamicPPL.length_internal(vnv) == expected_length - @test length(x) == DynamicPPL.length_internal(vnv) - @test all( - DynamicPPL.getindex_internal(vnv, i) == x[i] for i in eachindex(x) - ) - end - - vnv = relax_container_types(deepcopy(vnv_base), test_vns, test_vals) - @testset "$vn (decreased size)" for vn in test_vns - val_original = test_pairs[vn] - val = decrease_size_for_test(val_original) - vn_already_present = haskey(vnv, vn) - expected_length = if vn_already_present - # If it's already present, the resulting length will be altered. - DynamicPPL.length_internal(vnv) + length(val) - length(val_original) - else - DynamicPPL.length_internal(vnv) + length(val) - end - - # Have to use reset!, because setindex! doesn't support decreasing size. - DynamicPPL.reset!(vnv, val .+ 1, vn) - x = DynamicPPL.getindex_internal(vnv, :) - @test vnv[vn] == val .+ 1 - @test DynamicPPL.length_internal(vnv) == expected_length - @test length(x) == DynamicPPL.length_internal(vnv) - @test all( - DynamicPPL.getindex_internal(vnv, i) == x[i] for i in eachindex(x) - ) - end - end - end - - @testset "growing and shrinking" begin - @testset "deterministic" begin - n = 5 - vn = @varname(x) - vnv = DynamicPPL.VarNamedVector(Dict(vn => [true])) - @test !DynamicPPL.has_inactive(vnv) - # Growing should not create inactive ranges. - for i in 1:n - x = fill(true, i) - DynamicPPL.update_internal!(vnv, x, vn, identity) - @test !DynamicPPL.has_inactive(vnv) - end - - # Same size should not create inactive ranges. - x = fill(true, n) - DynamicPPL.update_internal!(vnv, x, vn, identity) - @test !DynamicPPL.has_inactive(vnv) - - # Shrinking should create inactive ranges. - for i in (n - 1):-1:1 - x = fill(true, i) - DynamicPPL.update_internal!(vnv, x, vn, identity) - @test DynamicPPL.has_inactive(vnv) - @test DynamicPPL.num_inactive(vnv, vn) == n - i - end - end - - @testset "random" begin - n = 5 - vn = @varname(x) - vnv = DynamicPPL.VarNamedVector(Dict(vn => [true])) - @test !DynamicPPL.has_inactive(vnv) - - # Insert a bunch of random-length vectors. - for i in 1:100 - x = fill(true, rand(1:n)) - DynamicPPL.update!(vnv, x, vn) - end - # Should never be allocating more than `n` elements. - @test DynamicPPL.num_allocated(vnv, vn) ≤ n - - # If we compaticfy, then it should always be the same size as just inserted. - for i in 1:10 - x = fill(true, rand(1:n)) - DynamicPPL.update!(vnv, x, vn) - DynamicPPL.contiguify!(vnv) - @test DynamicPPL.num_allocated(vnv, vn) == length(x) - end - end - end - - @testset "subset" begin - vnv = DynamicPPL.VarNamedVector(test_pairs) - @test subset(vnv, test_vns) == vnv - @test subset(vnv, VarName[]) == DynamicPPL.VarNamedVector() - @test merge(subset(vnv, test_vns[1:3]), subset(vnv, test_vns[4:end])) == vnv - - # Test that subset preserves transformations and unconstrainedness. - vn = @varname(t[1]) - vns = vcat(test_vns, [vn]) - vnv = DynamicPPL.setindex_internal!!(vnv, [2.0], vn, x -> x .^ 2) - DynamicPPL.set_transformed!(vnv, true, @varname(t[1])) - @test vnv[@varname(t[1])] == [4.0] - @test is_transformed(vnv, @varname(t[1])) - @test subset(vnv, vns) == vnv - end - - @testset "loosen and tighten types" begin - """ - test_tightenability(vnv::VarNamedVector) - - Test that tighten_types!! is a no-op on `vnv`. - """ - function test_tightenability(vnv::DynamicPPL.VarNamedVector) - @test vnv == DynamicPPL.tighten_types!!(deepcopy(vnv)) - # TODO(mhauru) We would like to check something more stringent here, namely that - # the operation is compiled to a direct no-op, with no instructions at all. I - # don't know how to do that though, so for now we just check that it doesn't - # allocate. - @allocations(DynamicPPL.tighten_types!!(vnv)) == 0 - return nothing - end - - vn = @varname(a[1]) - # Test that tighten_types!! is a no-op on an empty VarNamedVector. - vnv = DynamicPPL.VarNamedVector() - @test DynamicPPL.is_tightly_typed(vnv) - test_tightenability(vnv) - # Also check that it literally returns the same object, and both tighten and loosen - # are type stable. - @test vnv === DynamicPPL.tighten_types!!(vnv) - @inferred DynamicPPL.tighten_types!!(vnv) - @inferred DynamicPPL.loosen_types!!(vnv, VarName, Any, Any) - # Likewise for a VarNamedVector with something pushed into it. - vnv = DynamicPPL.VarNamedVector() - vnv = setindex!!(vnv, 1.0, vn) - @test DynamicPPL.is_tightly_typed(vnv) - test_tightenability(vnv) - @test vnv === DynamicPPL.tighten_types!!(vnv) - @inferred DynamicPPL.tighten_types!!(vnv) - @inferred DynamicPPL.loosen_types!!(vnv, VarName, Any, Any) - # Likewise for a VarNamedVector with abstract element-types, when that is needed for - # the current contents because mixed types have been pushed into it. However, this - # time, since the types are only as tight as they can be, but not actually concrete, - # tighten_types!! can't be type stable. - vnv = DynamicPPL.VarNamedVector() - vnv = setindex!!(vnv, 1.0, vn) - vnv = setindex!!(vnv, 2, @varname(b)) - @test !DynamicPPL.is_tightly_typed(vnv) - test_tightenability(vnv) - @inferred DynamicPPL.loosen_types!!(vnv, VarName, Any, Any) - # Likewise when first mixed types are pushed, but then deleted. - vnv = DynamicPPL.VarNamedVector() - vnv = setindex!!(vnv, 1.0, vn) - vnv = setindex!!(vnv, 2, @varname(b)) - @test !DynamicPPL.is_tightly_typed(vnv) - vnv = delete!!(vnv, vn) - @test DynamicPPL.is_tightly_typed(vnv) - test_tightenability(vnv) - @test vnv === DynamicPPL.tighten_types!!(vnv) - @inferred DynamicPPL.tighten_types!!(vnv) - @inferred DynamicPPL.loosen_types!!(vnv, VarName, Any, Any) - - # Test that loosen_types!! does really loosen them and that tighten_types!! reverts - # that. - vnv = DynamicPPL.VarNamedVector() - vnv = setindex!!(vnv, 1.0, vn) - @test DynamicPPL.is_tightly_typed(vnv) - k = eltype(vnv.varnames) - e = eltype(vnv.vals) - t = eltype(vnv.transforms) - # Loosen key type. - vnv = @inferred DynamicPPL.loosen_types!!(vnv, VarName, e, t) - @test !DynamicPPL.is_tightly_typed(vnv) - vnv = DynamicPPL.tighten_types!!(vnv) - @test DynamicPPL.is_tightly_typed(vnv) - # Loosen element type - vnv = @inferred DynamicPPL.loosen_types!!(vnv, k, Real, t) - @test !DynamicPPL.is_tightly_typed(vnv) - vnv = DynamicPPL.tighten_types!!(vnv) - @test DynamicPPL.is_tightly_typed(vnv) - # Loosen transformation type - vnv = @inferred DynamicPPL.loosen_types!!(vnv, k, e, Function) - @test !DynamicPPL.is_tightly_typed(vnv) - vnv = DynamicPPL.tighten_types!!(vnv) - @test DynamicPPL.is_tightly_typed(vnv) - # Loosening to the same types as currently should do nothing. - vnv = @inferred DynamicPPL.loosen_types!!(vnv, k, e, t) - @test DynamicPPL.is_tightly_typed(vnv) - @allocations(DynamicPPL.loosen_types!!(vnv, k, e, t)) == 0 - end -end - -@testset "VarInfo + VarNamedVector" begin - models = DynamicPPL.TestUtils.ALL_MODELS - @testset "$(model.f)" for model in models - # NOTE: Need to set random seed explicitly to avoid using the same seed - # for initialization as for sampling in the inner testset below. - Random.seed!(42) - value_true = DynamicPPL.TestUtils.rand_prior_true(model) - vns = DynamicPPL.TestUtils.varnames(model) - varnames = DynamicPPL.TestUtils.varnames(model) - varinfos = DynamicPPL.TestUtils.setup_varinfos( - model, value_true, varnames; include_threadsafe=false - ) - # Filter out those which are not based on `VarNamedVector`. - varinfos = filter(DynamicPPL.has_varnamedvector, varinfos) - # Get the true log joint. - logp_true = DynamicPPL.TestUtils.logjoint_true(model, value_true...) - - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - # Need to make sure we're using a different random seed from the - # one used in the above call to `rand_prior_true`. - Random.seed!(43) - - # Are values correct? - DynamicPPL.TestUtils.test_values(varinfo, value_true, vns) - - # Is evaluation correct? - varinfo_eval = last(DynamicPPL.evaluate!!(model, deepcopy(varinfo))) - # Log density should be the same. - @test getlogjoint(varinfo_eval) ≈ logp_true - # Values should be the same. - DynamicPPL.TestUtils.test_values(varinfo_eval, value_true, vns) - - # Is sampling correct? - varinfo_sample = last(DynamicPPL.init!!(model, deepcopy(varinfo))) - # Log density should be different. - @test getlogjoint(varinfo_sample) != getlogjoint(varinfo) - # Values should be different. - DynamicPPL.TestUtils.test_values( - varinfo_sample, value_true, vns; compare=!isequal - ) - end - end -end From 8ba36f6dffd09c69aaa395bf27245d18ef283de9 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 13 Jan 2026 14:10:10 +0000 Subject: [PATCH 124/148] Fix a lot of doctests --- src/abstract_varinfo.jl | 44 ++++++++++++++++++++------------------- src/model.jl | 8 +++---- src/values_as_in_model.jl | 23 ++++++++------------ src/vntvarinfo.jl | 4 ++++ 4 files changed, 40 insertions(+), 39 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 1c5159626..c4af10898 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -506,13 +506,15 @@ If no `Type` is provided, return values as stored in `varinfo`. julia> # Just use an example model to construct the `VarInfo` because we're lazy. vi = DynamicPPL.VarInfo(DynamicPPL.TestUtils.demo_assume_dot_observe()); -julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0; +julia> vi = DynamicPPL.setindex!!(vi, 1.0, @varname(s)); + +julia> vi = DynamicPPL.setindex!!(vi, 2.0, @varname(m)); julia> values_as(vi, NamedTuple) (s = 1.0, m = 2.0) julia> values_as(vi, OrderedDict) -OrderedDict{VarName{sym, typeof(identity)} where sym, Float64} with 2 entries: +OrderedDict{Any, Any} with 2 entries: s => 1.0 m => 2.0 @@ -570,20 +572,20 @@ demo (generic function with 2 methods) julia> model = demo(); -julia> varinfo = VarInfo(model); +julia> vi = VarInfo(model); -julia> keys(varinfo) +julia> keys(vi) 4-element Vector{VarName}: s m x[1] x[2] -julia> for (i, vn) in enumerate(keys(varinfo)) - varinfo[vn] = i +julia> for (i, vn) in enumerate(keys(vi)) + vi = DynamicPPL.setindex!!(vi, Float64(i), vn) end -julia> varinfo[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]] +julia> vi[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]] 4-element Vector{Float64}: 1.0 2.0 @@ -591,59 +593,59 @@ julia> varinfo[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]] 4.0 julia> # Extract one with only `m`. - varinfo_subset1 = subset(varinfo, [@varname(m),]); + vi_subset1 = subset(vi, [@varname(m),]); -julia> keys(varinfo_subset1) -1-element Vector{VarName{:m, typeof(identity)}}: +julia> keys(vi_subset1) +1-element Vector{VarName}: m -julia> varinfo_subset1[@varname(m)] +julia> vi_subset1[@varname(m)] 2.0 julia> # Extract one with both `s` and `x[2]`. - varinfo_subset2 = subset(varinfo, [@varname(s), @varname(x[2])]); + vi_subset2 = subset(vi, [@varname(s), @varname(x[2])]); -julia> keys(varinfo_subset2) +julia> keys(vi_subset2) 2-element Vector{VarName}: s x[2] -julia> varinfo_subset2[[@varname(s), @varname(x[2])]] +julia> vi_subset2[[@varname(s), @varname(x[2])]] 2-element Vector{Float64}: 1.0 4.0 ``` -`subset` is particularly useful when combined with [`merge(varinfo::AbstractVarInfo)`](@ref) +`subset` is particularly useful when combined with [`merge(vi::AbstractVarInfo)`](@ref) ```jldoctest varinfo-subset julia> # Merge the two. - varinfo_subset_merged = merge(varinfo_subset1, varinfo_subset2); + vi_subset_merged = merge(vi_subset1, vi_subset2); -julia> keys(varinfo_subset_merged) +julia> keys(vi_subset_merged) 3-element Vector{VarName}: m s x[2] -julia> varinfo_subset_merged[[@varname(s), @varname(m), @varname(x[2])]] +julia> vi_subset_merged[[@varname(s), @varname(m), @varname(x[2])]] 3-element Vector{Float64}: 1.0 2.0 4.0 julia> # Merge the two with the original. - varinfo_merged = merge(varinfo, varinfo_subset_merged); + vi_merged = merge(vi, vi_subset_merged); -julia> keys(varinfo_merged) +julia> keys(vi_merged) 4-element Vector{VarName}: s m x[1] x[2] -julia> varinfo_merged[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]] +julia> vi_merged[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]] 4-element Vector{Float64}: 1.0 2.0 diff --git a/src/model.jl b/src/model.jl index cd36ee44b..7d65df842 100644 --- a/src/model.jl +++ b/src/model.jl @@ -501,7 +501,7 @@ true julia> # Since we conditioned on `a.m`, it is not treated as a random variable. # However, `a.x` will still be a random variable. keys(VarInfo(cm)) -1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: +1-element Vector{VarName}: a.x julia> # We can also condition on `a.m` _outside_ of the PrefixContext: @@ -513,7 +513,7 @@ Dict{VarName{:a, Accessors.PropertyLens{:m}}, Float64} with 1 entry: julia> # Now `a.x` will be sampled. keys(VarInfo(cm)) -1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: +1-element Vector{VarName}: a.x ``` """ @@ -839,7 +839,7 @@ julia> Set(keys(fixed(cm))) == Set([@varname(a.m), @varname(x)]) true julia> keys(VarInfo(cm)) -1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: +1-element Vector{VarName}: a.x julia> # We can also condition on `a.m` _outside_ of the PrefixContext: @@ -851,7 +851,7 @@ Dict{VarName{:a, Accessors.PropertyLens{:m}}, Float64} with 1 entry: julia> # Now `a.x` will be sampled. keys(VarInfo(cm)) -1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: +1-element Vector{VarName}: a.x ``` """ diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index f7440d6ff..304b99a3e 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -107,35 +107,30 @@ julia> @model function model_changing_support() julia> model = model_changing_support(); -julia> # Construct initial type-stable `VarInfo`. +julia> # Construct initial `VarInfo`. varinfo = VarInfo(rng, model); julia> # Link it so it works in unconstrained space. - varinfo_linked = DynamicPPL.link(varinfo, model); + varinfo_linked = DynamicPPL.link!!(copy(varinfo), model); -julia> # Perform computations in unconstrained space, e.g. changing the values of `θ`. +julia> # Perform computations in unconstrained space, e.g. changing the values of `vals`. # Flip `x` so we hit the other support of `y`. - θ = [!varinfo[@varname(x)], rand(rng)]; + vals = [!varinfo[@varname(x)], rand(rng)]; julia> # Update the `VarInfo` with the new values. - varinfo_linked = DynamicPPL.unflatten!!(varinfo_linked, θ); + varinfo_linked = DynamicPPL.unflatten!!(varinfo_linked, vals); julia> # Determine the expected support of `y`. - lb, ub = θ[1] == 1 ? (0, 1) : (11, 12) + lb, ub = vals[1] == 1 ? (0, 1) : (11, 12) (0, 1) julia> # Approach 1: Convert back to constrained space using `invlink` and extract. - varinfo_invlinked = DynamicPPL.invlink(varinfo_linked, model); + varinfo_invlinked = DynamicPPL.invlink!!(copy(varinfo_linked), model); -julia> # (×) Fails! Because `VarInfo` _saves_ the original distributions - # used in the very first model evaluation, hence the support of `y` - # is not updated even though `x` has changed. - lb ≤ first(varinfo_invlinked[@varname(y)]) ≤ ub -false +julia> lb ≤ first(varinfo_invlinked[@varname(y)]) ≤ ub +true julia> # Approach 2: Extract realizations using `values_as_in_model`. - # (✓) `values_as_in_model` will re-run the model and extract - # the correct realization of `y` given the new values of `x`. lb ≤ values_as_in_model(model, true, varinfo_linked)[@varname(y)] ≤ ub true ``` diff --git a/src/vntvarinfo.jl b/src/vntvarinfo.jl index b0cafa364..a7eafc460 100644 --- a/src/vntvarinfo.jl +++ b/src/vntvarinfo.jl @@ -54,6 +54,10 @@ function Base.getindex(vi::VNTVarInfo, vn::VarName) return tv.transform(tv.val) end +function Base.getindex(vi::VNTVarInfo, vns::Vector{<:VarName}) + return [getindex(vi, vn) for vn in vns] +end + function Base.getindex(vi::VNTVarInfo, vn::VarName, dist::Distribution) val = getindex_internal(vi, vn) return from_maybe_linked_internal(vi, vn, dist, val) From 1f6335db7c097ff393696723bf4c64179486e68b Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 13 Jan 2026 14:11:03 +0000 Subject: [PATCH 125/148] Rename vntvarinfo.jl to varinfo.jl --- src/DynamicPPL.jl | 2 +- src/{vntvarinfo.jl => varinfo.jl} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename src/{vntvarinfo.jl => varinfo.jl} (100%) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index b5a77be03..d6f4025ca 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -199,7 +199,7 @@ include("accumulators.jl") include("default_accumulators.jl") include("abstract_varinfo.jl") include("threadsafe.jl") -include("vntvarinfo.jl") +include("varinfo.jl") include("onlyaccs.jl") include("compiler.jl") include("pointwise_logdensities.jl") diff --git a/src/vntvarinfo.jl b/src/varinfo.jl similarity index 100% rename from src/vntvarinfo.jl rename to src/varinfo.jl From dbcf5f646b70a965fd739f3a0accca05e483c564 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 13 Jan 2026 14:14:30 +0000 Subject: [PATCH 126/148] Rename VNTVarInfo to VarInfo --- src/logdensityfunction.jl | 2 +- src/varinfo.jl | 135 ++++++++++++++++++-------------------- test/test_util.jl | 4 +- 3 files changed, 68 insertions(+), 73 deletions(-) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 4f8ac4933..17101d0d2 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -303,7 +303,7 @@ representation, along with whether each variable is linked or unlinked. This function returns a VarNamedTuple mapping all VarNames to their corresponding `RangeAndLinked`. """ -function get_ranges_and_linked(vi::VNTVarInfo) +function get_ranges_and_linked(vi::VarInfo) # TODO(mhauru) Check that the closure doesn't cause type instability here. vnt = VarNamedTuple() vnt, _ = mapreduce( diff --git a/src/varinfo.jl b/src/varinfo.jl index a7eafc460..37728dca2 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1,11 +1,8 @@ -struct VNTVarInfo{T<:VarNamedTuple,Accs<:AccumulatorTuple} <: AbstractVarInfo +struct VarInfo{T<:VarNamedTuple,Accs<:AccumulatorTuple} <: AbstractVarInfo values::T accs::Accs end -# TODO(mhauru) Make this renaming permanent. -const VarInfo = VNTVarInfo - struct TransformedValue{ValType,TransformType,SizeType} val::ValType linked::Bool @@ -15,9 +12,9 @@ end VarNamedTuples.vnt_size(tv::TransformedValue) = tv.size -VNTVarInfo() = VNTVarInfo(VarNamedTuple(), default_accumulators()) +VarInfo() = VarInfo(VarNamedTuple(), default_accumulators()) -function VNTVarInfo(values::Union{NamedTuple,AbstractDict}) +function VarInfo(values::Union{NamedTuple,AbstractDict}) vi = VarInfo() for (k, v) in pairs(values) vn = k isa Symbol ? VarName{k}() : k @@ -26,52 +23,52 @@ function VNTVarInfo(values::Union{NamedTuple,AbstractDict}) return vi end -function VNTVarInfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) - return VNTVarInfo(Random.default_rng(), model, init_strategy) +function VarInfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) + return VarInfo(Random.default_rng(), model, init_strategy) end -function VNTVarInfo( +function VarInfo( rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=InitFromPrior(), ) - return last(init!!(rng, model, VNTVarInfo(), init_strategy)) + return last(init!!(rng, model, VarInfo(), init_strategy)) end -getaccs(vi::VNTVarInfo) = vi.accs -setaccs!!(vi::VNTVarInfo, accs::AccumulatorTuple) = VNTVarInfo(vi.values, accs) +getaccs(vi::VarInfo) = vi.accs +setaccs!!(vi::VarInfo, accs::AccumulatorTuple) = VarInfo(vi.values, accs) -transformation(::VNTVarInfo) = DynamicTransformation() +transformation(::VarInfo) = DynamicTransformation() -Base.copy(vi::VNTVarInfo) = VNTVarInfo(copy(vi.values), copy(getaccs(vi))) +Base.copy(vi::VarInfo) = VarInfo(copy(vi.values), copy(getaccs(vi))) -Base.haskey(vi::VNTVarInfo, vn::VarName) = haskey(vi.values, vn) +Base.haskey(vi::VarInfo, vn::VarName) = haskey(vi.values, vn) -Base.length(vi::VNTVarInfo) = length(vi.values) +Base.length(vi::VarInfo) = length(vi.values) -function Base.getindex(vi::VNTVarInfo, vn::VarName) +function Base.getindex(vi::VarInfo, vn::VarName) tv = getindex(vi.values, vn) return tv.transform(tv.val) end -function Base.getindex(vi::VNTVarInfo, vns::Vector{<:VarName}) +function Base.getindex(vi::VarInfo, vns::Vector{<:VarName}) return [getindex(vi, vn) for vn in vns] end -function Base.getindex(vi::VNTVarInfo, vn::VarName, dist::Distribution) +function Base.getindex(vi::VarInfo, vn::VarName, dist::Distribution) val = getindex_internal(vi, vn) return from_maybe_linked_internal(vi, vn, dist, val) end -Base.isempty(vi::VNTVarInfo) = isempty(vi.values) -Base.empty(vi::VNTVarInfo) = VNTVarInfo(empty(vi.values), map(reset, vi.accs)) -BangBang.empty!!(vi::VNTVarInfo) = VNTVarInfo(empty!!(vi.values), map(reset, vi.accs)) +Base.isempty(vi::VarInfo) = isempty(vi.values) +Base.empty(vi::VarInfo) = VarInfo(empty(vi.values), map(reset, vi.accs)) +BangBang.empty!!(vi::VarInfo) = VarInfo(empty!!(vi.values), map(reset, vi.accs)) -function setindex_internal!!(vi::VNTVarInfo, val, vn::VarName) +function setindex_internal!!(vi::VarInfo, val, vn::VarName) old_tv = getindex(vi.values, vn) new_tv = TransformedValue(val, old_tv.linked, old_tv.transform, old_tv.size) new_values = setindex!!(vi.values, new_tv, vn) - return VNTVarInfo(new_values, vi.accs) + return VarInfo(new_values, vi.accs) end # TODO(mhauru) It shouldn't really be VarInfo's business to know about `dist`. However, @@ -80,7 +77,7 @@ end # of doing the transformation to the caller, it'll be done even when e.g. using # OnlyAccsVarInfo. Hence having this function. It should eventually hopefully be removed # once VAIMAcc is the only way to get values out of an evaluation. -function setindex_with_dist!!(vi::VNTVarInfo, val, dist::Distribution, vn::VarName) +function setindex_with_dist!!(vi::VarInfo, val, dist::Distribution, vn::VarName) # Determine whether to insert a transformed value into `vi`. # If the VarInfo alrady had a value for this variable, we will # keep the same linked status as in the original VarInfo. If not, we @@ -98,81 +95,81 @@ function setindex_with_dist!!(vi::VNTVarInfo, val, dist::Distribution, vn::VarNa transformed_val, logjac = with_logabsdet_jacobian(inverse(transform), val) val_size = hasmethod(size, Tuple{typeof(val)}) ? size(val) : () tv = TransformedValue(transformed_val, insert_transformed_value, transform, val_size) - vi = VNTVarInfo(setindex!!(vi.values, tv, vn), vi.accs) + vi = VarInfo(setindex!!(vi.values, tv, vn), vi.accs) return vi, logjac end -function BangBang.setindex!!(vi::VNTVarInfo, val, vn::VarName) +function BangBang.setindex!!(vi::VarInfo, val, vn::VarName) transform = from_vec_transform(val) transformed_val = inverse(transform)(val) tv = TransformedValue(transformed_val, false, transform, size(val)) - return VNTVarInfo(setindex!!(vi.values, tv, vn), vi.accs) + return VarInfo(setindex!!(vi.values, tv, vn), vi.accs) end -Base.keys(vi::VNTVarInfo) = keys(vi.values) -Base.values(vi::VNTVarInfo) = mapreduce(p -> p.second.val, push!, vi.values; init=Any[]) +Base.keys(vi::VarInfo) = keys(vi.values) +Base.values(vi::VarInfo) = mapreduce(p -> p.second.val, push!, vi.values; init=Any[]) -function set_transformed!!(vi::VNTVarInfo, linked::Bool, vn::VarName) +function set_transformed!!(vi::VarInfo, linked::Bool, vn::VarName) old_tv = getindex(vi.values, vn) new_tv = TransformedValue(old_tv.val, linked, old_tv.transform, old_tv.size) new_values = setindex!!(vi.values, new_tv, vn) - return VNTVarInfo(new_values, vi.accs) + return VarInfo(new_values, vi.accs) end -# VNTVarInfo does not care whether the transformation was Static or Dynamic, it just tracks +# VarInfo does not care whether the transformation was Static or Dynamic, it just tracks # whether one was applied at all. -function set_transformed!!(vi::VNTVarInfo, ::AbstractTransformation, vn::VarName) +function set_transformed!!(vi::VarInfo, ::AbstractTransformation, vn::VarName) return set_transformed!!(vi, true, vn) end -set_transformed!!(vi::VNTVarInfo, ::AbstractTransformation) = set_transformed!!(vi, true) +set_transformed!!(vi::VarInfo, ::AbstractTransformation) = set_transformed!!(vi, true) -function set_transformed!!(vi::VNTVarInfo, ::NoTransformation, vn::VarName) +function set_transformed!!(vi::VarInfo, ::NoTransformation, vn::VarName) return set_transformed!!(vi, false, vn) end -set_transformed!!(vi::VNTVarInfo, ::NoTransformation) = set_transformed!!(vi, false) +set_transformed!!(vi::VarInfo, ::NoTransformation) = set_transformed!!(vi, false) -function set_transformed!!(vi::VNTVarInfo, linked::Bool) +function set_transformed!!(vi::VarInfo, linked::Bool) new_values = map_values!!(vi.values) do tv TransformedValue(tv.val, linked, tv.transform, tv.size) end - return VNTVarInfo(new_values, vi.accs) + return VarInfo(new_values, vi.accs) end -function getindex_internal(vi::VNTVarInfo, vn::VarName) +function getindex_internal(vi::VarInfo, vn::VarName) tv = getindex(vi.values, vn) return tv.val end # TODO(mhauru) This is mimicing old behaviour, but is now wrong: The internal # representation does not have to be a Vector. -getindex_internal(vi::VNTVarInfo, ::Colon) = values_as(vi, Vector) +getindex_internal(vi::VarInfo, ::Colon) = values_as(vi, Vector) -function is_transformed(vi::VNTVarInfo, vn::VarName) +function is_transformed(vi::VarInfo, vn::VarName) tv = getindex(vi.values, vn) return tv.linked end # TODO(mhauru) Other VarInfos have something like this. Do we need it? Or should we use the # below version? -function from_internal_transform(::VNTVarInfo, ::VarName, dist::Distribution) +function from_internal_transform(::VarInfo, ::VarName, dist::Distribution) return from_vec_transform(dist) end -# function from_internal_transform(vi::VNTVarInfo, vn::VarName, ::Distribution) +# function from_internal_transform(vi::VarInfo, vn::VarName, ::Distribution) # return getindex(vi.values, vn).transform # end -function from_linked_internal_transform(::VNTVarInfo, ::VarName, dist::Distribution) +function from_linked_internal_transform(::VarInfo, ::VarName, dist::Distribution) return from_linked_vec_transform(dist) end -function from_linked_internal_transform(vi::VNTVarInfo, vn::VarName) +function from_linked_internal_transform(vi::VarInfo, vn::VarName) return getindex(vi.values, vn).transform end -function link!!(::DynamicTransformation, vi::VNTVarInfo, vns, model::Model) +function link!!(::DynamicTransformation, vi::VarInfo, vns, model::Model) dists = extract_priors(model, vi) cumulative_logjac = zero(LogProbType) new_values = map_pairs!!(vi.values) do pair @@ -192,18 +189,18 @@ function link!!(::DynamicTransformation, vi::VNTVarInfo, vns, model::Model) cumulative_logjac += logjac1 + logjac2 return new_tv end - vi = VNTVarInfo(new_values, vi.accs) + vi = VarInfo(new_values, vi.accs) if hasacc(vi, Val(:LogJacobian)) vi = acclogjac!!(vi, cumulative_logjac) end return vi end -function link!!(t::DynamicTransformation, vi::VNTVarInfo, model::Model) +function link!!(t::DynamicTransformation, vi::VarInfo, model::Model) return link!!(t, vi, nothing, model) end -function invlink!!(::DynamicTransformation, vi::VNTVarInfo, vns, model::Model) +function invlink!!(::DynamicTransformation, vi::VarInfo, vns, model::Model) dists = extract_priors(model, vi) cumulative_logjac = zero(LogProbType) new_values = map_pairs!!(vi.values) do pair @@ -224,18 +221,18 @@ function invlink!!(::DynamicTransformation, vi::VNTVarInfo, vns, model::Model) cumulative_logjac += logjac1 + logjac2 return new_tv end - vi = VNTVarInfo(new_values, vi.accs) + vi = VarInfo(new_values, vi.accs) if hasacc(vi, Val(:LogJacobian)) vi = acclogjac!!(vi, cumulative_logjac) end return vi end -function invlink!!(t::DynamicTransformation, vi::VNTVarInfo, model::Model) +function invlink!!(t::DynamicTransformation, vi::VarInfo, model::Model) return invlink!!(t, vi, nothing, model) end -function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VNTVarInfo}, model::Model) +function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, model::Model) # By default this will simply evaluate the model with `DynamicTransformationContext`, # and so we need to specialize to avoid this. return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, model) @@ -243,7 +240,7 @@ end function link!!( t::DynamicTransformation, - vi::ThreadSafeVarInfo{<:VNTVarInfo}, + vi::ThreadSafeVarInfo{<:VarInfo}, vns::VarNameTuple, model::Model, ) @@ -252,9 +249,7 @@ function link!!( return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, vns, model) end -function invlink!!( - t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VNTVarInfo}, model::Model -) +function invlink!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, model::Model) # By default this will simply evaluate the model with `DynamicTransformationContext`, # and so we need to specialize to avoid this. return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(t, vi.varinfo, model) @@ -262,7 +257,7 @@ end function invlink!!( ::DynamicTransformation, - vi::ThreadSafeVarInfo{<:VNTVarInfo}, + vi::ThreadSafeVarInfo{<:VarInfo}, vns::VarNameTuple, model::Model, ) @@ -273,11 +268,11 @@ end # TODO(mhauru) I don't think this should return the internal values, but that's the current # convention. -function values_as(vi::VNTVarInfo, ::Type{Vector}) +function values_as(vi::VarInfo, ::Type{Vector}) return mapfoldl(pair -> tovec(pair.second.val), vcat, vi.values; init=Union{}[]) end -function values_as(vi::VNTVarInfo, ::Type{T}) where {T<:AbstractDict} +function values_as(vi::VarInfo, ::Type{T}) where {T<:AbstractDict} return mapfoldl(identity, function (cumulant, pair) vn, tv = pair val = tv.transform(tv.val) @@ -289,7 +284,7 @@ end # interface provided by rand(::Model). We should change that to return a VarNamedTuple # instead, and then this method (and any other values_as methods for NamedTuple) could be # removed. -function values_as(vi::VNTVarInfo, ::Type{NamedTuple}) +function values_as(vi::VarInfo, ::Type{NamedTuple}) return mapfoldl( identity, function (cumulant, pair) @@ -309,7 +304,7 @@ function untyped_varinfo( model::Model, init_strategy::AbstractInitStrategy=InitFromPrior(), ) - return VNTVarInfo(rng, model, init_strategy) + return VarInfo(rng, model, init_strategy) end function typed_varinfo( @@ -317,10 +312,10 @@ function typed_varinfo( model::Model, init_strategy::AbstractInitStrategy=InitFromPrior(), ) - return VNTVarInfo(rng, model, init_strategy) + return VarInfo(rng, model, init_strategy) end -typed_varinfo(vi::VNTVarInfo) = vi +typed_varinfo(vi::VarInfo) = vi function typed_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) return typed_varinfo(Random.default_rng(), model, init_strategy) @@ -350,7 +345,7 @@ function get_next_chunk!(vci::VectorChunkIterator, len::Int) return chunk end -function unflatten!!(vi::VNTVarInfo, vec::AbstractVector) +function unflatten!!(vi::VarInfo, vec::AbstractVector) # You may wonder, why have a whole struct for this, rather than just an index variable # that the mapping function would close over. I wonder too. But for some reason type # inference fails on such an index variable, turning it into a Core.Box. @@ -367,16 +362,16 @@ function unflatten!!(vi::VNTVarInfo, vec::AbstractVector) new_val = get_next_chunk!(vci, len) return TransformedValue(new_val, tv.linked, tv.transform, tv.size) end - return VNTVarInfo(new_values, vi.accs) + return VarInfo(new_values, vi.accs) end -function subset(varinfo::VNTVarInfo, vns) +function subset(varinfo::VarInfo, vns) new_values = subset(varinfo.values, vns) - return VNTVarInfo(new_values, map(copy, getaccs(varinfo))) + return VarInfo(new_values, map(copy, getaccs(varinfo))) end -function Base.merge(varinfo_left::VNTVarInfo, varinfo_right::VNTVarInfo) +function Base.merge(varinfo_left::VarInfo, varinfo_right::VarInfo) new_values = merge(varinfo_left.values, varinfo_right.values) new_accs = map(copy, getaccs(varinfo_right)) - return VNTVarInfo(new_values, new_accs) + return VarInfo(new_values, new_accs) end diff --git a/test/test_util.jl b/test/test_util.jl index 9f6939adf..8f402ad8f 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -16,8 +16,8 @@ Return string representing a short description of `vi`. function short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) return "threadsafe($(short_varinfo_name(vi.varinfo)))" end -function short_varinfo_name(::DynamicPPL.VNTVarInfo) - return "VNTVarInfo" +function short_varinfo_name(::DynamicPPL.VarInfo) + return "VarInfo" end # convenient functions for testing model.jl From 0edaa53e9acd366e75acd595511ecad147734684 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 13 Jan 2026 14:43:03 +0000 Subject: [PATCH 127/148] Remove (un)typed_varinfo --- src/chains.jl | 8 -------- src/test_utils/contexts.jl | 27 +++++++++++---------------- src/varinfo.jl | 28 ---------------------------- 3 files changed, 11 insertions(+), 52 deletions(-) diff --git a/src/chains.jl b/src/chains.jl index ca653fff9..ee4312547 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -38,7 +38,6 @@ function ParamsWithStats( include_colon_eq::Bool=true, include_log_probs::Bool=true, ) - varinfo = maybe_to_typed_varinfo(varinfo) accs = if include_log_probs ( DynamicPPL.LogPriorAccumulator(), @@ -64,13 +63,6 @@ function ParamsWithStats( return ParamsWithStats(params, stats) end -# Re-evaluating the model is unconscionably slow for untyped VarInfo. It's much faster to -# convert it to a typed varinfo first, hence this method. -# https://github.com/TuringLang/Turing.jl/issues/2604 -# maybe_to_typed_varinfo(vi::UntypedVarInfo) = typed_varinfo(vi) -# maybe_to_typed_varinfo(vi::UntypedVectorVarInfo) = typed_vector_varinfo(vi) -maybe_to_typed_varinfo(vi::AbstractVarInfo) = vi - """ ParamsWithStats( varinfo::AbstractVarInfo, diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index c48d2ddfd..cceedee8c 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -36,16 +36,12 @@ function test_leaf_context(context::DynamicPPL.AbstractContext, model::DynamicPP # varinfos.) Thus we only test evaluation with VarInfos that are already # filled with values. @testset "evaluation" begin - # Generate a new filled untyped varinfo - _, untyped_vi = DynamicPPL.init!!(model, DynamicPPL.VarInfo()) - typed_vi = DynamicPPL.typed_varinfo(untyped_vi) + # Generate a new filled varinfo + _, vi = DynamicPPL.init!!(model, DynamicPPL.VarInfo()) # Set the test context as the new leaf context new_model = DynamicPPL.setleafcontext(model, context) - # Check that evaluation works - for vi in [untyped_vi, typed_vi] - _, vi = DynamicPPL.evaluate!!(new_model, vi) - @test vi isa DynamicPPL.VarInfo - end + _, vi = DynamicPPL.evaluate!!(new_model, vi) + @test vi isa DynamicPPL.VarInfo end end @@ -73,13 +69,12 @@ function test_parent_context(context::DynamicPPL.AbstractContext, model::Dynamic @testset "initialisation and evaluation" begin new_model = contextualize(model, context) - for vi in [DynamicPPL.VarInfo(), DynamicPPL.typed_varinfo(DynamicPPL.VarInfo())] - # Initialisation - _, vi = DynamicPPL.init!!(new_model, DynamicPPL.VarInfo()) - @test vi isa DynamicPPL.VarInfo - # Evaluation - _, vi = DynamicPPL.evaluate!!(new_model, vi) - @test vi isa DynamicPPL.VarInfo - end + vi = DynamicPPL.VarInfo() + # Initialisation + _, vi = DynamicPPL.init!!(new_model, DynamicPPL.VarInfo()) + @test vi isa DynamicPPL.VarInfo + # Evaluation + _, vi = DynamicPPL.evaluate!!(new_model, vi) + @test vi isa DynamicPPL.VarInfo end end diff --git a/src/varinfo.jl b/src/varinfo.jl index 37728dca2..170181b80 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -297,34 +297,6 @@ function values_as(vi::VarInfo, ::Type{NamedTuple}) ) end -# TODO(mhauru) These two are now redundant, just conforming to the old interface -# temporarily. -function untyped_varinfo( - rng::Random.AbstractRNG, - model::Model, - init_strategy::AbstractInitStrategy=InitFromPrior(), -) - return VarInfo(rng, model, init_strategy) -end - -function typed_varinfo( - rng::Random.AbstractRNG, - model::Model, - init_strategy::AbstractInitStrategy=InitFromPrior(), -) - return VarInfo(rng, model, init_strategy) -end - -typed_varinfo(vi::VarInfo) = vi - -function typed_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) - return typed_varinfo(Random.default_rng(), model, init_strategy) -end - -function untyped_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) - return untyped_varinfo(Random.default_rng(), model, init_strategy) -end - """ VectorChunkIterator{T<:AbstractVector} From c2748a79a33d2234a081000479d2cfbf8369f89b Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 13 Jan 2026 14:47:21 +0000 Subject: [PATCH 128/148] Add docstrings to varinfo.jl --- src/varinfo.jl | 135 +++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 119 insertions(+), 16 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 170181b80..688c90b03 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1,12 +1,67 @@ +""" + VarInfo{T<:VarNamedTuple,Accs<:AccumulatorTuple} <: AbstractVarInfo + +The default implementation of `AbstractVarInfo`, storing variable values and accumulators. + +`VarInfo` is quite a thin wrapper around a `VarNamedTuple` storing the variable values, +and a tuple of accumulators. The only really noteworthy thing about it is that it stores +the values of variables vectorised as instances of `TransformedValue`. That is, it stores +each value as a vector and a transformation to be applied to that vector to get the actual +value. It also stores whether the transformation is such that it guarantees all real vectors +to be valid internal representations of the variable (i.e., whether the variable has been +linked), as well as the size of the actual post-transformation value. These are all fields +of [`TransformedValue`](@ref). + +Note that `setindex!!` and `getindex` on `VarInfo` deal with the actual values of variables. +To get access to the internal vectorised values, use [`getindex_internal`](@ref), +[`setindex_internal!!`](@ref), and [`unflatten!!`](@ref). + +There's also a `VarInfo`-specific function [`setindex_with_dist!!`](@ref), which sets a +variable's value with a transformation based on the statistical distribution this value is +a sample for. + +For more details on the internal storage, see documentation of [`TransformedValue`](@ref) and +[`VarNamedTuple`](@ref). + +# Fields +$(TYPEDFIELDS) + +""" struct VarInfo{T<:VarNamedTuple,Accs<:AccumulatorTuple} <: AbstractVarInfo values::T accs::Accs end +# TODO(mhauru) The policy of vectorising all values was set when the old VarInfo type was +# using a Vector as the internal storage in all cases. We should revisit this, and allow +# values to be stored "raw", since VarNamedTuple supports it. + +# TODO(mhauru) Related to the above, I think we should reconsider whether we should store +# transformations at all. We rarely use them, since they may be dynamic in a model. +# tilde_assume!! rather gets the transformation from the current distribution encountered +# during model execution. However, this would change the interface quite a lot, so I want to +# finish implementing VarInfo using VNT (mostly) respecting the old interface first. + +""" + TransformedValue{ValType,TransformType,SizeType} + +A struct for storing a variable's value in its internal (vectorised) form. + +# Fields +$(TYPEDFIELDS) +""" struct TransformedValue{ValType,TransformType,SizeType} + "The internal (vectorised) value." val::ValType + """Boolean indicating whether the variable is linked, i.e. the transformation maps all + real vectors to valid values.""" linked::Bool + """The transformation from internal (vectorised) to actual value. In other words, the + actual value of the variable being stored is `transform(val)`.""" transform::TransformType + """The size of the actual value after transformation. This is needed when a + TransformedValue is stored as a block in an array (see [`PartialArray`](@ref) in + `VarNamedTuples`).""" size::SizeType end @@ -41,10 +96,10 @@ setaccs!!(vi::VarInfo, accs::AccumulatorTuple) = VarInfo(vi.values, accs) transformation(::VarInfo) = DynamicTransformation() Base.copy(vi::VarInfo) = VarInfo(copy(vi.values), copy(getaccs(vi))) - Base.haskey(vi::VarInfo, vn::VarName) = haskey(vi.values, vn) - Base.length(vi::VarInfo) = length(vi.values) +Base.keys(vi::VarInfo) = keys(vi.values) +Base.values(vi::VarInfo) = mapreduce(p -> p.second.val, push!, vi.values; init=Any[]) function Base.getindex(vi::VarInfo, vn::VarName) tv = getindex(vi.values, vn) @@ -64,6 +119,13 @@ Base.isempty(vi::VarInfo) = isempty(vi.values) Base.empty(vi::VarInfo) = VarInfo(empty(vi.values), map(reset, vi.accs)) BangBang.empty!!(vi::VarInfo) = VarInfo(empty!!(vi.values), map(reset, vi.accs)) +""" + setindex_internal!!(vi::VarInfo, val, vn::VarName) + +Set the internal (vectorised) value of variable `vn` in `vi` to `val`. + +This does not change the transformation or linked status of the variable. +""" function setindex_internal!!(vi::VarInfo, val, vn::VarName) old_tv = getindex(vi.values, vn) new_tv = TransformedValue(val, old_tv.linked, old_tv.transform, old_tv.size) @@ -73,10 +135,23 @@ end # TODO(mhauru) It shouldn't really be VarInfo's business to know about `dist`. However, # we need `dist` to determine the linking transformation (or even just the vectorisation -# transformation, in the case of ProductNamedTupleDistribions), and if we leave the work -# of doing the transformation to the caller, it'll be done even when e.g. using -# OnlyAccsVarInfo. Hence having this function. It should eventually hopefully be removed -# once VAIMAcc is the only way to get values out of an evaluation. +# transformation in the case of ProductNamedTupleDistribions), and if we leave the work +# of doing the transformation to the caller (tilde_assume!!), it'll be done even when e.g. +# using OnlyAccsVarInfo. Hence having this function. It should eventually hopefully be +# removed once VAIMAcc is the only way to get values out of an evaluation. +""" + setindex_with_dist!!(vi::VarInfo, val, dist::Distribution, vn::VarName) + +Set the value of `vn` in `vi` to `val`, applying a transformation based on `dist`. + +`val` is taken to be the actual value of the variable, and is transformed into the internal +(vectorised) representation using a transformation based on `dist`. If the variable is +linked in `vi`, or doesn't exist in `vi` but all other variables in `vi` are linked, the +linking transformation is used; otherwise, the standard vector transformation is used. + +Returns the modified `vi` together with the log absolute determinant of the Jacobian of the +transformation applied. +""" function setindex_with_dist!!(vi::VarInfo, val, dist::Distribution, vn::VarName) # Determine whether to insert a transformed value into `vi`. # If the VarInfo alrady had a value for this variable, we will @@ -99,6 +174,14 @@ function setindex_with_dist!!(vi::VarInfo, val, dist::Distribution, vn::VarName) return vi, logjac end +""" + setindex!!(vi::VarInfo, val, vn::VarName) + +Set the value of `vn` in `vi` to `val`. + +The transformation for `vn` is reset to be the standard vector transformation for values of +the type of `val` and linking status is set to false. +""" function BangBang.setindex!!(vi::VarInfo, val, vn::VarName) transform = from_vec_transform(val) transformed_val = inverse(transform)(val) @@ -106,9 +189,13 @@ function BangBang.setindex!!(vi::VarInfo, val, vn::VarName) return VarInfo(setindex!!(vi.values, tv, vn), vi.accs) end -Base.keys(vi::VarInfo) = keys(vi.values) -Base.values(vi::VarInfo) = mapreduce(p -> p.second.val, push!, vi.values; init=Any[]) +""" + set_transformed!!(vi::VarInfo, linked::Bool, vn::VarName) + +Set the linked status of variable `vn` in `vi` to `linked`. +This does not change the value or transformation of the variable. +""" function set_transformed!!(vi::VarInfo, linked::Bool, vn::VarName) old_tv = getindex(vi.values, vn) new_tv = TransformedValue(old_tv.val, linked, old_tv.transform, old_tv.size) @@ -137,13 +224,16 @@ function set_transformed!!(vi::VarInfo, linked::Bool) return VarInfo(new_values, vi.accs) end +""" + getindex_internal(vi::VarInfo, vn::VarName) + +Get the internal (vectorised) value of variable `vn` in `vi`. +""" function getindex_internal(vi::VarInfo, vn::VarName) tv = getindex(vi.values, vn) return tv.val end -# TODO(mhauru) This is mimicing old behaviour, but is now wrong: The internal -# representation does not have to be a Vector. getindex_internal(vi::VarInfo, ::Colon) = values_as(vi, Vector) function is_transformed(vi::VarInfo, vn::VarName) @@ -151,20 +241,18 @@ function is_transformed(vi::VarInfo, vn::VarName) return tv.linked end -# TODO(mhauru) Other VarInfos have something like this. Do we need it? Or should we use the -# below version? function from_internal_transform(::VarInfo, ::VarName, dist::Distribution) return from_vec_transform(dist) end -# function from_internal_transform(vi::VarInfo, vn::VarName, ::Distribution) -# return getindex(vi.values, vn).transform -# end - function from_linked_internal_transform(::VarInfo, ::VarName, dist::Distribution) return from_linked_vec_transform(dist) end +function from_internal_transform(vi::VarInfo, vn::VarName) + return getindex(vi.values, vn).transform +end + function from_linked_internal_transform(vi::VarInfo, vn::VarName) return getindex(vi.values, vn).transform end @@ -337,11 +425,26 @@ function unflatten!!(vi::VarInfo, vec::AbstractVector) return VarInfo(new_values, vi.accs) end +""" + subset(varinfo::VarInfo, vns) + +Create a new `VarInfo` containing only the variables in `vns`. + +`vns` can be almost any collection of `VarName`s, e.g. a `Set`, `Vector`, or `Tuple`. +""" function subset(varinfo::VarInfo, vns) new_values = subset(varinfo.values, vns) return VarInfo(new_values, map(copy, getaccs(varinfo))) end +""" + merge(varinfo_left::VarInfo, varinfo_right::VarInfo) + +Merge two `VarInfo`s into a new `VarInfo` containing all variables from both. + +If a variable exists in both `varinfo_left` and `varinfo_right`, the value from +`varinfo_right` is used. +""" function Base.merge(varinfo_left::VarInfo, varinfo_right::VarInfo) new_values = merge(varinfo_left.values, varinfo_right.values) new_accs = map(copy, getaccs(varinfo_right)) From 6dbae236ddf5b3345c959c20498d1a802bd27e0f Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 13 Jan 2026 17:15:31 +0000 Subject: [PATCH 129/148] Simplify transformations --- src/DynamicPPL.jl | 1 - src/abstract_varinfo.jl | 66 +--------------- src/contexts.jl | 16 ++-- src/contexts/transformation.jl | 44 ----------- src/test_utils/contexts.jl | 2 +- src/threadsafe.jl | 55 -------------- src/varinfo.jl | 134 +++++++++++++++++---------------- 7 files changed, 81 insertions(+), 237 deletions(-) delete mode 100644 src/contexts/transformation.jl diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index d6f4025ca..5889a6915 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -188,7 +188,6 @@ using .VarNamedTuples: VarNamedTuples, VarNamedTuple, map_pairs!!, map_values!!, include("contexts.jl") include("contexts/default.jl") include("contexts/init.jl") -include("contexts/transformation.jl") include("contexts/prefix.jl") include("contexts/conditionfix.jl") # Must come after contexts/prefix.jl include("model.jl") diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index c4af10898..67ac822cd 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -32,6 +32,9 @@ in the execution of a given `Model`. This is in constrast to `StaticTransformation` which transforms all variables _before_ the execution of a given `Model`. +Different VarInfo types should implement their own methods for `link!!` and `invlink!!` for +`DynamicTransformation`. + See also: [`StaticTransformation`](@ref). """ struct DynamicTransformation <: AbstractTransformation end @@ -53,23 +56,6 @@ struct StaticTransformation{F} <: AbstractTransformation bijector::F end -""" - merge_transformations(transformation_left, transformation_right) - -Merge two transformations. - -The main use of this is in [`merge(::AbstractVarInfo, ::AbstractVarInfo)`](@ref). -""" -function merge_transformations(::NoTransformation, ::NoTransformation) - return NoTransformation() -end -function merge_transformations(::DynamicTransformation, ::DynamicTransformation) - return DynamicTransformation() -end -function merge_transformations(left::StaticTransformation, right::StaticTransformation) - return StaticTransformation(merge_bijectors(left.bijector, right.bijector)) -end - function merge_bijectors(left::Bijectors.NamedTransform, right::Bijectors.NamedTransform) return Bijectors.NamedTransform(merge_bijector(left.bs, right.bs)) end @@ -744,31 +730,6 @@ end function link!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) return link!!(default_transformation(model, vi), vi, vns, model) end -function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) - model = setleafcontext(model, DynamicTransformationContext{false}()) - vi = last(evaluate!!(model, vi)) - return set_transformed!!(vi, t) -end -function link!!( - t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model -) - # TODO(mhauru) This assumes that the user has defined the bijector using the same - # variable ordering as what `vi[:]` and `unflatten!!(vi, x)` use. This is a bad user - # interface. - b = inverse(t.bijector) - x = vi[:] - y, logjac = with_logabsdet_jacobian(b, x) - # Set parameters and add the logjac term. - # TODO(mhauru) This doesn't set the transforms of `vi`. With the old Metadata that meant - # that getindex(vi, vn) would apply the default link transform of the distribution. With - # the new VarNamedTuple-based VarInfo it means that getindex(vi, vn) won't apply any - # transform. Neither is correct, rather the transform should be the inverse of b. - vi = unflatten!!(vi, y) - if hasacc(vi, Val(:LogJacobian)) - vi = acclogjac!!(vi, logjac) - end - return set_transformed!!(vi, t) -end """ link([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) @@ -811,27 +772,6 @@ end function invlink!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) return invlink!!(default_transformation(model, vi), vi, vns, model) end -function invlink!!(::DynamicTransformation, vi::AbstractVarInfo, model::Model) - model = setleafcontext(model, DynamicTransformationContext{true}()) - vi = last(evaluate!!(model, vi)) - return set_transformed!!(vi, NoTransformation()) -end -function invlink!!( - t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model -) - b = t.bijector - y = vi[:] - x, inv_logjac = with_logabsdet_jacobian(b, y) - - # Mildly confusing: we need to _add_ the logjac of the inverse transform, - # because we are trying to remove the logjac of the forward transform - # that was previously accumulated when linking. - vi = unflatten!!(vi, x) - if hasacc(vi, Val(:LogJacobian)) - vi = acclogjac!!(vi, inv_logjac) - end - return set_transformed!!(vi, NoTransformation()) -end """ invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) diff --git a/src/contexts.jl b/src/contexts.jl index 46c5b8855..0eccf7b53 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -25,18 +25,18 @@ Reconstruct `parent` but now using `child` is its [`childcontext`](@ref), effectively updating the child context. # Examples -```jldoctest -julia> using DynamicPPL: DynamicTransformationContext, ConditionContext +```jldoctest; setup=:(using Random) +julia> using DynamicPPL: InitContext, ConditionContext julia> ctx = ConditionContext((; a = 1)); julia> DynamicPPL.childcontext(ctx) DefaultContext() -julia> ctx_prior = DynamicPPL.setchildcontext(ctx, DynamicTransformationContext{true}()); +julia> ctx_prior = DynamicPPL.setchildcontext(ctx, InitContext(MersenneTwister(23), InitFromPrior())); julia> DynamicPPL.childcontext(ctx_prior) -DynamicTransformationContext{true}() +InitContext{MersenneTwister, InitFromPrior}(MersenneTwister(23), InitFromPrior()) ``` """ setchildcontext @@ -60,8 +60,8 @@ in which case effectively append `right` to `left`, dropping the original leaf context of `left`. # Examples -```jldoctest -julia> using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext, AbstractContext, DynamicTransformationContext +```jldoctest; setup=:(using Random) +julia> using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext, AbstractContext, InitContext julia> struct ParentContext{C} <: AbstractParentContext context::C @@ -77,8 +77,8 @@ julia> ctx = ParentContext(ParentContext(DefaultContext())) ParentContext(ParentContext(DefaultContext())) julia> # Replace the leaf context with another leaf. - leafcontext(setleafcontext(ctx, DynamicTransformationContext{true}())) -DynamicTransformationContext{true}() + leafcontext(setleafcontext(ctx, InitContext(MersenneTwister(23), InitFromPrior()))) +InitContext{MersenneTwister, InitFromPrior}(MersenneTwister(23), InitFromPrior()) julia> # Append another parent context. setleafcontext(ctx, ParentContext(DefaultContext())) diff --git a/src/contexts/transformation.jl b/src/contexts/transformation.jl deleted file mode 100644 index 0914d7a79..000000000 --- a/src/contexts/transformation.jl +++ /dev/null @@ -1,44 +0,0 @@ -""" - struct DynamicTransformationContext{isinverse} <: AbstractContext - -When a model is evaluated with this context, transform the accompanying `AbstractVarInfo` to -constrained space if `isinverse` or unconstrained if `!isinverse`. - -Note that some `AbstractVarInfo` types, must notably `VarInfo`, override the -`DynamicTransformationContext` methods with more efficient implementations. -`DynamicTransformationContext` is a fallback for when we need to evaluate the model to know -how to do the transformation. -""" -struct DynamicTransformationContext{isinverse} <: AbstractContext end - -function tilde_assume!!( - ::DynamicTransformationContext{isinverse}, - right::Distribution, - vn::VarName, - vi::AbstractVarInfo, -) where {isinverse} - # vi[vn, right] always provides the value in unlinked space. - x = vi[vn, right] - - if is_transformed(vi, vn) - isinverse || @warn "Trying to link an already transformed variable ($vn)" - else - isinverse && @warn "Trying to invlink a non-transformed variable ($vn)" - end - - transform = isinverse ? identity : link_transform(right) - y, logjac = with_logabsdet_jacobian(transform, x) - vi = accumulate_assume!!(vi, x, logjac, vn, right) - vi = setindex!!(vi, y, vn) - return x, vi -end - -function tilde_observe!!( - ::DynamicTransformationContext, - right::Distribution, - left, - vn::Union{VarName,Nothing}, - vi::AbstractVarInfo, -) - return tilde_observe!!(DefaultContext(), right, left, vn, vi) -end diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index cceedee8c..7182f511e 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -49,7 +49,7 @@ function test_parent_context(context::DynamicPPL.AbstractContext, model::Dynamic @testset "get/set leaf and child contexts" begin # Ensure we're using a different leaf context than the current. leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext - DynamicPPL.DynamicTransformationContext{false}() + DynamicPPL.InitContext(Random.MersenneTwister(1234), InitFromPrior()) else DefaultContext() end diff --git a/src/threadsafe.jl b/src/threadsafe.jl index d83cb289d..547dd6a1e 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -82,61 +82,6 @@ function invlink!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, args...) end -function link(t::AbstractTransformation, vi::ThreadSafeVarInfo, model::Model) - return Accessors.@set vi.varinfo = link(t, vi.varinfo, model) -end - -function invlink(t::AbstractTransformation, vi::ThreadSafeVarInfo, model::Model) - return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, model) -end - -function link( - t::AbstractTransformation, vi::ThreadSafeVarInfo, vns::VarNameTuple, model::Model -) - return Accessors.@set vi.varinfo = link(t, vi.varinfo, vns, model) -end - -function invlink( - t::AbstractTransformation, vi::ThreadSafeVarInfo, vns::VarNameTuple, model::Model -) - return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, vns, model) -end - -# Need to define explicitly for `DynamicTransformation` to avoid method ambiguity. -# NOTE: We also can't just defer to the wrapped varinfo, because we need to ensure -# consistency between `vi.accs_by_thread` field and `getacc(vi.varinfo)`, which accumulates -# to define `getacc(vi)`. -function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - model = setleafcontext(model, DynamicTransformationContext{false}()) - return set_transformed!!(last(evaluate!!(model, vi)), t) -end - -function invlink!!(::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - model = setleafcontext(model, DynamicTransformationContext{true}()) - return set_transformed!!(last(evaluate!!(model, vi)), NoTransformation()) -end - -function link(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - return link!!(t, deepcopy(vi), model) -end - -function invlink(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - return invlink!!(t, deepcopy(vi), model) -end - -# These two StaticTransformation methods needed to resolve ambiguities -function link!!( - t::StaticTransformation{<:Bijectors.Transform}, vi::ThreadSafeVarInfo, model::Model -) - return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, model) -end - -function invlink!!( - t::StaticTransformation{<:Bijectors.Transform}, vi::ThreadSafeVarInfo, model::Model -) - return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, model) -end - function maybe_invlink_before_eval!!(vi::ThreadSafeVarInfo, model::Model) # Defer to the wrapped `AbstractVarInfo` object. # NOTE: When computing `getacc` for `ThreadSafeVarInfo` we do include the diff --git a/src/varinfo.jl b/src/varinfo.jl index 688c90b03..3af47691b 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -257,38 +257,25 @@ function from_linked_internal_transform(vi::VarInfo, vn::VarName) return getindex(vi.values, vn).transform end -function link!!(::DynamicTransformation, vi::VarInfo, vns, model::Model) - dists = extract_priors(model, vi) - cumulative_logjac = zero(LogProbType) - new_values = map_pairs!!(vi.values) do pair - vn, tv = pair - if vns !== nothing && !any(x -> subsumes(x, vn), vns) - # Not one of the target variables. - return tv - end - dist = getindex(dists, vn) - vec_transform = from_vec_transform(dist) - link_transform = from_linked_vec_transform(dist) - val_untransformed, logjac1 = with_logabsdet_jacobian(vec_transform, tv.val) - val_new, logjac2 = with_logabsdet_jacobian( - inverse(link_transform), val_untransformed - ) - new_tv = TransformedValue(val_new, true, link_transform, tv.size) - cumulative_logjac += logjac1 + logjac2 - return new_tv - end - vi = VarInfo(new_values, vi.accs) - if hasacc(vi, Val(:LogJacobian)) - vi = acclogjac!!(vi, cumulative_logjac) - end - return vi -end +""" + _link_or_invlink!!(vi::VarInfo, vns, model::Model, ::Val{link}) where {link isa Bool} -function link!!(t::DynamicTransformation, vi::VarInfo, model::Model) - return link!!(t, vi, nothing, model) -end +The internal function that implements both link!! and invlink!!. -function invlink!!(::DynamicTransformation, vi::VarInfo, vns, model::Model) +The last argument controls whether linking (true) or invlinking (false) is performed. If +`vns` is `nothing`, all variables in `vi` are transformed; otherwise, only the variables +in `vns` are transformed. Existing variables already in the desired state are left +unchanged. +""" +function _link_or_invlink!!(vi::VarInfo, vns, model::Model, ::Val{link}) where {link} + @assert link isa Bool + # Note that extract_priors causes a model execution. In the past with the Metadata-based + # VarInfo we rather derived the transformations from the distributions stored in the + # VarInfo itself. However, that is not fail-safe with dynamic models, and would require + # storing the distributions in TransformedValue (which we could start doing). Instead we + # use extract_priors to get the current, correct transformations. This logic is very + # similar to what DynamicTransformation used to do, and we might replace this with a + # context that transforms each variable in turn during the execution. dists = extract_priors(model, vi) cumulative_logjac = zero(LogProbType) new_values = map_pairs!!(vi.values) do pair @@ -297,15 +284,23 @@ function invlink!!(::DynamicTransformation, vi::VarInfo, vns, model::Model) # Not one of the target variables. return tv end - current_val = tv.val + if tv.linked == link + # Already in the desired state. + return tv + end dist = getindex(dists, vn) vec_transform = from_vec_transform(dist) link_transform = from_linked_vec_transform(dist) - val_untransformed, logjac1 = with_logabsdet_jacobian(link_transform, current_val) + current_transform, new_transform = if link + (vec_transform, link_transform) + else + (link_transform, vec_transform) + end + val_untransformed, logjac1 = with_logabsdet_jacobian(current_transform, tv.val) val_new, logjac2 = with_logabsdet_jacobian( - inverse(vec_transform), val_untransformed + inverse(new_transform), val_untransformed ) - new_tv = TransformedValue(val_new, false, vec_transform, tv.size) + new_tv = TransformedValue(val_new, link, new_transform, tv.size) cumulative_logjac += logjac1 + logjac2 return new_tv end @@ -316,42 +311,51 @@ function invlink!!(::DynamicTransformation, vi::VarInfo, vns, model::Model) return vi end -function invlink!!(t::DynamicTransformation, vi::VarInfo, model::Model) - return invlink!!(t, vi, nothing, model) +function link!!(::DynamicTransformation, vi::VarInfo, vns, model::Model) + return _link_or_invlink!!(vi, vns, model, Val(true)) end - -function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, model::Model) - # By default this will simply evaluate the model with `DynamicTransformationContext`, - # and so we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, model) +function link!!(::DynamicTransformation, vi::VarInfo, model::Model) + return _link_or_invlink!!(vi, nothing, model, Val(true)) end - -function link!!( - t::DynamicTransformation, - vi::ThreadSafeVarInfo{<:VarInfo}, - vns::VarNameTuple, - model::Model, -) - # By default this will simply evaluate the model with `DynamicTransformationContext`, - # and so we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, vns, model) +function invlink!!(::DynamicTransformation, vi::VarInfo, vns, model::Model) + return _link_or_invlink!!(vi, vns, model, Val(false)) +end +function invlink!!(::DynamicTransformation, vi::VarInfo, model::Model) + return _link_or_invlink!!(vi, nothing, model, Val(false)) +end + +function link!!(t::StaticTransformation{<:Bijectors.Transform}, vi::VarInfo, ::Model) + # TODO(mhauru) This assumes that the user has defined the bijector using the same + # variable ordering as what `vi[:]` and `unflatten!!(vi, x)` use. This is a bad user + # interface. + b = inverse(t.bijector) + x = vi[:] + y, logjac = with_logabsdet_jacobian(b, x) + # Set parameters and add the logjac term. + # TODO(mhauru) This doesn't set the transforms of `vi`. With the old Metadata that meant + # that getindex(vi, vn) would apply the default link transform of the distribution. With + # the new VarNamedTuple-based VarInfo it means that getindex(vi, vn) won't apply any + # transform. Neither is correct, rather the transform should be the inverse of b. + vi = unflatten!!(vi, y) + if hasacc(vi, Val(:LogJacobian)) + vi = acclogjac!!(vi, logjac) + end + return set_transformed!!(vi, t) end -function invlink!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, model::Model) - # By default this will simply evaluate the model with `DynamicTransformationContext`, - # and so we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(t, vi.varinfo, model) -end +function invlink!!(t::StaticTransformation{<:Bijectors.Transform}, vi::VarInfo, ::Model) + b = t.bijector + y = vi[:] + x, inv_logjac = with_logabsdet_jacobian(b, y) -function invlink!!( - ::DynamicTransformation, - vi::ThreadSafeVarInfo{<:VarInfo}, - vns::VarNameTuple, - model::Model, -) - # By default this will simply evaluate the model with `DynamicTransformationContext`, and so - # we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(vi.varinfo, vns, model) + # Mildly confusing: we need to _add_ the logjac of the inverse transform, + # because we are trying to remove the logjac of the forward transform + # that was previously accumulated when linking. + vi = unflatten!!(vi, x) + if hasacc(vi, Val(:LogJacobian)) + vi = acclogjac!!(vi, inv_logjac) + end + return set_transformed!!(vi, NoTransformation()) end # TODO(mhauru) I don't think this should return the internal values, but that's the current From 2fa7333ea1a0a19fec823da2e13c69e58cfc062e Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 13 Jan 2026 18:19:32 +0000 Subject: [PATCH 130/148] Fix docs --- docs/src/api.md | 23 ++++++++--------------- docs/src/internals/varinfo.md | 1 - src/varinfo.jl | 3 +-- 3 files changed, 9 insertions(+), 18 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index a506c793e..5cd94fccd 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -343,6 +343,8 @@ AbstractVarInfo ```@docs VarInfo +DynamicPPL.TransformedValue +DynamicPPL.setindex_with_dist!! ``` One main characteristic of [`VarInfo`](@ref) is that samples are transformed to unconstrained Euclidean space and stored in a linearized form, as described in the [main Turing documentation](https://turinglang.org/docs/developers/transforms/dynamicppl/). @@ -354,11 +356,7 @@ is_transformed set_transformed!! ``` -```@docs -Base.empty! -``` - -#### `VarNamedTuple` +#### `VarNamedTuple`s `VarInfo` is only a thin wrapper around [`VarNamedTuple`](@ref), which stores arbitrary data keyed by `VarName`s. For more details on `VarNamedTuple`, see the Internals section of our documentation. @@ -366,6 +364,10 @@ For more details on `VarNamedTuple`, see the Internals section of our documentat ```@docs DynamicPPL.VarNamedTuples.VarNamedTuple DynamicPPL.VarNamedTuples.vnt_size +DynamicPPL.VarNamedTuples.apply!! +DynamicPPL.VarNamedTuples.map_pairs!! +DynamicPPL.VarNamedTuples.map_values!! +DynamicPPL.VarNamedTuples.PartialArray ``` ### Accumulators @@ -411,19 +413,10 @@ accloglikelihood!! ```@docs keys getindex -push!! empty!! isempty DynamicPPL.getindex_internal -DynamicPPL.setindex_internal! -DynamicPPL.update_internal! -DynamicPPL.insert_internal! -DynamicPPL.length_internal -DynamicPPL.reset! -DynamicPPL.update! -DynamicPPL.insert! -DynamicPPL.loosen_types!! -DynamicPPL.tighten_types!! +DynamicPPL.setindex_internal!! ``` ```@docs diff --git a/docs/src/internals/varinfo.md b/docs/src/internals/varinfo.md index c57ea1fcf..f3f100a81 100644 --- a/docs/src/internals/varinfo.md +++ b/docs/src/internals/varinfo.md @@ -39,7 +39,6 @@ One can access a vectorised version of a variable's value with the following vec - `getindex_internal(::VarInfo, i::Int)`: get `i`th value of the flattened vector of all values - `setindex_internal!!(::VarInfo, ::AbstractVector, ::VarName)`: set the flattened value of a variable. - `setindex_internal!!(::VarInfo, val, i::Int)`: set the `i`th value of the flattened vector of all values - - `length_internal(::VarInfo)`: return the length of the flat representation of `metadata`. The functions have `_internal` in their name because internally `VarInfo` always stores values as vectorised. diff --git a/src/varinfo.jl b/src/varinfo.jl index 3af47691b..4cda6b40f 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -60,8 +60,7 @@ struct TransformedValue{ValType,TransformType,SizeType} actual value of the variable being stored is `transform(val)`.""" transform::TransformType """The size of the actual value after transformation. This is needed when a - TransformedValue is stored as a block in an array (see [`PartialArray`](@ref) in - `VarNamedTuples`).""" + `TransformedValue` is stored as a block in an array.""" size::SizeType end From 7cbc4a7bf6c1f5657fb8258fd8c1b460d19e8367 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 13 Jan 2026 18:21:23 +0000 Subject: [PATCH 131/148] Mark some inference tests as broken on 1.10 --- test/varnamedtuple.jl | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 1937ea189..f3f1e83e6 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -667,6 +667,9 @@ Base.size(st::SizedThing) = st.size end @testset "length" begin + # Type inference for length fails in some cases on Julia versions < 1.11 + inference_broken = VERSION < v"1.11" + vnt = VarNamedTuple() @test @inferred(length(vnt)) == 0 @@ -683,23 +686,23 @@ Base.size(st::SizedThing) = st.size @test @inferred(length(vnt)) == 3 vnt = setindex!!(vnt, -1.0, @varname(d[4])) - @test @inferred(length(vnt)) == 4 + @test @inferred(length(vnt)) == 4 broken = inference_broken vnt = setindex!!(vnt, ["a", "b"], @varname(d[1:2])) - @test @inferred(length(vnt)) == 6 + @test @inferred(length(vnt)) == 6 broken = inference_broken vnt = setindex!!(vnt, 2.0, @varname(e.f[3].g.h[2].i)) vnt = setindex!!(vnt, 3.0, @varname(e.f[3].g.h[2].j)) - @test @inferred(length(vnt)) == 8 + @test @inferred(length(vnt)) == 8 broken = inference_broken vnt = setindex!!(vnt, SizedThing((3, 2)), @varname(x[1, 2:4, 2, 1:2, 3])) - @test @inferred(length(vnt)) == 14 + @test @inferred(length(vnt)) == 14 broken = inference_broken vnt = setindex!!(vnt, SizedThing((3, 2)), @varname(x[1, 4:6, 2, 1:2, 3])) - @test @inferred(length(vnt)) == 14 + @test @inferred(length(vnt)) == 14 broken = inference_broken vnt = setindex!!(vnt, [:a, :b], @varname(y[4][3][2][1:2])) - @test @inferred(length(vnt)) == 16 + @test @inferred(length(vnt)) == 16 broken = inference_broken test_invariants(vnt) end @@ -917,7 +920,9 @@ Base.size(st::SizedThing) = st.size @test haskey(vnt, @varname(y.z[2, 2:3, 3, 2:3, 4])) @test @inferred(getindex(vnt, @varname(y.z[2, 2:3, 3, 2:3, 4]))) == val @test haskey(vnt, @varname(y.z[3, 2:3, 3, 2:3, 4])) - @test @inferred(getindex(vnt, @varname(y.z[3, 2:3, 3, 2:3, 4]))) == val + # Type inference fails on this one for Julia versions < 1.11 + @test @inferred(getindex(vnt, @varname(y.z[3, 2:3, 3, 2:3, 4]))) == val broken = + VERSION < v"1.11" end @testset "map and friends" begin From b4361c04612731e0357a17ac42dc850a055e003c Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 13 Jan 2026 19:02:12 +0000 Subject: [PATCH 132/148] Polish VNT and tests --- src/varnamedtuple.jl | 41 +++++++++++++++++++---------------------- test/varnamedtuple.jl | 6 +++++- 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 37158442b..0287e393b 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -26,7 +26,7 @@ function _haskey end Like `setindex!!`, but special-cased for `VarNamedTuple` and `PartialArray` to recurse into nested structures. -The `allow_new` keywword argument is a performance optimisation: If it is set to +The `allow_new` keyword argument is a performance optimisation: If it is set to `Val(false)`, the function can assume that the key being set already exists in `collection`. This allows skipping some code paths, which may have a minor benefit at runtime, but more importantly, allows for better constant propagation and type stability at compile time. @@ -541,7 +541,7 @@ function _check_index_validity(pa::PartialArray, inds::NTuple{N,INDEX_TYPES}) wh end function _getindex(pa::PartialArray, inds::Vararg{INDEX_TYPES}) - # The original, non-bare inds is needed later for ArrayLikeBlock checks. + # The unmodified inds is needed later for ArrayLikeBlock checks. orig_inds = inds inds = _unwrap_concretized_slice.(inds) _check_index_validity(pa, inds) @@ -1237,6 +1237,7 @@ end end end +# As above but with a prefix VarName `vn`. @generated function _map_recursive!!(func, vnt::VarNamedTuple{Names}, vn::T) where {Names,T} exs = Expr[] for name in Names @@ -1273,10 +1274,13 @@ map_values!!(func, vnt::VarNamedTuple) = map_pairs!!(pair -> func(pair.second), Apply `f` to all elements of `vnt`, and reduce the results using `op`, starting from `init`. +The order is the same as in `mapfoldl`, i.e. left-associative with `init` as the +left-most value. + `init` is a keyword argument to conform to the usual `mapreduce` interface in Base, but it is not optional. -`f` op` should accept pairs of `VarName` and value. +`f` op` should accept pairs of `varname => value`. """ function Base.mapreduce(f, op, vnt::VarNamedTuple; init=nothing) if init === nothing @@ -1298,41 +1302,30 @@ _mapreduce_recursive(f, op, pa::ArrayLikeBlock, vn, init) = op(init, f(vn => pa. @generated function _mapreduce_recursive( f, op, vnt::VarNamedTuple{Names}, init ) where {Names} - exs = Expr[] - push!( - exs, - quote - result = init - end, - ) + exs = Expr[:(result = init)] for name in Names push!( exs, - :( + quote result = _mapreduce_recursive( f, op, vnt.data.$name, VarName{$(QuoteNode(name))}(), result ) - ), + end, ) end push!(exs, :(return result)) return Expr(:block, exs...) end +# As above but with a prefix VarName `vn`. @generated function _mapreduce_recursive( f, op, vnt::VarNamedTuple{Names}, vn, init ) where {Names} - exs = Expr[] - push!( - exs, - quote - result = init - end, - ) + exs = Expr[:(result = init)] for name in Names push!( exs, - :( + quote result = _mapreduce_recursive( f, op, @@ -1340,7 +1333,7 @@ end AbstractPPL.prefix(VarName{$(QuoteNode(name))}(), vn), result, ) - ), + end, ) end push!(exs, :(return result)) @@ -1354,7 +1347,7 @@ function _mapreduce_recursive(f, op, pa::PartialArray, vn, init) albs_seen = Set{ArrayLikeBlock}() @inbounds for i in CartesianIndices(pa.mask) if pa.mask[i] - val = @inbounds pa.data[i] + val = pa.data[i] is_alb = val isa ArrayLikeBlock if is_alb if val in albs_seen @@ -1370,6 +1363,10 @@ function _mapreduce_recursive(f, op, pa::PartialArray, vn, init) return result end +# TODO(mhauru) We could try to keep the return types of these more tight, rather than always +# return the same, abstract element type. Would that be better? It would be faster in some +# cases, but would be less consistent, and could result in a lot of allocations in the +# mapreduce, as the element type is gradually expanded. Base.keys(vnt::VarNamedTuple) = mapreduce(first, push!, vnt; init=VarName[]) Base.values(vnt::VarNamedTuple) = mapreduce(pair -> pair.second, push!, vnt; init=Any[]) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index f3f1e83e6..7c2a263f5 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -502,6 +502,9 @@ Base.size(st::SizedThing) = st.size vnt = setindex!!(vnt, SizedThing((3, 1, 4)), @varname(p[2, 1][2:4, 5:5, 11:14])) test_invariants(vnt) + # TODO(mhauru) I'm a bit saddened by the lack of type stability for subset: It's + # return type always infers as VarNamedTuple. Improving this would require a + # different implementation of subset. @test subset(vnt, VarName[]) == VarNamedTuple() @test subset(vnt, (@varname(z),)) == VarNamedTuple() @test subset(vnt, (@varname(d[4]),)) == VarNamedTuple() @@ -1025,7 +1028,8 @@ Base.size(st::SizedThing) = st.size @test @inferred(getindex(vnt_mapped, @varname(w[4][3][2, 1]))) == "b" call_counter = 0 - vnt_applied = @inferred(apply!!(f_val, vnt, @varname(a))) + vnt_applied = copy(vnt) + vnt_applied = @inferred(apply!!(f_val, vnt_applied, @varname(a))) @test call_counter == 1 test_invariants(vnt_applied; skip=(:parseeval,)) @test @inferred(getindex(vnt_applied, @varname(a))) == 11 From 73e50df5d576fd76473888a4244cf88c2ac59f6f Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 13 Jan 2026 19:10:07 +0000 Subject: [PATCH 133/148] Fix broken test marking --- test/varnamedtuple.jl | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 7c2a263f5..63ee12c5b 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -924,8 +924,7 @@ Base.size(st::SizedThing) = st.size @test @inferred(getindex(vnt, @varname(y.z[2, 2:3, 3, 2:3, 4]))) == val @test haskey(vnt, @varname(y.z[3, 2:3, 3, 2:3, 4])) # Type inference fails on this one for Julia versions < 1.11 - @test @inferred(getindex(vnt, @varname(y.z[3, 2:3, 3, 2:3, 4]))) == val broken = - VERSION < v"1.11" + @test @inferred(getindex(vnt, @varname(y.z[3, 2:3, 3, 2:3, 4]))) == val end @testset "map and friends" begin @@ -1048,7 +1047,14 @@ Base.size(st::SizedThing) = st.size test_invariants(vnt_applied; skip=(:parseeval,)) @test @inferred(getindex(vnt_applied, @varname(c.d))) == [2.0] - vnt_applied = @inferred(apply!!(f_val, vnt_applied, @varname(e.f[3].g.h[2].i))) + vnt_applied = begin + # The @inferred fails on Julia 1.10. + @static if VERSION < v"1.11" + apply!!(f_val, vnt_applied, @varname(e.f[3].g.h[2].i)) + else + @inferred(apply!!(f_val, vnt_applied, @varname(e.f[3].g.h[2].i))) + end + end @test call_counter == 4 test_invariants(vnt_applied; skip=(:parseeval,)) @test @inferred(getindex(vnt_applied, @varname(e.f[3].g.h[2].i))) == "ab" From 922fbb62f3c3717ddf36a68cad22f7dc2f082ac3 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 13 Jan 2026 19:46:41 +0000 Subject: [PATCH 134/148] Polish varinfo.jl --- src/varinfo.jl | 46 +++++++++++++++++++++------------------------- 1 file changed, 21 insertions(+), 25 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 4cda6b40f..4dfa538c1 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -145,34 +145,32 @@ Set the value of `vn` in `vi` to `val`, applying a transformation based on `dist `val` is taken to be the actual value of the variable, and is transformed into the internal (vectorised) representation using a transformation based on `dist`. If the variable is -linked in `vi`, or doesn't exist in `vi` but all other variables in `vi` are linked, the -linking transformation is used; otherwise, the standard vector transformation is used. +currently linked in `vi`, or doesn't exist in `vi` but all other variables in `vi` are +linked, the linking transformation is used; otherwise, the standard vector transformation is +used. Returns the modified `vi` together with the log absolute determinant of the Jacobian of the transformation applied. """ function setindex_with_dist!!(vi::VarInfo, val, dist::Distribution, vn::VarName) - # Determine whether to insert a transformed value into `vi`. - # If the VarInfo alrady had a value for this variable, we will - # keep the same linked status as in the original VarInfo. If not, we - # check the rest of the VarInfo to see if other variables are linked. - # is_transformed(vi) returns true if vi is nonempty and all variables in vi - # are linked. - insert_transformed_value = haskey(vi, vn) ? is_transformed(vi, vn) : is_transformed(vi) - # TODO(mhauru) We should move away from having all values vectorised by default. - # That messes with our use of unflatten though, so will require some thought. - transform = if insert_transformed_value + link = haskey(vi, vn) ? is_transformed(vi, vn) : is_transformed(vi) + transform = if link from_linked_vec_transform(dist) else from_vec_transform(dist) end transformed_val, logjac = with_logabsdet_jacobian(inverse(transform), val) + # All values for which `size` is not defined are assumed to be scalars. val_size = hasmethod(size, Tuple{typeof(val)}) ? size(val) : () - tv = TransformedValue(transformed_val, insert_transformed_value, transform, val_size) + tv = TransformedValue(transformed_val, link, transform, val_size) vi = VarInfo(setindex!!(vi.values, tv, vn), vi.accs) return vi, logjac end +# TODO(mhauru) The below is somewhat unsafe or incomplete: For instance, from_vec_transform +# isn't defined for NamedTuples. However, this is needed in some places where values for +# in a VarInfo are set outside the context of a `tilde_assume!!` and no distribution is +# available. Hopefully we'll get rid of this eventually. """ setindex!!(vi::VarInfo, val, vn::VarName) @@ -228,17 +226,11 @@ end Get the internal (vectorised) value of variable `vn` in `vi`. """ -function getindex_internal(vi::VarInfo, vn::VarName) - tv = getindex(vi.values, vn) - return tv.val -end - +getindex_internal(vi::VarInfo, vn::VarName) = getindex(vi.values, vn).val +# TODO(mhauru) The below should be removed together with unflatten!!. getindex_internal(vi::VarInfo, ::Colon) = values_as(vi, Vector) -function is_transformed(vi::VarInfo, vn::VarName) - tv = getindex(vi.values, vn) - return tv.linked -end +is_transformed(vi::VarInfo, vn::VarName) = getindex(vi.values, vn).linked function from_internal_transform(::VarInfo, ::VarName, dist::Distribution) return from_vec_transform(dist) @@ -253,6 +245,9 @@ function from_internal_transform(vi::VarInfo, vn::VarName) end function from_linked_internal_transform(vi::VarInfo, vn::VarName) + if !is_transformed(vi, vn) + error("Variable $vn is not linked; cannot get linked transformation.") + end return getindex(vi.values, vn).transform end @@ -330,11 +325,10 @@ function link!!(t::StaticTransformation{<:Bijectors.Transform}, vi::VarInfo, ::M b = inverse(t.bijector) x = vi[:] y, logjac = with_logabsdet_jacobian(b, x) - # Set parameters and add the logjac term. # TODO(mhauru) This doesn't set the transforms of `vi`. With the old Metadata that meant # that getindex(vi, vn) would apply the default link transform of the distribution. With # the new VarNamedTuple-based VarInfo it means that getindex(vi, vn) won't apply any - # transform. Neither is correct, rather the transform should be the inverse of b. + # link transform. Neither is correct, rather the transform should be the inverse of b. vi = unflatten!!(vi, y) if hasacc(vi, Val(:LogJacobian)) vi = acclogjac!!(vi, logjac) @@ -417,7 +411,7 @@ function unflatten!!(vi::VarInfo, vec::AbstractVector) old_val = tv.val if !(old_val isa AbstractVector) error( - "Can not unflatten a VarInfo for which existing values are not vectors:" * + "Can't unflatten a VarInfo for which existing values are not vectors:" * " Got value of type $(typeof(old_val)).", ) end @@ -445,6 +439,8 @@ end Merge two `VarInfo`s into a new `VarInfo` containing all variables from both. +The accumulators are taken exclusively from `varinfo_right`. + If a variable exists in both `varinfo_left` and `varinfo_right`, the value from `varinfo_right` is used. """ From 66c79709942abbccc38fe01d158668b04719f96d Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 13 Jan 2026 19:56:09 +0000 Subject: [PATCH 135/148] Polish internal docs --- docs/src/internals/varinfo.md | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/docs/src/internals/varinfo.md b/docs/src/internals/varinfo.md index f3f100a81..6d87e5edc 100644 --- a/docs/src/internals/varinfo.md +++ b/docs/src/internals/varinfo.md @@ -14,7 +14,7 @@ It contains `values` takes care of storing information related to values of individual random variables, while `accs` keeps track of information that we keep accumulating in the course of evaluating through a model. Variables are regonised by their `VarName`. -We want to work with `VarName` rather than something like `Symbol` or `String` as `VarName` contains additional structural information. +We want to work with `VarName`s rather than something like `Symbol` or `String` as `VarName` contains additional structural information. For instance, a `Symbol("x[1]")` can be a result of either `var"x[1]" ~ Normal()` or `x[1] ~ Normal()`; these scenarios are disambiguated by `VarName`. `VarName`s also allow things such as setting values for `x[1]` and `x[2]` and getting a value for `x` as a whole. @@ -24,7 +24,6 @@ To ensure that `VarInfo` is simple and intuitive to work with we want it to repl - `haskey(::VarInfo)`: check if a particular `VarName` is present. - `getindex(::VarInfo, ::VarName)`: return the realization corresponding to a particular `VarName`. - `setindex!!(::VarInfo, val, ::VarName)`: set the realization corresponding to a particular `VarName`. - - `delete!!(::VarInfo, ::VarName)`: delete the realization corresponding to a particular `VarName`. - `empty!!(::VarInfo)`: delete all data. - `merge(::VarInfo, ::VarInfo)`: merge two containers according to similar rules as `Dict`. @@ -36,15 +35,12 @@ One can access a vectorised version of a variable's value with the following vec - `getindex_internal(::VarInfo, ::VarName)`: get the flattened value of a single variable. - `getindex_internal(::VarInfo, ::Colon)`: get the flattened values of all variables. - - `getindex_internal(::VarInfo, i::Int)`: get `i`th value of the flattened vector of all values - `setindex_internal!!(::VarInfo, ::AbstractVector, ::VarName)`: set the flattened value of a variable. - - `setindex_internal!!(::VarInfo, val, i::Int)`: set the `i`th value of the flattened vector of all values The functions have `_internal` in their name because internally `VarInfo` always stores values as vectorised. Moreover, a link transformation can be applied to a `VarInfo` with `link!!` (and reversed with `invlink!!`), which applies a reversible transformation to the internal storage format of a variable that makes the range of the random variable cover all of Euclidean space. `getindex_internal` and `setindex_internal!` give direct access to the vectorised value after such a transformation, which is what samplers often need to be able sample in unconstrained space. -One can also manually set a transformation by giving `setindex_internal!!` a fourth, optional argument, that is a function that maps internally stored value to the actual value of the variable. Finally, we want want the underlying storage to have a few performance-related properties: From 51fdcbec611911be498d4a6499c5d6f09ae24cf1 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 13 Jan 2026 19:58:59 +0000 Subject: [PATCH 136/148] More broken inference tests on v1.10 --- test/varnamedtuple.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 63ee12c5b..abd1406e6 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -1060,7 +1060,14 @@ Base.size(st::SizedThing) = st.size @test @inferred(getindex(vnt_applied, @varname(e.f[3].g.h[2].i))) == "ab" @test @inferred(getindex(vnt_applied, @varname(e.f[3].g.h[2].j))) == 5.0 - vnt_applied = @inferred(apply!!(f_val, vnt_applied, @varname(e.f[3].g.h[2].j))) + vnt_applied = begin + # The @inferred fails on Julia 1.10. + @static if VERSION < v"1.11" + apply!!(f_val, vnt_applied, @varname(e.f[3].g.h[2].j)) + else + @inferred(apply!!(f_val, vnt_applied, @varname(e.f[3].g.h[2].j))) + end + end @test call_counter == 5 test_invariants(vnt_applied; skip=(:parseeval,)) @test @inferred(getindex(vnt_applied, @varname(e.f[3].g.h[2].i))) == "ab" From 07a13c4532c53da6363ed239c36d2cfb045300ce Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 14 Jan 2026 15:11:09 +0000 Subject: [PATCH 137/148] Export VarNamedTuple and its functions --- src/DynamicPPL.jl | 4 ++++ src/values_as_in_model.jl | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 5889a6915..9961125e2 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -46,6 +46,10 @@ import Base: # VarInfo export AbstractVarInfo, VarInfo, + VarNamedTuple, + map_pairs!!, + map_values!!, + apply!!, AbstractAccumulator, LogLikelihoodAccumulator, LogPriorAccumulator, diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index 304b99a3e..9ee622424 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -111,7 +111,7 @@ julia> # Construct initial `VarInfo`. varinfo = VarInfo(rng, model); julia> # Link it so it works in unconstrained space. - varinfo_linked = DynamicPPL.link!!(copy(varinfo), model); + varinfo_linked = DynamicPPL.link(varinfo, model); julia> # Perform computations in unconstrained space, e.g. changing the values of `vals`. # Flip `x` so we hit the other support of `y`. @@ -125,7 +125,7 @@ julia> # Determine the expected support of `y`. (0, 1) julia> # Approach 1: Convert back to constrained space using `invlink` and extract. - varinfo_invlinked = DynamicPPL.invlink!!(copy(varinfo_linked), model); + varinfo_invlinked = DynamicPPL.invlink(varinfo_linked, model); julia> lb ≤ first(varinfo_invlinked[@varname(y)]) ≤ ub true From 92dd490a0839b8a7b703d8efc838f6eb70e3c537 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 14 Jan 2026 15:34:46 +0000 Subject: [PATCH 138/148] Add HISTROY.md entry on the new VarInfo --- HISTORY.md | 62 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/HISTORY.md b/HISTORY.md index bb40b8464..d6c13f7a6 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -2,6 +2,68 @@ ## 0.40 +### `VarNamedTuple` + +DynamicPPL now exports a new type, called `VarNamedTuple`, which stores values keyed by `VarName`s. +With it are exported a few new functions for using it: `map_values!!`, `map_pairs!!`, `apply!!`. +Our documentation's Internals section now has a page about `VarNamedTuple`, how it works, and what it's good for. + +`VarNamedTuple` is now used internally in many different parts: In `VarInfo`, in `values_as_in_model`, in `LogDensityFunction`, etc. +Almost all of the below changes are the consequence from switching over to using `VarNamedTuple` for various features internally. + +### Overhaul of `VarInfo` + +DynamicPPL tracks variable values during model execution using one of the `AbstractVarInfo` types. +Previously, there were many versions of them: `VarInfo`, both "typed" and "untyped, and `SimpleVarInfo` with both `NamedTuple` and `OrderedDict` as storage backends. +These have all been replaced by a rewritten implementation of `VarInfo`. +While the basics of the `VarInfo` interface remain the same, this brings with it many changes: + +#### No more many `AbstractVarInfo` types + +`SimpleVarInfo`, `untyped_varinfo`, `typed_varinfo`, and many other constructors, some exported some not, have been removed. +The remaining one is `VarInfo(...)`, which can take a model or a collection of values. +See the docstring for details. + +Some related types and functions, that weren't exported but may have been used by some, have also been removed, most notably `VarNamedVector` and its associated functions like `loosen_types!!` and `tighten_types!!`. + +#### Setting and getting values + +Previously the various `AbstractVarInfo` types had a multitude of functions for setting values: +`push!!`, `push!`, `setindex!`, `update!`, `update_internal!`, `insert_internal!`, `reset!`, etc. +These have all been replaced by three functions + + - `setindex!!` is the one to use for simply setting a variable in `VarInfo` to a known value. It works regardless of whether the variable already exists. + - `setindex_internal!!` is the one to use for setting the internal, vectorised representation of a variable. See the docstring for details. + - `setindex_with_dist!!` is to be used when you want to set a value, but choose the internal representation based on which distribution this value is a sample for. + +The order of the arguments for some of these functions has also changed, and now more closely matches the usual convention for `setindex!!`. + +Note that `setindex!` (with a single `!`) is not defined, and thus you can't do `varinfo[varname] = new_value`. + +`unflatten` works as before, but has been renamed to `unflatten!!`, since it may mutate the first argument and aliases memory with the second argument (it uses views rather than copies of the input vector). + +#### Linking is now safer + +`link!!` and `invlink!!` on `VarInfo` used to assume that the prior distribution of a variable didn't change from one execution to another (as it does in e.g. `truncated(dist; lower=x)` where `x` is a random variable). +This is no longer the case. +Linking should thus be safer to do. +The cost to pay is that calls to `link!!` and `invlink!!` (and the non-mutating versions) now trigger a model evaluation, to determine the correct priors to use. + +#### Other miscellanea + + - The `Experimental` module had functions like `Experimental.determine_suitable_varinfo` for determining which `AbstractVarInfo` type was suitable for a given model. This is now redundant and has been removed. + - `Bijectors.bijector(::Model)`, which creates a bijector from the vectorised variable space of the model to the linked variable space of the model, now has slightly different optional arguments. See the docstring for details. + - `NamedDist` no longer allows variable names with `Colon`s in them, such as `x[:]`. + +There are probably also changes to the `VarInfo` interface that we've neglected to document here, since the overhaul of `VarInfo` has been quite complete. +If anything related to `VarInfo` is behaving unexpectedly, e.g. the arguments or return type of a function seem to have changed, please check the docstring, which should be comprehensive. + +#### Performance benefits + +The purpose of this overhaul of `VarInfo` is code simplification and performance benefits. + +TODO(mhauru) Add some basic summary of what has gotten faster by how much. + ### Changes to indexing random variables with square brackets 0.40 internally reimplements how DynamicPPL handles random variables like `x[1]`, `x.y[2,2]`, and `x[:,1:4,5]`, i.e. ones that use indexing with square brackets. From 06f6c1efdb3f566a0f1e49162e7e75c5d5396fb9 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 14 Jan 2026 15:55:24 +0000 Subject: [PATCH 139/148] Apply suggestions from code review Co-authored-by: Penelope Yong --- src/varinfo.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 4dfa538c1..06623ca25 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -12,9 +12,9 @@ to be valid internal representations of the variable (i.e., whether the variable linked), as well as the size of the actual post-transformation value. These are all fields of [`TransformedValue`](@ref). -Note that `setindex!!` and `getindex` on `VarInfo` deal with the actual values of variables. -To get access to the internal vectorised values, use [`getindex_internal`](@ref), -[`setindex_internal!!`](@ref), and [`unflatten!!`](@ref). +Note that `setindex!!` and `getindex` on `VarInfo` take and return values in the support of +the original distribution. To get access to the internal vectorised values, use +[`getindex_internal`](@ref), [`setindex_internal!!`](@ref), and [`unflatten!!`](@ref). There's also a `VarInfo`-specific function [`setindex_with_dist!!`](@ref), which sets a variable's value with a transformation based on the statistical distribution this value is @@ -105,7 +105,7 @@ function Base.getindex(vi::VarInfo, vn::VarName) return tv.transform(tv.val) end -function Base.getindex(vi::VarInfo, vns::Vector{<:VarName}) +function Base.getindex(vi::VarInfo, vns::AbstractVector{<:VarName}) return [getindex(vi, vn) for vn in vns] end From 4f893bc1665169df66de4ae088fb81edc8bb81c8 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 14 Jan 2026 15:57:57 +0000 Subject: [PATCH 140/148] Use SkipSizeCheck rather than Val(:pass) --- src/varnamedtuple.jl | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 0287e393b..802b8222b 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -102,6 +102,13 @@ const INDEX_TYPES = Union{Integer,AbstractUnitRange,Colon,AbstractPPL.Concretize _unwrap_concretized_slice(cs::AbstractPPL.ConcretizedSlice) = cs.range _unwrap_concretized_slice(x::Union{Integer,AbstractUnitRange,Colon}) = x +""" + SkipSizeCheck() + +A special return value for `vnt_size` indicating that size checks should be skipped. +""" +struct SkipSizeCheck end + """ vnt_size(x) @@ -111,7 +118,7 @@ By default, this falls back onto `Base.size`, but can be overloaded for custom t This notion of type is used to determine whether a value can be set into a `PartialArray` as a block, see the docstring of `PartialArray` and `ArrayLikeBlock` for details. -A special return value of `Val(:pass)` indicates that the size check should be skipped. +A special return value of `SkipSizeCheck()` indicates that the size check should be skipped. """ vnt_size(x) = size(x) @@ -301,7 +308,7 @@ _internal_size(pa::PartialArray, args...) = size(pa.data, args...) # be stored as a PartialArray wrapped in an ArrayLikeBlock, stored in another PartialArray. # Note that this bypasses _any_ size checks, so that e.g. @varname(x[1:3][1,15]) is also a # valid key. -vnt_size(pa::PartialArray) = Val(:pass) +vnt_size(::PartialArray) = SkipSizeCheck() function Base.copy(pa::PartialArray) # Make a shallow copy of pa, except for any VarNamedTuple elements, which we recursively @@ -686,7 +693,7 @@ function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) new_data = pa.data if _needs_arraylikeblock(value, inds...) inds_size = reduce((x, y) -> tuple(x..., y...), map(size, inds)) - if vnt_size(value) !== Val(:pass) && vnt_size(value) != inds_size + if !(vnt_size(value) isa SkipSizeCheck) && vnt_size(value) != inds_size throw( DimensionMismatch( "Assigned value has size $(vnt_size(value)), which does not match " * @@ -1216,7 +1223,7 @@ function _map_recursive!!(func, alb::ArrayLikeBlock, vn) new_block = _map_recursive!!(func, alb.block, vn) sz_new = vnt_size(new_block) sz_old = vnt_size(alb.block) - if sz_new !== Val(:pass) && sz_old !== Val(:pass) && sz_new != sz_old + if !(sz_new isa SkipSizeCheck) && !(sz_old isa SkipSizeCheck) && sz_new != sz_old throw( DimensionMismatch( "map_pairs!! can't change the size of an ArrayLikeBlock. Tried to change " * From fdb1373c78a68b96069e2c1c2f5b9096a82a0a57 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 14 Jan 2026 16:01:27 +0000 Subject: [PATCH 141/148] Remove getindex with dist argument --- src/varinfo.jl | 5 ----- test/linking.jl | 6 +++--- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 06623ca25..3e026648b 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -109,11 +109,6 @@ function Base.getindex(vi::VarInfo, vns::AbstractVector{<:VarName}) return [getindex(vi, vn) for vn in vns] end -function Base.getindex(vi::VarInfo, vn::VarName, dist::Distribution) - val = getindex_internal(vi, vn) - return from_maybe_linked_internal(vi, vn, dist, val) -end - Base.isempty(vi::VarInfo) = isempty(vi.values) Base.empty(vi::VarInfo) = VarInfo(empty(vi.values), map(reset, vi.accs)) BangBang.empty!!(vi::VarInfo) = VarInfo(empty!!(vi.values), map(reset, vi.accs)) diff --git a/test/linking.jl b/test/linking.jl index 2047b9d11..bfd1285b1 100644 --- a/test/linking.jl +++ b/test/linking.jl @@ -89,7 +89,7 @@ end DynamicPPL.getlogjoint_internal(vi_linked) ≈ log(2) # The non-internal logjoint should be the same since it doesn't depend on linking. @test DynamicPPL.getlogjoint(vi) ≈ DynamicPPL.getlogjoint(vi_linked) - @test vi_linked[@varname(m), dist] == LowerTriangular(vi[@varname(m), dist]) + @test vi_linked[@varname(m)] == LowerTriangular(vi[@varname(m)]) # Linked one should be working with a lower-dimensional representation. @test length(vi_linked[:]) < length(vi[:]) @test length(vi_linked[:]) == length(y) @@ -100,7 +100,7 @@ end DynamicPPL.invlink(vi_linked, model) end @test length(vi_invlinked[:]) == length(vi[:]) - @test vi_invlinked[@varname(m), dist] ≈ LowerTriangular(vi[@varname(m), dist]) + @test vi_invlinked[@varname(m)] ≈ LowerTriangular(vi[@varname(m)]) # The non-internal logjoint should still be the same, again since # it doesn't depend on linking. @test DynamicPPL.getlogjoint(vi_invlinked) ≈ DynamicPPL.getlogjoint(vi) @@ -121,7 +121,7 @@ end model, values_original, (@varname(x),) ) @testset "$(short_varinfo_name(vi))" for vi in vis - val = vi[@varname(x), dist] + val = vi[@varname(x)] # Ensure that `reconstruct` works as intended. @test val isa Cholesky @test val.uplo == uplo From a023a7fc6c8a57d3941510afda1bff070d7a2dd0 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 14 Jan 2026 16:10:59 +0000 Subject: [PATCH 142/148] Simplify map and mapreduce for VNT --- src/varnamedtuple.jl | 57 ++++++++++++++++++++++---------------------- 1 file changed, 28 insertions(+), 29 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 802b8222b..3f758852e 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -1234,16 +1234,6 @@ function _map_recursive!!(func, alb::ArrayLikeBlock, vn) return ArrayLikeBlock(new_block, alb.inds) end -@generated function _map_recursive!!(func, vnt::VarNamedTuple{Names}) where {Names} - exs = Expr[] - for name in Names - push!(exs, :(_map_recursive!!(func, vnt.data.$name, VarName{$(QuoteNode(name))}()))) - end - return quote - return VarNamedTuple(NamedTuple{Names}(($(exs...),))) - end -end - # As above but with a prefix VarName `vn`. @generated function _map_recursive!!(func, vnt::VarNamedTuple{Names}, vn::T) where {Names,T} exs = Expr[] @@ -1267,7 +1257,17 @@ Apply `func` to all key => value pairs of `vnt`, in place if possible. `func` should accept a pair of `VarName` and value, and return the new value to be set. """ -map_pairs!!(func, vnt::VarNamedTuple) = _map_recursive!!(func, vnt) +@generated function map_pairs!!(func, vnt::VarNamedTuple{Names}) where {Names} + exs = Expr[] + for name in Names + push!(exs, :(_map_recursive!!(func, vnt.data.$name, VarName{$(QuoteNode(name))}()))) + end + return quote + return VarNamedTuple(NamedTuple{Names}(($(exs...),))) + end +end + +Base.foreach(func, vnt::VarNamedTuple) = map_pairs!!(p -> (func(p); p), vnt) """ map_values!!(func, vnt::VarNamedTuple) @@ -1289,26 +1289,19 @@ is not optional. `f` op` should accept pairs of `varname => value`. """ -function Base.mapreduce(f, op, vnt::VarNamedTuple; init=nothing) - if init === nothing - throw( - NotImplementedError( - "mapreduce without init is not implemented for VarNamedTuple." - ), - ) +@generated function Base.mapreduce( + f, op, vnt::VarNamedTuple{Names}; init::InitType=nothing +) where {Names,InitType} + if InitType === Nothing + return quote + throw( + ArgumentError( + "mapreduce without init is not implemented for VarNamedTuple." + ), + ) + end end - return _mapreduce_recursive(f, op, vnt, init) -end - -# Our mapreduce is always left-associative. -Base.mapfoldl(f, op, vnt::VarNamedTuple; init=nothing) = mapreduce(f, op, vnt; init=init) - -_mapreduce_recursive(f, op, x, vn, init) = op(init, f(vn => x)) -_mapreduce_recursive(f, op, pa::ArrayLikeBlock, vn, init) = op(init, f(vn => pa.block)) -@generated function _mapreduce_recursive( - f, op, vnt::VarNamedTuple{Names}, init -) where {Names} exs = Expr[:(result = init)] for name in Names push!( @@ -1324,6 +1317,12 @@ _mapreduce_recursive(f, op, pa::ArrayLikeBlock, vn, init) = op(init, f(vn => pa. return Expr(:block, exs...) end +# Our mapreduce is always left-associative. +Base.mapfoldl(f, op, vnt::VarNamedTuple; init=nothing) = mapreduce(f, op, vnt; init=init) + +_mapreduce_recursive(f, op, x, vn, init) = op(init, f(vn => x)) +_mapreduce_recursive(f, op, pa::ArrayLikeBlock, vn, init) = op(init, f(vn => pa.block)) + # As above but with a prefix VarName `vn`. @generated function _mapreduce_recursive( f, op, vnt::VarNamedTuple{Names}, vn, init From c369b0926bceefa859883daf4777280ec98b13b8 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 15 Jan 2026 10:28:16 +0000 Subject: [PATCH 143/148] Remove unused utility functions --- src/abstract_varinfo.jl | 8 +- src/accumulators.jl | 3 +- src/chains.jl | 2 +- src/contexts/init.jl | 13 +- src/logdensityfunction.jl | 10 +- src/utils.jl | 320 -------------------------------------- src/varinfo.jl | 2 +- src/varname.jl | 17 -- test/utils.jl | 23 --- test/varinfo.jl | 4 +- 10 files changed, 21 insertions(+), 381 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 67ac822cd..51341e3d4 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -727,7 +727,7 @@ See also: [`default_transformation`](@ref), [`invlink!!`](@ref). function link!!(vi::AbstractVarInfo, model::Model) return link!!(default_transformation(model, vi), vi, model) end -function link!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) +function link!!(vi::AbstractVarInfo, vns, model::Model) return link!!(default_transformation(model, vi), vi, vns, model) end @@ -746,7 +746,7 @@ See also: [`default_transformation`](@ref), [`invlink`](@ref). function link(vi::AbstractVarInfo, model::Model) return link(default_transformation(model, vi), vi, model) end -function link(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) +function link(vi::AbstractVarInfo, vns, model::Model) return link(default_transformation(model, vi), vi, vns, model) end function link(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) @@ -769,7 +769,7 @@ See also: [`default_transformation`](@ref), [`link!!`](@ref). function invlink!!(vi::AbstractVarInfo, model::Model) return invlink!!(default_transformation(model, vi), vi, model) end -function invlink!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) +function invlink!!(vi::AbstractVarInfo, vns, model::Model) return invlink!!(default_transformation(model, vi), vi, vns, model) end @@ -789,7 +789,7 @@ See also: [`default_transformation`](@ref), [`link`](@ref). function invlink(vi::AbstractVarInfo, model::Model) return invlink(default_transformation(model, vi), vi, model) end -function invlink(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) +function invlink(vi::AbstractVarInfo, vns, model::Model) return invlink(default_transformation(model, vi), vi, vns, model) end function invlink(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) diff --git a/src/accumulators.jl b/src/accumulators.jl index 0208f19a5..ae1c26094 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -118,8 +118,7 @@ See also: [`split`](@ref) """ function combine end -# TODO(mhauru) The existence of this function makes me sad. See comment in unflatten in -# src/varinfo.jl. +# TODO(mhauru) The existence of this function makes me sad. See comment in src/model.jl. """ convert_eltype(::Type{T}, acc::AbstractAccumulator) diff --git a/src/chains.jl b/src/chains.jl index ee4312547..cfd27d87a 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -113,7 +113,7 @@ Generate a `ParamsWithStats` by re-evaluating the given `ldf` with the provided `param_vector`. This method is intended to replace the old method of obtaining parameters and statistics -via `unflatten` plus re-evaluation. It is faster for two reasons: +via `unflatten!!` plus re-evaluation. It is faster for two reasons: 1. It does not rely on `deepcopy`-ing the VarInfo object (this used to be mandatory as otherwise re-evaluation would mutate the VarInfo, rendering it unusable for subsequent diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 65ea08ec5..d92bc35f8 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -56,12 +56,13 @@ used to determine whether the float type needs to be modified). In case that wasn't enough: in fact, even the above is not always true. Firstly, the accumulator argument is only true when evaluating with ThreadSafeVarInfo. See the comments -in `DynamicPPL.unflatten` for more details. For non-threadsafe evaluation, Julia is capable -of automatically promoting the types on its own. Secondly, the promotion only matters if you -are trying to directly assign into a `Vector{Float64}` with a `ForwardDiff.Dual` or similar -tracer type, for example using `xs[i] = MyDual`. This doesn't actually apply to -tilde-statements like `xs[i] ~ ...` because those use `Accessors.@set` under the hood, which -also does the promotion for you. For the gory details, see the following issues: +in `DynamicPPL.unflatten!!` for more details. For non-threadsafe evaluation, Julia is +capable of automatically promoting the types on its own. Secondly, the promotion only +matters if you are trying to directly assign into a `Vector{Float64}` with a +`ForwardDiff.Dual` or similar tracer type, for example using `xs[i] = MyDual`. This doesn't +actually apply to tilde-statements like `xs[i] ~ ...` because those use `Accessors.@set` +under the hood, which also does the promotion for you. For the gory details, see the +following issues: - https://github.com/TuringLang/DynamicPPL.jl/issues/906 for accumulator types - https://github.com/TuringLang/DynamicPPL.jl/issues/823 for type argument promotion diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 17101d0d2..6ae1dc3a1 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -87,9 +87,9 @@ from: Up until DynamicPPL v0.38, there have been two ways of evaluating a DynamicPPL model at a given set of parameters: -1. With `unflatten` + `evaluate!!` with `DefaultContext`: this stores a vector of parameters - inside a VarInfo's metadata, then reads parameter values from the VarInfo during - evaluation. +1. With `unflatten!!` + `evaluate!!` with `DefaultContext`: this stores a vector of + parameters inside a VarInfo's metadata, then reads parameter values from the VarInfo + during evaluation. 2. With `InitFromParams`: this reads parameter values from a NamedTuple or a Dict, and stores them inside a VarInfo's metadata. @@ -114,7 +114,7 @@ In particular, it is not clear: - which parts of the vector correspond to which random variables, and - whether the variables are linked or unlinked. -Traditionally, this problem has been solved by `unflatten`, because that function would +Traditionally, this problem has been solved by `unflatten!!`, because that function would place values into the VarInfo's metadata alongside the information about ranges and linking. That way, when we evaluate with `DefaultContext`, we can read this information out again. However, we want to avoid using a metadata. Thus, here, we _extract this information from @@ -131,7 +131,7 @@ the `LogDensityFunction` object. Therefore, a `LogDensityFunction` object cannot models which have variable numbers of parameters, or models which may visit random variables in different orders depending on stochastic control flow. **Indeed, silent errors may occur with such models.** This is a general limitation of vectorised parameters: the original -`unflatten` + `evaluate!!` approach also fails with such models. +`unflatten!!` + `evaluate!!` approach also fails with such models. """ struct LogDensityFunction{ # true if all variables are linked; false if all variables are unlinked; nothing if diff --git a/src/utils.jl b/src/utils.jl index 4a0eea96c..f0f46157b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -2,13 +2,6 @@ # defined in other files. function subset end -# singleton for indicating if no default arguments are present -struct NoDefault end -const NO_DEFAULT = NoDefault() - -# A short-hand for a type commonly used in type signatures for VarInfo methods. -VarNameTuple = NTuple{N,VarName} where {N} - """ The type for all log probability variables. @@ -545,275 +538,6 @@ tovec(t::Tuple) = mapreduce(tovec, vcat, t) tovec(nt::NamedTuple) = mapreduce(tovec, vcat, values(nt)) tovec(C::Cholesky) = tovec(Matrix(C.UL)) -""" - recombine(dist::Union{UnivariateDistribution,MultivariateDistribution}, vals::AbstractVector, n::Int) - -Recombine `vals`, representing a batch of samples from `dist`, so that it's a compatible with `dist`. - -!!! warning - This only supports `UnivariateDistribution` and `MultivariateDistribution`, which are the only two - distribution types which are allowed on the right-hand side of a `.~` statement in a model. -""" -function recombine(::UnivariateDistribution, val::AbstractVector, ::Int) - # This is just a no-op, since we're trying to convert a vector into a vector. - return copy(val) -end -function recombine(d::MultivariateDistribution, val::AbstractVector, n::Int) - # Here `val` is of the length `length(d) * n` and so we need to reshape it. - return copy(reshape(val, length(d), n)) -end - -####################### -# Convenience methods # -####################### -""" - collect_maybe(x) - -Return `x` if `x` is an array, otherwise return `collect(x)`. -""" -collect_maybe(x) = collect(x) -collect_maybe(x::AbstractArray) = x - -####################### -# BangBang.jl related # -####################### -function set!!(obj, optic::AbstractPPL.ALLOWED_OPTICS, value) - opticmut = BangBang.prefermutation(optic) - return Accessors.set(obj, opticmut, value) -end -function set!!(obj, vn::VarName{sym}, value) where {sym} - optic = BangBang.prefermutation( - AbstractPPL.getoptic(vn) ∘ Accessors.PropertyLens{sym}() - ) - return Accessors.set(obj, optic, value) -end - -############################# -# AbstractPPL.jl extensions # -############################# -# This is preferable to `haskey` because the order of arguments is different, and -# we're more likely to specialize on the key in these settings rather than the container. -# TODO: I'm not sure about this name. -""" - canview(optic, container) - -Return `true` if `optic` can be used to view `container`, and `false` otherwise. - -# Examples -```jldoctest; setup=:(using Accessors; using DynamicPPL: canview) -julia> canview(@o(_.a), (a = 1.0, )) -true - -julia> canview(@o(_.a), (b = 1.0, )) # property `a` does not exist -false - -julia> canview(@o(_.a[1]), (a = [1.0, 2.0], )) -true - -julia> canview(@o(_.a[3]), (a = [1.0, 2.0], )) # out of bounds -false -``` -""" -canview(optic, container) = false -canview(::typeof(identity), _) = true -function canview(optic::Accessors.PropertyLens{field}, x) where {field} - return hasproperty(x, field) -end - -# `IndexLens`: only relevant if `x` supports indexing. -canview(optic::Accessors.IndexLens, x) = false -function canview(optic::Accessors.IndexLens, x::AbstractArray) - return checkbounds(Bool, x, optic.indices...) -end - -# `ComposedOptic`: check that we can view `.inner` and `.outer`, but using -# value extracted using `.inner`. -function canview(optic::Accessors.ComposedOptic, x) - return canview(optic.inner, x) && canview(optic.outer, optic.inner(x)) -end - -""" - parent(vn::VarName) - -Return the parent `VarName`. - -# Examples -```julia-repl; setup=:(using DynamicPPL: parent) -julia> parent(@varname(x.a[1])) -x.a - -julia> (parent ∘ parent)(@varname(x.a[1])) -x - -julia> (parent ∘ parent ∘ parent)(@varname(x.a[1])) -x -``` -""" -function parent(vn::VarName) - p = parent(getoptic(vn)) - return p === nothing ? VarName{getsym(vn)}(identity) : VarName{getsym(vn)}(p) -end - -""" - parent(optic) - -Return the parent optic. If `optic` doesn't have a parent, -`nothing` is returned. - -See also: [`parent_and_child`]. - -# Examples -```jldoctest; setup=:(using Accessors; using DynamicPPL: parent) -julia> parent(@o(_.a[1])) -(@o _.a) - -julia> # Parent of optic without parents results in `nothing`. - (parent ∘ parent)(@o(_.a[1])) === nothing -true -``` -""" -parent(optic::AbstractPPL.ALLOWED_OPTICS) = first(parent_and_child(optic)) - -""" - parent_and_child(optic) - -Return a 2-tuple of optics `(parent, child)` where `parent` is the -parent optic of `optic` and `child` is the child optic of `optic`. - -If `optic` does not have a parent, we return `(nothing, optic)`. - -See also: [`parent`]. - -# Examples -```jldoctest; setup=:(using Accessors; using DynamicPPL: parent_and_child) -julia> parent_and_child(@o(_.a[1])) -((@o _.a), (@o _[1])) - -julia> parent_and_child(@o(_.a)) -(nothing, (@o _.a)) -``` -""" -parent_and_child(optic::AbstractPPL.ALLOWED_OPTICS) = (nothing, optic) -function parent_and_child(optic::Accessors.ComposedOptic) - p, child = parent_and_child(optic.outer) - parent = p === nothing ? optic.inner : p ∘ optic.inner - return parent, child -end - -""" - splitoptic(condition, optic) - -Return a 3-tuple `(parent, child, issuccess)` where, if `issuccess` is `true`, -`parent` is a optic such that `condition(parent)` is `true` and `child ∘ parent == optic`. - -If `issuccess` is `false`, then no such split could be found. - -# Examples -```jldoctest; setup=:(using Accessors; using DynamicPPL: splitoptic) -julia> p, c, issucesss = splitoptic(@o(_.a[1])) do parent - # Succeeds! - parent == @o(_.a) - end -((@o _.a), (@o _[1]), true) - -julia> c ∘ p -(@o _.a[1]) - -julia> splitoptic(@o(_.a[1])) do parent - # Fails! - parent == @o(_.b) - end -(nothing, (@o _.a[1]), false) -``` -""" -function splitoptic(condition, optic) - current_parent, current_child = parent_and_child(optic) - # We stop if either a) `condition` is satisfied, or b) we reached the root. - while !condition(current_parent) && current_parent !== nothing - current_parent, c = parent_and_child(current_parent) - current_child = current_child ∘ c - end - - return current_parent, current_child, condition(current_parent) -end - -""" - remove_parent_optic(vn_parent::VarName, vn_child::VarName) - -Remove the parent optic `vn_parent` from `vn_child`. - -# Examples -```jldoctest; setup = :(using Accessors; using DynamicPPL: remove_parent_optic) -julia> remove_parent_optic(@varname(x), @varname(x.a)) -(@o _.a) - -julia> remove_parent_optic(@varname(x), @varname(x.a[1])) -(@o _.a[1]) - -julia> remove_parent_optic(@varname(x.a), @varname(x.a[1])) -(@o _[1]) - -julia> remove_parent_optic(@varname(x.a), @varname(x.a[1].b)) -(@o _[1].b) - -julia> remove_parent_optic(@varname(x.a), @varname(x.a)) -ERROR: Could not find x.a in x.a - -julia> remove_parent_optic(@varname(x.a[2]), @varname(x.a[1])) -ERROR: Could not find x.a[2] in x.a[1] -``` -""" -function remove_parent_optic(vn_parent::VarName{sym}, vn_child::VarName{sym}) where {sym} - _, child, issuccess = splitoptic(getoptic(vn_child)) do optic - o = optic === nothing ? identity : optic - o == getoptic(vn_parent) - end - - issuccess || error("Could not find $vn_parent in $vn_child") - return child -end - -# HACK(torfjelde): This makes it so it works on iterators, etc. by default. -# TODO(torfjelde): Do better. -""" - unflatten(original, x::AbstractVector) - -Return instance of `original` constructed from `x`. -""" -function unflatten(original, x::AbstractVector) - lengths = map(length, original) - end_indices = cumsum(lengths) - return map(zip(original, lengths, end_indices)) do (v, l, end_idx) - start_idx = end_idx - l + 1 - return unflatten(v, @view(x[start_idx:end_idx])) - end -end - -unflatten(::Real, x::Real) = x -unflatten(::Real, x::AbstractVector) = only(x) -unflatten(::AbstractVector{<:Real}, x::Real) = vcat(x) -unflatten(::AbstractVector{<:Real}, x::AbstractVector) = x -unflatten(original::AbstractArray{<:Real}, x::AbstractVector) = reshape(x, size(original)) - -function unflatten(original::Tuple, x::AbstractVector) - lengths = map(length, original) - end_indices = cumsum(lengths) - return ntuple(length(original)) do i - v = original[i] - l = lengths[i] - end_idx = end_indices[i] - start_idx = end_idx - l + 1 - return unflatten(v, @view(x[start_idx:end_idx])) - end -end -function unflatten(original::NamedTuple{names}, x::AbstractVector) where {names} - return NamedTuple{names}(unflatten(values(original), x)) -end -function unflatten(original::AbstractDict, x::AbstractVector) - D = ConstructionBase.constructorof(typeof(original)) - return D(zip(keys(original), unflatten(collect(values(original)), x))) -end - """ update_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns) @@ -916,50 +640,6 @@ _merge(left::AbstractDict, right::NamedTuple) = merge(left, to_varname_dict(righ _merge(::NamedTuple{()}, right::AbstractDict) = right _merge(left::NamedTuple, right::AbstractDict) = merge(to_varname_dict(left), right) -""" - unique_syms(vns::T) where {T<:NTuple{N,VarName}} - -Return the unique symbols of the variables in `vns`. - -Note that `unique_syms` is only defined for `Tuple`s of `VarName`s and, unlike -`Base.unique`, returns a `Tuple`. The point of `unique_syms` is that it supports constant -propagating the result, which is possible only when the argument and the return value are -`Tuple`s. -""" -@generated function unique_syms(::T) where {T<:VarNameTuple} - retval = Expr(:tuple) - syms = [first(vn.parameters) for vn in T.parameters] - for sym in unique(syms) - push!(retval.args, QuoteNode(sym)) - end - return retval -end - -""" - group_varnames_by_symbol(vns::NTuple{N,VarName}) where {N} - -Return a `NamedTuple` of the variables in `vns` grouped by symbol. - -Note that `group_varnames_by_symbol` only accepts a `Tuple` of `VarName`s. This allows it to -be type stable. - -Example: -```julia -julia> vns_tuple = (@varname(x), @varname(y[1]), @varname(x.a), @varname(z[15]), @varname(y[2])) -(x, y[1], x.a, z[15], y[2]) - -julia> vns_nt = (; x=[@varname(x), @varname(x.a)], y=[@varname(y[1]), @varname(y[2])], z=[@varname(z[15])]) -(x = VarName{:x}[x, x.a], y = VarName{:y, IndexLens{Tuple{Int64}}}[y[1], y[2]], z = VarName{:z, IndexLens{Tuple{Int64}}}[z[15]]) - -julia> group_varnames_by_symbol(vns_tuple) == vns_nt -``` -""" -function group_varnames_by_symbol(vns::VarNameTuple) - syms = unique_syms(vns) - elements = map(collect, tuple((filter(vn -> getsym(vn) == s, vns) for s in syms)...)) - return NamedTuple{syms}(elements) -end - """ basetypeof(x) diff --git a/src/varinfo.jl b/src/varinfo.jl index 3e026648b..860fb7372 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -406,7 +406,7 @@ function unflatten!!(vi::VarInfo, vec::AbstractVector) old_val = tv.val if !(old_val isa AbstractVector) error( - "Can't unflatten a VarInfo for which existing values are not vectors:" * + "Can't unflatten!! a VarInfo for which existing values are not vectors:" * " Got value of type $(typeof(old_val)).", ) end diff --git a/src/varname.jl b/src/varname.jl index 7ffe9cc08..e1492bb32 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -1,20 +1,3 @@ -""" - subsumes_string(u::String, v::String[, u_indexing]) - -Check whether stringified variable name `v` describes a sub-range of stringified variable `u`. - -This is a very restricted version `subumes(u::VarName, v::VarName)` only really supporting: -- Scalar: `x` subsumes `x[1, 2]`, `x[1, 2]` subsumes `x[1, 2][3]`, etc. - -## Note -- To get same matching capabilities as `AbstractPPL.subumes(u::VarName, v::VarName)` - for strings, one can always do `eval(varname(Meta.parse(u))` to get `VarName` of `u`, - and similarly to `v`. But this is slow. -""" -function subsumes_string(u::String, v::String, u_indexing=u * "[") - return u == v || startswith(v, u_indexing) -end - """ inargnames(varname::VarName, model::Model) diff --git a/test/utils.jl b/test/utils.jl index bef1c2ba8..bc01fc0ce 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -186,29 +186,6 @@ end t = (2.0, [3.0, 4.0]) @test DynamicPPL.tovec(t) == [2.0, 3.0, 4.0] end - - @testset "unique_syms" begin - vns = (@varname(x), @varname(y[1]), @varname(x.a), @varname(z[15]), @varname(y[2])) - @inferred DynamicPPL.unique_syms(vns) - @inferred DynamicPPL.unique_syms(()) - @test DynamicPPL.unique_syms(vns) == (:x, :y, :z) - @test DynamicPPL.unique_syms(()) == () - end - - @testset "group_varnames_by_symbol" begin - vns_tuple = ( - @varname(x), @varname(y[1]), @varname(x.a), @varname(z[15]), @varname(y[2]) - ) - vns_vec = collect(vns_tuple) - vns_nt = (; - x=[@varname(x), @varname(x.a)], - y=[@varname(y[1]), @varname(y[2])], - z=[@varname(z[15])], - ) - vns_vec_single_symbol = [@varname(x.a), @varname(x.b), @varname(x[1])] - @inferred DynamicPPL.group_varnames_by_symbol(vns_tuple) - @test DynamicPPL.group_varnames_by_symbol(vns_tuple) == vns_nt - end end end diff --git a/test/varinfo.jl b/test/varinfo.jl index 8ae0535c7..639b4f688 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -346,7 +346,7 @@ end end end - @testset "unflatten + linking" begin + @testset "unflatten!! + linking" begin @testset "Model: $(model.f)" for model in [ DynamicPPL.TestUtils.demo_one_variable_multiple_constraints(), DynamicPPL.TestUtils.demo_lkjchol(), @@ -403,7 +403,7 @@ end end end - @testset "unflatten type stability" begin + @testset "unflatten!! type stability" begin @model function demo(y) x ~ Normal() y ~ Normal(x, 1) From 6128a562203936a5d6da71cf4e0b19d84cbff289 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 15 Jan 2026 13:16:49 +0000 Subject: [PATCH 144/148] Use OnlyAccsVarInfo in extract_priors --- src/extract_priors.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/extract_priors.jl b/src/extract_priors.jl index 8c7b5f7db..182e933e4 100644 --- a/src/extract_priors.jl +++ b/src/extract_priors.jl @@ -121,7 +121,7 @@ julia> length(extract_priors(rng, model)[@varname(x)]) extract_priors(args::Union{Model,AbstractVarInfo}...) = extract_priors(Random.default_rng(), args...) function extract_priors(rng::Random.AbstractRNG, model::Model) - varinfo = VarInfo() + varinfo = OnlyAccsVarInfo() varinfo = setaccs!!(varinfo, (PriorDistributionAccumulator(),)) varinfo = last(init!!(rng, model, varinfo)) return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors From 8fddfef6ce719f8819f2a32b6504bd27a9c00196 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 15 Jan 2026 13:25:37 +0000 Subject: [PATCH 145/148] Make linking status a type parameter of VarInfo --- src/logdensityfunction.jl | 2 +- src/varinfo.jl | 132 +++++++++++++++++++++++++++----------- test/varinfo.jl | 6 +- 3 files changed, 99 insertions(+), 41 deletions(-) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 6ae1dc3a1..9337d159c 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -313,7 +313,7 @@ function get_ranges_and_linked(vi::VarInfo) val = tv.val range = offset:(offset + length(val) - 1) offset += length(val) - ral = RangeAndLinked(range, tv.linked, tv.size) + ral = RangeAndLinked(range, is_transformed(tv), tv.size) vnt = setindex!!(vnt, ral, vn) return vnt, offset end, diff --git a/src/varinfo.jl b/src/varinfo.jl index 860fb7372..a59837ba0 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1,8 +1,12 @@ """ - VarInfo{T<:VarNamedTuple,Accs<:AccumulatorTuple} <: AbstractVarInfo + VarInfo{Linked,T<:VarNamedTuple,Accs<:AccumulatorTuple} <: AbstractVarInfo The default implementation of `AbstractVarInfo`, storing variable values and accumulators. +The `Linked` type parameter is either `true` or `false` to mark that all variables in this +`VarInfo` are linked, or `nothing` to indicate that some variables may be linked and some +not, and a runtime check is needed. + `VarInfo` is quite a thin wrapper around a `VarNamedTuple` storing the variable values, and a tuple of accumulators. The only really noteworthy thing about it is that it stores the values of variables vectorised as instances of `TransformedValue`. That is, it stores @@ -27,9 +31,15 @@ For more details on the internal storage, see documentation of [`TransformedValu $(TYPEDFIELDS) """ -struct VarInfo{T<:VarNamedTuple,Accs<:AccumulatorTuple} <: AbstractVarInfo +struct VarInfo{Linked,T<:VarNamedTuple,Accs<:AccumulatorTuple} <: AbstractVarInfo values::T accs::Accs + + function VarInfo{Linked}( + values::T, accs::Accs + ) where {Linked,T<:VarNamedTuple,Accs<:AccumulatorTuple} + return new{Linked,T,Accs}(values, accs) + end end # TODO(mhauru) The policy of vectorising all values was set when the old VarInfo type was @@ -42,31 +52,43 @@ end # during model execution. However, this would change the interface quite a lot, so I want to # finish implementing VarInfo using VNT (mostly) respecting the old interface first. +# TODO(mhauru) We are considering removing `transform` completely, and forcing people to use +# ValuesAsInModelAcc instead. If that is done, we may want to move the Linked type parameter +# to just be a bool field. It's currently a type parameter to make the type of `transform` +# easier to type infer, but if `transform` no longer exists, it might start to cause +# unnecessary type inconcreteness in the elements of PartialArray. """ - TransformedValue{ValType,TransformType,SizeType} + TransformedValue{Linked,ValType,TransformType,SizeType} A struct for storing a variable's value in its internal (vectorised) form. +The type parameter `Linked` is a `Bool` indicating whether the variable is linked, i.e. +whether the transformation maps all real vectors to valid values. # Fields $(TYPEDFIELDS) """ -struct TransformedValue{ValType,TransformType,SizeType} +struct TransformedValue{Linked,ValType,TransformType,SizeType} "The internal (vectorised) value." val::ValType - """Boolean indicating whether the variable is linked, i.e. the transformation maps all - real vectors to valid values.""" - linked::Bool """The transformation from internal (vectorised) to actual value. In other words, the actual value of the variable being stored is `transform(val)`.""" transform::TransformType """The size of the actual value after transformation. This is needed when a `TransformedValue` is stored as a block in an array.""" size::SizeType + + function TransformedValue{Linked}( + val::ValType, transform::TransformType, size::SizeType + ) where {Linked,ValType,TransformType,SizeType} + return new{Linked,ValType,TransformType,SizeType}(val, transform, size) + end end +is_transformed(::TransformedValue{Linked}) where {Linked} = Linked + VarNamedTuples.vnt_size(tv::TransformedValue) = tv.size -VarInfo() = VarInfo(VarNamedTuple(), default_accumulators()) +VarInfo() = VarInfo{false}(VarNamedTuple(), default_accumulators()) function VarInfo(values::Union{NamedTuple,AbstractDict}) vi = VarInfo() @@ -90,11 +112,15 @@ function VarInfo( end getaccs(vi::VarInfo) = vi.accs -setaccs!!(vi::VarInfo, accs::AccumulatorTuple) = VarInfo(vi.values, accs) +function setaccs!!(vi::VarInfo{Linked}, accs::AccumulatorTuple) where {Linked} + return VarInfo{Linked}(vi.values, accs) +end transformation(::VarInfo) = DynamicTransformation() -Base.copy(vi::VarInfo) = VarInfo(copy(vi.values), copy(getaccs(vi))) +function Base.copy(vi::VarInfo{Linked}) where {Linked} + return VarInfo{Linked}(copy(vi.values), copy(getaccs(vi))) +end Base.haskey(vi::VarInfo, vn::VarName) = haskey(vi.values, vn) Base.length(vi::VarInfo) = length(vi.values) Base.keys(vi::VarInfo) = keys(vi.values) @@ -110,8 +136,8 @@ function Base.getindex(vi::VarInfo, vns::AbstractVector{<:VarName}) end Base.isempty(vi::VarInfo) = isempty(vi.values) -Base.empty(vi::VarInfo) = VarInfo(empty(vi.values), map(reset, vi.accs)) -BangBang.empty!!(vi::VarInfo) = VarInfo(empty!!(vi.values), map(reset, vi.accs)) +Base.empty(vi::VarInfo) = VarInfo{false}(empty(vi.values), map(reset, vi.accs)) +BangBang.empty!!(vi::VarInfo) = VarInfo{false}(empty!!(vi.values), map(reset, vi.accs)) """ setindex_internal!!(vi::VarInfo, val, vn::VarName) @@ -120,11 +146,11 @@ Set the internal (vectorised) value of variable `vn` in `vi` to `val`. This does not change the transformation or linked status of the variable. """ -function setindex_internal!!(vi::VarInfo, val, vn::VarName) +function setindex_internal!!(vi::VarInfo{Linked}, val, vn::VarName) where {Linked} old_tv = getindex(vi.values, vn) - new_tv = TransformedValue(val, old_tv.linked, old_tv.transform, old_tv.size) + new_tv = TransformedValue{is_transformed(old_tv)}(val, old_tv.transform, old_tv.size) new_values = setindex!!(vi.values, new_tv, vn) - return VarInfo(new_values, vi.accs) + return VarInfo{Linked}(new_values, vi.accs) end # TODO(mhauru) It shouldn't really be VarInfo's business to know about `dist`. However, @@ -147,8 +173,14 @@ used. Returns the modified `vi` together with the log absolute determinant of the Jacobian of the transformation applied. """ -function setindex_with_dist!!(vi::VarInfo, val, dist::Distribution, vn::VarName) - link = haskey(vi, vn) ? is_transformed(vi, vn) : is_transformed(vi) +function setindex_with_dist!!( + vi::VarInfo{Linked}, val, dist::Distribution, vn::VarName +) where {Linked} + link = if Linked === nothing + haskey(vi, vn) ? is_transformed(vi, vn) : is_transformed(vi) + else + Linked + end transform = if link from_linked_vec_transform(dist) else @@ -157,8 +189,9 @@ function setindex_with_dist!!(vi::VarInfo, val, dist::Distribution, vn::VarName) transformed_val, logjac = with_logabsdet_jacobian(inverse(transform), val) # All values for which `size` is not defined are assumed to be scalars. val_size = hasmethod(size, Tuple{typeof(val)}) ? size(val) : () - tv = TransformedValue(transformed_val, link, transform, val_size) - vi = VarInfo(setindex!!(vi.values, tv, vn), vi.accs) + tv = TransformedValue{link}(transformed_val, transform, val_size) + new_linked = Linked == link ? Linked : nothing + vi = VarInfo{new_linked}(setindex!!(vi.values, tv, vn), vi.accs) return vi, logjac end @@ -174,11 +207,12 @@ Set the value of `vn` in `vi` to `val`. The transformation for `vn` is reset to be the standard vector transformation for values of the type of `val` and linking status is set to false. """ -function BangBang.setindex!!(vi::VarInfo, val, vn::VarName) +function BangBang.setindex!!(vi::VarInfo{Linked}, val, vn::VarName) where {Linked} + new_linked = Linked == false ? false : nothing transform = from_vec_transform(val) transformed_val = inverse(transform)(val) - tv = TransformedValue(transformed_val, false, transform, size(val)) - return VarInfo(setindex!!(vi.values, tv, vn), vi.accs) + tv = TransformedValue{false}(transformed_val, transform, size(val)) + return VarInfo{new_linked}(setindex!!(vi.values, tv, vn), vi.accs) end """ @@ -188,11 +222,14 @@ Set the linked status of variable `vn` in `vi` to `linked`. This does not change the value or transformation of the variable. """ -function set_transformed!!(vi::VarInfo, linked::Bool, vn::VarName) +function set_transformed!!(vi::VarInfo{Linked}, linked::Bool, vn::VarName) where {Linked} old_tv = getindex(vi.values, vn) - new_tv = TransformedValue(old_tv.val, linked, old_tv.transform, old_tv.size) + new_tv = TransformedValue{linked}(old_tv.val, old_tv.transform, old_tv.size) new_values = setindex!!(vi.values, new_tv, vn) - return VarInfo(new_values, vi.accs) + # The below check shouldn't ever pass, this should always result in `nothing`, but may + # as well play it safe, it'll be constant propagated away anyway. + new_linked = Linked == linked ? Linked : nothing + return VarInfo{new_linked}(new_values, vi.accs) end # VarInfo does not care whether the transformation was Static or Dynamic, it just tracks @@ -211,9 +248,9 @@ set_transformed!!(vi::VarInfo, ::NoTransformation) = set_transformed!!(vi, false function set_transformed!!(vi::VarInfo, linked::Bool) new_values = map_values!!(vi.values) do tv - TransformedValue(tv.val, linked, tv.transform, tv.size) + TransformedValue{linked}(tv.val, tv.transform, tv.size) end - return VarInfo(new_values, vi.accs) + return VarInfo{linked}(new_values, vi.accs) end """ @@ -225,7 +262,13 @@ getindex_internal(vi::VarInfo, vn::VarName) = getindex(vi.values, vn).val # TODO(mhauru) The below should be removed together with unflatten!!. getindex_internal(vi::VarInfo, ::Colon) = values_as(vi, Vector) -is_transformed(vi::VarInfo, vn::VarName) = getindex(vi.values, vn).linked +function is_transformed(vi::VarInfo{Linked}, vn::VarName) where {Linked} + return if Linked === nothing + is_transformed(getindex(vi.values, vn)) + else + Linked + end +end function from_internal_transform(::VarInfo, ::VarName, dist::Distribution) return from_vec_transform(dist) @@ -273,7 +316,7 @@ function _link_or_invlink!!(vi::VarInfo, vns, model::Model, ::Val{link}) where { # Not one of the target variables. return tv end - if tv.linked == link + if is_transformed(tv) == link # Already in the desired state. return tv end @@ -289,11 +332,17 @@ function _link_or_invlink!!(vi::VarInfo, vns, model::Model, ::Val{link}) where { val_new, logjac2 = with_logabsdet_jacobian( inverse(new_transform), val_untransformed ) - new_tv = TransformedValue(val_new, link, new_transform, tv.size) + # !is_transformed(tv) is the same as `link`, but might be easier for type inference. + new_tv = TransformedValue{!is_transformed(tv)}(val_new, new_transform, tv.size) cumulative_logjac += logjac1 + logjac2 return new_tv end - vi = VarInfo(new_values, vi.accs) + vi_linked = if vns === nothing + link + else + nothing + end + vi = VarInfo{vi_linked}(new_values, vi.accs) if hasacc(vi, Val(:LogJacobian)) vi = acclogjac!!(vi, cumulative_logjac) end @@ -397,7 +446,7 @@ function get_next_chunk!(vci::VectorChunkIterator, len::Int) return chunk end -function unflatten!!(vi::VarInfo, vec::AbstractVector) +function unflatten!!(vi::VarInfo{Linked}, vec::AbstractVector) where {Linked} # You may wonder, why have a whole struct for this, rather than just an index variable # that the mapping function would close over. I wonder too. But for some reason type # inference fails on such an index variable, turning it into a Core.Box. @@ -412,9 +461,9 @@ function unflatten!!(vi::VarInfo, vec::AbstractVector) end len = length(old_val) new_val = get_next_chunk!(vci, len) - return TransformedValue(new_val, tv.linked, tv.transform, tv.size) + return TransformedValue{is_transformed(tv)}(new_val, tv.transform, tv.size) end - return VarInfo(new_values, vi.accs) + return VarInfo{Linked}(new_values, vi.accs) end """ @@ -424,9 +473,9 @@ Create a new `VarInfo` containing only the variables in `vns`. `vns` can be almost any collection of `VarName`s, e.g. a `Set`, `Vector`, or `Tuple`. """ -function subset(varinfo::VarInfo, vns) +function subset(varinfo::VarInfo{Linked}, vns) where {Linked} new_values = subset(varinfo.values, vns) - return VarInfo(new_values, map(copy, getaccs(varinfo))) + return VarInfo{Linked}(new_values, map(copy, getaccs(varinfo))) end """ @@ -439,8 +488,15 @@ The accumulators are taken exclusively from `varinfo_right`. If a variable exists in both `varinfo_left` and `varinfo_right`, the value from `varinfo_right` is used. """ -function Base.merge(varinfo_left::VarInfo, varinfo_right::VarInfo) +function Base.merge( + varinfo_left::VarInfo{LinkedLeft}, varinfo_right::VarInfo{LinkedRight} +) where {LinkedLeft,LinkedRight} new_values = merge(varinfo_left.values, varinfo_right.values) new_accs = map(copy, getaccs(varinfo_right)) - return VarInfo(new_values, new_accs) + new_linked = if LinkedLeft == LinkedRight + LinkedLeft + else + nothing + end + return VarInfo{new_linked}(new_values, new_accs) end diff --git a/test/varinfo.jl b/test/varinfo.jl index 639b4f688..9fb8c6d4d 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -265,8 +265,10 @@ end _, vi = DynamicPPL.init!!(model, vi, InitFromUniform()) vals = values(vi) - all_transformed(vi) = mapreduce(p -> p.second.linked, &, vi.values; init=true) - any_transformed(vi) = mapreduce(p -> p.second.linked, |, vi.values; init=false) + all_transformed(vi) = + mapreduce(p -> is_transformed(p.second), &, vi.values; init=true) + any_transformed(vi) = + mapreduce(p -> is_transformed(p.second), |, vi.values; init=false) @test !any_transformed(vi) From aa3adb327fd2313f97dffd4df2e20a3f207cfcdc Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 15 Jan 2026 14:47:39 +0000 Subject: [PATCH 146/148] Fix a typo Co-authored-by: Penelope Yong --- HISTORY.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/HISTORY.md b/HISTORY.md index d6c13f7a6..c3a704552 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -14,7 +14,7 @@ Almost all of the below changes are the consequence from switching over to using ### Overhaul of `VarInfo` DynamicPPL tracks variable values during model execution using one of the `AbstractVarInfo` types. -Previously, there were many versions of them: `VarInfo`, both "typed" and "untyped, and `SimpleVarInfo` with both `NamedTuple` and `OrderedDict` as storage backends. +Previously, there were many versions of them: `VarInfo`, both "typed" and "untyped", and `SimpleVarInfo` with both `NamedTuple` and `OrderedDict` as storage backends. These have all been replaced by a rewritten implementation of `VarInfo`. While the basics of the `VarInfo` interface remain the same, this brings with it many changes: From 0c03233daa9d4b09345087ba1786d60da57114af Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 15 Jan 2026 14:55:03 +0000 Subject: [PATCH 147/148] Simplify code Co-authored-by: Penelope Yong --- src/extract_priors.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/extract_priors.jl b/src/extract_priors.jl index 182e933e4..def2b7756 100644 --- a/src/extract_priors.jl +++ b/src/extract_priors.jl @@ -121,8 +121,7 @@ julia> length(extract_priors(rng, model)[@varname(x)]) extract_priors(args::Union{Model,AbstractVarInfo}...) = extract_priors(Random.default_rng(), args...) function extract_priors(rng::Random.AbstractRNG, model::Model) - varinfo = OnlyAccsVarInfo() - varinfo = setaccs!!(varinfo, (PriorDistributionAccumulator(),)) + varinfo = OnlyAccsVarInfo((PriorDistributionAccumulator(),)) varinfo = last(init!!(rng, model, varinfo)) return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors end From 39df57b586e9b58d7b954bb78d758335ece312d3 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 15 Jan 2026 15:01:20 +0000 Subject: [PATCH 148/148] Fix comments, remove dead line --- src/logdensityfunction.jl | 2 -- src/varinfo.jl | 6 ++++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 9337d159c..7cb84cbc2 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -304,8 +304,6 @@ This function returns a VarNamedTuple mapping all VarNames to their correspondin `RangeAndLinked`. """ function get_ranges_and_linked(vi::VarInfo) - # TODO(mhauru) Check that the closure doesn't cause type instability here. - vnt = VarNamedTuple() vnt, _ = mapreduce( identity, function ((vnt, offset), pair) diff --git a/src/varinfo.jl b/src/varinfo.jl index a59837ba0..191537ad8 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -226,8 +226,6 @@ function set_transformed!!(vi::VarInfo{Linked}, linked::Bool, vn::VarName) where old_tv = getindex(vi.values, vn) new_tv = TransformedValue{linked}(old_tv.val, old_tv.transform, old_tv.size) new_values = setindex!!(vi.values, new_tv, vn) - # The below check shouldn't ever pass, this should always result in `nothing`, but may - # as well play it safe, it'll be constant propagated away anyway. new_linked = Linked == linked ? Linked : nothing return VarInfo{new_linked}(new_values, vi.accs) end @@ -496,6 +494,10 @@ function Base.merge( new_linked = if LinkedLeft == LinkedRight LinkedLeft else + # TODO(mhauru) Consider doing something more clever here, e.g. checking whether + # either varinfo_left or varinfo_right is empty, or actually iterating over all the + # values to check their linked status. Needs to balance keeping the type parameter + # alive vs runtime costs. nothing end return VarInfo{new_linked}(new_values, new_accs)