From 35c3e20b1247493f74b38f582c170a4ab4957dca Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 11 Dec 2025 19:00:30 +0000 Subject: [PATCH 01/26] ArrayLikeBlock WIP --- src/varnamedtuple.jl | 37 ++++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index db2462e53..0e984b7a7 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -55,6 +55,11 @@ const PARTIAL_ARRAY_DIM_GROWTH_FACTOR = 4 """A convenience for defining method argument type bounds.""" const INDEX_TYPES = Union{Integer,UnitRange,Colon} +struct ArrayLikeBlock{T,I} + block::T + inds::I +end + """ PartialArray{ElType,numdims} @@ -105,6 +110,9 @@ means that the largest index set so far determines the memory usage of the `Part a few scattered values are set, a structure like `SparseArray` may be more appropriate. """ struct PartialArray{ElType,num_dims} + # TODO(mhauru) Consider trying FixedSizeArrays instead, see how it would change + # performance. We reallocate new Arrays every time when resizing anyway, except for + # Vectors, which can be extended in place. data::Array{ElType,num_dims} mask::Array{Bool,num_dims} @@ -395,7 +403,34 @@ function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) else _resize_partialarray!!(pa, inds) end - new_data = setindex!!(pa.data, value, inds...) + + new_data = pa.data + if _is_multiindex(inds) && !(isa(value, AbstractArray)) + if !hasmethod(size, value) + throw(ArgumentError("Cannot assign a scalar value to a range.")) + end + if size(value) != map(x -> _length_needed(x), inds) + throw( + DimensionMismatch( + "Assigned value has size $(size(value)), which does not match the size " * + "implied by the indices $(map(x -> _length_needed(x), inds)).", + ), + ) + end + # At this point we know we have a value that is not an AbstractArray, but it has + # some notion of size, and that size matches the indices that are being set. In this + # case we wrap the value in a ArrayLikeBlock, and set all the individual indices + # point to that, with the right subindices. + first_index = first.(inds) + # Iterate over all the subindices of inds. + for ind in CartesianIndices(map(x -> _length_needed(x), inds)) + subinds = ntuple(i -> first_index[i] + ind[i] - 1, length(inds)) + new_data = _setindex!!(new_data, ArrayLikeBlock(value, Tuple(ind)), subinds...) + end + else + new_data = setindex!!(new_data, value, inds...) + end + if _is_multiindex(inds) pa.mask[inds...] .= true else From 4253e9b53c89aabe80f23fde3c4dfc059bf20d11 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 12 Dec 2025 18:17:23 +0000 Subject: [PATCH 02/26] ArrayLikeBlock WIP2 --- src/varnamedtuple.jl | 118 +++++++++++++++++++++++++++++++++++++----- test/varnamedtuple.jl | 50 ++++++++++++++++++ 2 files changed, 155 insertions(+), 13 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 0e984b7a7..fa711e4f4 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -58,6 +58,13 @@ const INDEX_TYPES = Union{Integer,UnitRange,Colon} struct ArrayLikeBlock{T,I} block::T inds::I + + function ArrayLikeBlock(block::T, inds::I) where {T,I} + if !_is_multiindex(inds) + throw(ArgumentError("ArrayLikeBlock must be constructed with a multi-index")) + end + return new{T,I}(block, inds) + end end """ @@ -385,15 +392,102 @@ end function _getindex(pa::PartialArray, inds::Vararg{INDEX_TYPES}) _check_index_validity(pa, inds) - if !_haskey(pa, inds) + if !(checkbounds(Bool, pa.mask, inds...) && all(@inbounds(getindex(pa.mask, inds...)))) throw(BoundsError(pa, inds)) end - return getindex(pa.data, inds...) + val = getindex(pa.data, inds...) + + # If not for ArrayLikeBlocks, at this point we could just return val directly. However, + # we need to check if val contains any ArrayLikeBlocks, and if so, make sure that that + # we are retrieving exactly that block and nothing else. + + # The error we'll throw if the retrieval is invalid. + err = ArgumentError(""" + A non-Array value set with a range of indices must be retrieved with the same + range of indices. + """) + if val isa ArrayLikeBlock + # Tried to get a single value, but it's an ArrayLikeBlock. + throw(err) + elseif val isa Array && (eltype(val) <: ArrayLikeBlock || ArrayLikeBlock <: eltype(val)) + # Tried to get a range of values, and at least some of them may be ArrayLikeBlocks. + # The below isempty check is deliberately kept separate from the outer elseif, + # because the outer one can be resolved at compile time. + if isempty(val) + return val + end + first_elem = first(val) + if !(first_elem isa ArrayLikeBlock) + throw(err) + end + if inds != first_elem.inds + # The requested indices do not match the ones used to set the value. + throw(err) + end + # If _setindex!! works correctly, we should only be able to reach this point if all + # the elements in `val` are identical to first_elem. Thus we just return that one. + return first(val).block + else + return val + end end function _haskey(pa::PartialArray, inds::NTuple{N,INDEX_TYPES}) where {N} _check_index_validity(pa, inds) - return checkbounds(Bool, pa.mask, inds...) && all(@inbounds(getindex(pa.mask, inds...))) + hasall = + checkbounds(Bool, pa.mask, inds...) && all(@inbounds(getindex(pa.mask, inds...))) + + # If not for ArrayLikeBlocks, we could just return hasall directly. However, we need to + # check that if any ArrayLikeBlocks are included, they are fully included. + et = eltype(pa) + if !(et <: ArrayLikeBlock || ArrayLikeBlock <: et) + # pa can't possibly hold any ArrayLikeBlocks, so nothing to do. + return hasall + end + + if !hasall + return false + end + # From this point on we can assume that all the requested elements are set, and the only + # thing to check is that we are not partially indexing into any ArrayLikeBlocks. + # We've already checked checkbounds at the top of the function, and returned if it + # wasn't true, so @inbounds is safe. + subdata = @inbounds getindex(pa.data, inds...) + if !_is_multiindex(inds) + return !(subdata isa ArrayLikeBlock) + end + return !any(elem -> elem isa ArrayLikeBlock && elem.inds != inds, subdata) +end + +function BangBang.delete!!(pa::PartialArray, inds::Vararg{INDEX_TYPES}) + _check_index_validity(pa, inds) + if _is_multiindex(inds) + pa.mask[inds...] .= false + else + pa.mask[inds...] = false + end + return _concretise_eltype!!(pa) +end + +_ensure_range(r::UnitRange) = r +_ensure_range(i::Integer) = i:i + +function _remove_partial_blocks!!(pa::PartialArray, inds::Vararg{INDEX_TYPES}) + et = eltype(pa) + if !(et <: ArrayLikeBlock || ArrayLikeBlock <: et) + # pa can't possibly hold any ArrayLikeBlocks, so nothing to do. + return pa + end + + for i in CartesianIndices(map(_ensure_range, inds)) + if pa.mask[i] + val = @inbounds pa.data[i] + if val isa ArrayLikeBlock + pa = delete!!(pa, val.inds...) + end + end + end + return pa end function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) @@ -403,13 +497,15 @@ function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) else _resize_partialarray!!(pa, inds) end + pa = _remove_partial_blocks!!(pa, inds...) new_data = pa.data if _is_multiindex(inds) && !(isa(value, AbstractArray)) - if !hasmethod(size, value) + if !hasmethod(size, Tuple{typeof(value)}) throw(ArgumentError("Cannot assign a scalar value to a range.")) end - if size(value) != map(x -> _length_needed(x), inds) + inds_size = reduce((x, y) -> tuple(x..., y...), map(size, inds)) + if size(value) != inds_size throw( DimensionMismatch( "Assigned value has size $(size(value)), which does not match the size " * @@ -419,14 +515,10 @@ function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) end # At this point we know we have a value that is not an AbstractArray, but it has # some notion of size, and that size matches the indices that are being set. In this - # case we wrap the value in a ArrayLikeBlock, and set all the individual indices - # point to that, with the right subindices. - first_index = first.(inds) - # Iterate over all the subindices of inds. - for ind in CartesianIndices(map(x -> _length_needed(x), inds)) - subinds = ntuple(i -> first_index[i] + ind[i] - 1, length(inds)) - new_data = _setindex!!(new_data, ArrayLikeBlock(value, Tuple(ind)), subinds...) - end + # case we wrap the value in an ArrayLikeBlock, and set all the individual indices + # point to that. + alb = ArrayLikeBlock(value, inds) + new_data = setindex!!(new_data, fill(alb, inds_size...), inds...) else new_data = setindex!!(new_data, value, inds...) end diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 3beadebf8..9365d5a7b 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -2,6 +2,7 @@ module VarNamedTupleTests using Combinatorics: Combinatorics using Test: @inferred, @test, @test_throws, @testset +using Distributions: Dirichlet using DynamicPPL: DynamicPPL, @varname, VarNamedTuple using DynamicPPL.VarNamedTuples: PartialArray using AbstractPPL: VarName, prefix @@ -458,6 +459,55 @@ end VarNamedTuple(f = VarNamedTuple(g = PartialArray{Float64,1}((1,) => 16.0, \ (2,) => 17.0),),)),))""" end + + @testset "block variables" begin + # Tests for setting and getting block variables, i.e. variables that have a non-zero + # size in a PartialArray, but are not Arrays themselves. + expected_err = ArgumentError(""" + A non-Array value set with a range of indices must be retrieved with the same + range of indices. + """) + vnt = VarNamedTuple() + vnt = setindex!!(vnt, Dirichlet(3, 1.0), @varname(x[2:4])) + @test haskey(vnt, @varname(x[2:4])) + @test getindex(vnt, @varname(x[2:4])) == Dirichlet(3, 1.0) + @test !haskey(vnt, @varname(x[2:3])) + @test_throws expected_err getindex(vnt, @varname(x[2:3])) + @test !haskey(vnt, @varname(x[3])) + @test_throws expected_err getindex(vnt, @varname(x[3])) + @test !haskey(vnt, @varname(x[1])) + @test !haskey(vnt, @varname(x[5])) + vnt = setindex!!(vnt, 1.0, @varname(x[1])) + vnt = setindex!!(vnt, 1.0, @varname(x[5])) + @test haskey(vnt, @varname(x[1])) + @test haskey(vnt, @varname(x[5])) + @test_throws expected_err getindex(vnt, @varname(x[1:4])) + @test_throws expected_err getindex(vnt, @varname(x[2:5])) + + # Setting any of these indices should remove the block variable x[2:4]. + @testset "index = $index" for index in (2, 3, 4, 2:3, 3:5) + # Test setting different types of values. + vals = if index isa Int + (2.0,) + else + (fill(2.0, length(index)), Dirichlet(length(index), 2.0)) + end + @testset "val = $val" for val in vals + vn = @varname(x[index]) + vnt2 = copy(vnt) + vnt2 = setindex!!(vnt2, val, vn) + @test !haskey(vnt2, @varname(x[2:4])) + @test_throws BoundsError getindex(vnt2, @varname(x[2:4])) + other_index = index in (2, 2:3) ? 4 : 2 + @test !haskey(vnt2, @varname(x[other_index])) + @test_throws BoundsError getindex(vnt2, @varname(x[other_index])) + @test haskey(vnt2, vn) + @test getindex(vnt2, vn) == val + @test haskey(vnt2, @varname(x[1])) + @test_throws BoundsError getindex(vnt2, @varname(x[1:4])) + end + end + end end end From 5cb3916ddf1898b36a54e590e88ced043ee18765 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 12 Dec 2025 18:50:24 +0000 Subject: [PATCH 03/26] Improve type stability of ArrayLikeBlock stuff --- src/varnamedtuple.jl | 30 +++++++++++++++++++++++++----- test/varnamedtuple.jl | 28 ++++++++++++++++++++++++++-- 2 files changed, 51 insertions(+), 7 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index fa711e4f4..3210f474c 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -67,6 +67,8 @@ struct ArrayLikeBlock{T,I} end end +_blocktype(::Type{ArrayLikeBlock{T}}) where {T} = T + """ PartialArray{ElType,numdims} @@ -414,7 +416,13 @@ function _getindex(pa::PartialArray, inds::Vararg{INDEX_TYPES}) # The below isempty check is deliberately kept separate from the outer elseif, # because the outer one can be resolved at compile time. if isempty(val) - return val + # We need to return an empty array, but for type stability, we want to unwrap + # any ArrayLikeBlock types in the element type. + return if eltype(val) <: ArrayLikeBlock + Array{_blocktype(eltype(val)),ndims(val)}() + else + val + end end first_elem = first(val) if !(first_elem isa ArrayLikeBlock) @@ -490,6 +498,12 @@ function _remove_partial_blocks!!(pa::PartialArray, inds::Vararg{INDEX_TYPES}) return pa end +function _needs_arraylikeblock(value, inds::Vararg{INDEX_TYPES}) + return _is_multiindex(inds) && + !isa(value, AbstractArray) && + hasmethod(size, Tuple{typeof(value)}) +end + function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) _check_index_validity(pa, inds) pa = if checkbounds(Bool, pa.mask, inds...) @@ -500,7 +514,7 @@ function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) pa = _remove_partial_blocks!!(pa, inds...) new_data = pa.data - if _is_multiindex(inds) && !(isa(value, AbstractArray)) + if _needs_arraylikeblock(value, inds...) if !hasmethod(size, Tuple{typeof(value)}) throw(ArgumentError("Cannot assign a scalar value to a range.")) end @@ -843,9 +857,15 @@ end function make_leaf(value, optic::IndexLens) inds = optic.indices num_inds = length(inds) - # Check if any of the indices are ranges or colons. If yes, value needs to be an - # AbstractArray. Otherwise it needs to be an individual value. - et = _is_multiindex(inds) ? eltype(value) : typeof(value) + # The element type of the PartialArray depends on whether we are setting a single value + # or a range of values. + et = if !_is_multiindex(inds) + typeof(value) + elseif _needs_arraylikeblock(value, inds...) + ArrayLikeBlock{typeof(value),typeof(inds)} + else + eltype(value) + end pa = PartialArray{et,num_inds}() return _setindex!!(pa, value, optic) end diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 9365d5a7b..fe66ab317 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -468,9 +468,9 @@ end range of indices. """) vnt = VarNamedTuple() - vnt = setindex!!(vnt, Dirichlet(3, 1.0), @varname(x[2:4])) + vnt = @inferred(setindex!!(vnt, Dirichlet(3, 1.0), @varname(x[2:4]))) @test haskey(vnt, @varname(x[2:4])) - @test getindex(vnt, @varname(x[2:4])) == Dirichlet(3, 1.0) + @test @inferred(getindex(vnt, @varname(x[2:4]))) == Dirichlet(3, 1.0) @test !haskey(vnt, @varname(x[2:3])) @test_throws expected_err getindex(vnt, @varname(x[2:3])) @test !haskey(vnt, @varname(x[3])) @@ -507,6 +507,30 @@ end @test_throws BoundsError getindex(vnt2, @varname(x[1:4])) end end + + # Extra checks, mostly for type stability and to confirm that multidimensional + # blocks work too. + struct TwoByTwoBlock end + Base.size(::TwoByTwoBlock) = (2, 2) + val = TwoByTwoBlock() + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, val, @varname(y.z[1:2, 1:2]))) + @test haskey(vnt, @varname(y.z[1:2, 1:2])) + @test @inferred(getindex(vnt, @varname(y.z[1:2, 1:2]))) == val + @test !haskey(vnt, @varname(y.z[1, 1])) + @test_throws expected_err getindex(vnt, @varname(y.z[1, 1])) + + vnt = @inferred(setindex!!(vnt, val, @varname(y.z[2:3, 2:3]))) + @test haskey(vnt, @varname(y.z[2:3, 2:3])) + @test @inferred(getindex(vnt, @varname(y.z[2:3, 2:3]))) == val + @test !haskey(vnt, @varname(y.z[1:2, 1:2])) + @test_throws BoundsError getindex(vnt, @varname(y.z[1:2, 1:2])) + + vnt = @inferred(setindex!!(vnt, val, @varname(y.z[4:5, 2:3]))) + @test haskey(vnt, @varname(y.z[2:3, 2:3])) + @test @inferred(getindex(vnt, @varname(y.z[2:3, 2:3]))) == val + @test haskey(vnt, @varname(y.z[4:5, 2:3])) + @test @inferred(getindex(vnt, @varname(y.z[4:5, 2:3]))) == val end end From a96bb44a6cc5ca48dd293a8ba708f41cdce58300 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 15 Dec 2025 09:07:39 +0000 Subject: [PATCH 04/26] Test more invariants --- test/varnamedtuple.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index fe66ab317..2f113aacc 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -469,6 +469,7 @@ end """) vnt = VarNamedTuple() vnt = @inferred(setindex!!(vnt, Dirichlet(3, 1.0), @varname(x[2:4]))) + test_invariants(vnt) @test haskey(vnt, @varname(x[2:4])) @test @inferred(getindex(vnt, @varname(x[2:4]))) == Dirichlet(3, 1.0) @test !haskey(vnt, @varname(x[2:3])) @@ -479,6 +480,7 @@ end @test !haskey(vnt, @varname(x[5])) vnt = setindex!!(vnt, 1.0, @varname(x[1])) vnt = setindex!!(vnt, 1.0, @varname(x[5])) + test_invariants(vnt) @test haskey(vnt, @varname(x[1])) @test haskey(vnt, @varname(x[5])) @test_throws expected_err getindex(vnt, @varname(x[1:4])) @@ -496,6 +498,7 @@ end vn = @varname(x[index]) vnt2 = copy(vnt) vnt2 = setindex!!(vnt2, val, vn) + test_invariants(vnt) @test !haskey(vnt2, @varname(x[2:4])) @test_throws BoundsError getindex(vnt2, @varname(x[2:4])) other_index = index in (2, 2:3) ? 4 : 2 @@ -515,18 +518,21 @@ end val = TwoByTwoBlock() vnt = VarNamedTuple() vnt = @inferred(setindex!!(vnt, val, @varname(y.z[1:2, 1:2]))) + test_invariants(vnt) @test haskey(vnt, @varname(y.z[1:2, 1:2])) @test @inferred(getindex(vnt, @varname(y.z[1:2, 1:2]))) == val @test !haskey(vnt, @varname(y.z[1, 1])) @test_throws expected_err getindex(vnt, @varname(y.z[1, 1])) vnt = @inferred(setindex!!(vnt, val, @varname(y.z[2:3, 2:3]))) + test_invariants(vnt) @test haskey(vnt, @varname(y.z[2:3, 2:3])) @test @inferred(getindex(vnt, @varname(y.z[2:3, 2:3]))) == val @test !haskey(vnt, @varname(y.z[1:2, 1:2])) @test_throws BoundsError getindex(vnt, @varname(y.z[1:2, 1:2])) vnt = @inferred(setindex!!(vnt, val, @varname(y.z[4:5, 2:3]))) + test_invariants(vnt) @test haskey(vnt, @varname(y.z[2:3, 2:3])) @test @inferred(getindex(vnt, @varname(y.z[2:3, 2:3]))) == val @test haskey(vnt, @varname(y.z[4:5, 2:3])) From a8014e6208f2cd346d400329fc2a71ae73f017c9 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 15 Dec 2025 10:21:21 +0000 Subject: [PATCH 05/26] Actually run VNT tests --- test/runtests.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/runtests.jl b/test/runtests.jl index 9649aebbb..e0b42904c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -53,6 +53,7 @@ include("test_util.jl") include("utils.jl") include("accumulators.jl") include("compiler.jl") + include("varnamedtuple.jl") include("varnamedvector.jl") include("varinfo.jl") include("simple_varinfo.jl") From cfc60419fde7736ec4d051b7d99d61dbb5fb5a61 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 15 Dec 2025 11:13:05 +0000 Subject: [PATCH 06/26] Implement show for ArrayLikeBlock --- src/varnamedtuple.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 3210f474c..5dfbf153e 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -67,6 +67,18 @@ struct ArrayLikeBlock{T,I} end end +function Base.show(io::IO, alb::ArrayLikeBlock) + # Note the distinction: The raw strings that form part of the structure of the print + # out are `print`ed, whereas the keys and values are `show`n. The latter ensures + # that strings are quoted, Symbols are prefixed with :, etc. + print(io, "ArrayLikeBlock(") + show(io, alb.block) + print(io, ", ") + show(io, alb.inds) + print(io, ")") + return nothing +end + _blocktype(::Type{ArrayLikeBlock{T}}) where {T} = T """ From e198fbb1c4234c0b4b69fe493751f7cba0b54a30 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 15 Dec 2025 11:14:16 +0000 Subject: [PATCH 07/26] Change keys on VNT to return an array --- src/varnamedtuple.jl | 15 ++++----------- test/varnamedtuple.jl | 28 ++++++++++++++-------------- 2 files changed, 18 insertions(+), 25 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 5dfbf153e..62c36d021 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -792,25 +792,18 @@ function apply!!(func, vnt::VarNamedTuple, name::VarName) return _setindex!!(vnt, new_subdata, name) end -# TODO(mhauru) Should this return tuples, like it does now? That makes sense for -# VarNamedTuple itself, but if there is a nested PartialArray the tuple might get very big. -# Also, this is not very type stable, it fails even in basic cases. A generated function -# would help, but I failed to make one. Might be something to do with a recursive -# generated function. function Base.keys(vnt::VarNamedTuple) - result = () + result = VarName[] for sym in keys(vnt.data) subdata = vnt.data[sym] if subdata isa VarNamedTuple subkeys = keys(subdata) - result = ( - result..., (AbstractPPL.prefix(sk, VarName{sym}()) for sk in subkeys)... - ) + append!(result, [AbstractPPL.prefix(sk, VarName{sym}()) for sk in subkeys]) elseif subdata isa PartialArray subkeys = keys(subdata) - result = (result..., (VarName{sym}(lens) for lens in subkeys)...) + append!(result, [VarName{sym}(lens) for lens in subkeys]) else - result = (result..., VarName{sym}()) + push!(result, VarName{sym}()) end end return result diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 2f113aacc..41bcd5fd5 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -343,36 +343,36 @@ end @testset "keys" begin vnt = VarNamedTuple() - @test @inferred(keys(vnt)) == () + @test @inferred(keys(vnt)) == VarName[] vnt = setindex!!(vnt, 1.0, @varname(a)) # TODO(mhauru) that the below passes @inferred, but any of the later ones don't. # We should improve type stability of keys(). - @test @inferred(keys(vnt)) == (@varname(a),) + @test @inferred(keys(vnt)) == [@varname(a)] vnt = setindex!!(vnt, [1, 2, 3], @varname(b)) - @test keys(vnt) == (@varname(a), @varname(b)) + @test keys(vnt) == [@varname(a), @varname(b)] vnt = setindex!!(vnt, 15, @varname(b[2])) - @test keys(vnt) == (@varname(a), @varname(b)) + @test keys(vnt) == [@varname(a), @varname(b)] vnt = setindex!!(vnt, [10], @varname(c.x.y)) - @test keys(vnt) == (@varname(a), @varname(b), @varname(c.x.y)) + @test keys(vnt) == [@varname(a), @varname(b), @varname(c.x.y)] vnt = setindex!!(vnt, -1.0, @varname(d[4])) - @test keys(vnt) == (@varname(a), @varname(b), @varname(c.x.y), @varname(d[4])) + @test keys(vnt) == [@varname(a), @varname(b), @varname(c.x.y), @varname(d[4])] vnt = setindex!!(vnt, 2.0, @varname(e.f[3, 3].g.h[2, 4, 1].i)) - @test keys(vnt) == ( + @test keys(vnt) == [ @varname(a), @varname(b), @varname(c.x.y), @varname(d[4]), @varname(e.f[3, 3].g.h[2, 4, 1].i), - ) + ] vnt = setindex!!(vnt, fill(1.0, 4), @varname(j[1:4])) - @test keys(vnt) == ( + @test keys(vnt) == [ @varname(a), @varname(b), @varname(c.x.y), @@ -382,10 +382,10 @@ end @varname(j[2]), @varname(j[3]), @varname(j[4]), - ) + ] vnt = setindex!!(vnt, 1.0, @varname(j[6])) - @test keys(vnt) == ( + @test keys(vnt) == [ @varname(a), @varname(b), @varname(c.x.y), @@ -396,10 +396,10 @@ end @varname(j[3]), @varname(j[4]), @varname(j[6]), - ) + ] vnt = setindex!!(vnt, 1.0, @varname(n[2].a)) - @test keys(vnt) == ( + @test keys(vnt) == [ @varname(a), @varname(b), @varname(c.x.y), @@ -411,7 +411,7 @@ end @varname(j[4]), @varname(j[6]), @varname(n[2].a), - ) + ] end @testset "printing" begin From b77b0af1d64ee60f5703de165e6f187d707f7550 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 15 Dec 2025 11:15:54 +0000 Subject: [PATCH 08/26] Fix keys and some tests for PartialArray --- src/varnamedtuple.jl | 2 ++ test/varnamedtuple.jl | 25 +++++++++++++++++-------- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 62c36d021..711fc6037 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -639,6 +639,8 @@ function Base.keys(pa::PartialArray) sublens = _varname_to_lens(vn) push!(ks, _compose_no_identity(sublens, lens)) end + elseif val isa ArrayLikeBlock + push!(ks, IndexLens(Tuple(val.inds))) else push!(ks, lens) end diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 41bcd5fd5..67c3621a7 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -2,9 +2,8 @@ module VarNamedTupleTests using Combinatorics: Combinatorics using Test: @inferred, @test, @test_throws, @testset -using Distributions: Dirichlet using DynamicPPL: DynamicPPL, @varname, VarNamedTuple -using DynamicPPL.VarNamedTuples: PartialArray +using DynamicPPL.VarNamedTuples: PartialArray, ArrayLikeBlock using AbstractPPL: VarName, prefix using BangBang: setindex!! @@ -20,12 +19,18 @@ function test_invariants(vnt::VarNamedTuple) for k in keys(vnt) @test haskey(vnt, k) v = getindex(vnt, k) + # ArrayLikeBlocks are an implementation detail, and should not be exposed through + # getindex. + @test !(v isa ArrayLikeBlock) vnt2 = setindex!!(copy(vnt), v, k) @test vnt == vnt2 @test isequal(vnt, vnt2) @test hash(vnt) == hash(vnt2) end # Check that the printed representation can be parsed back to an equal VarNamedTuple. + # The below eval test is a bit fragile: If any elements in vnt don't respect the same + # reconstructability-from-repr property, this will fail. Likewise if any element uses + # in its repr print out types that are not in scope in this module, it will fail. vnt3 = eval(Meta.parse(repr(vnt))) @test vnt == vnt3 @test isequal(vnt, vnt3) @@ -461,6 +466,12 @@ end end @testset "block variables" begin + """ A type that has a size but is not an Array.""" + struct SizedThing + size::Tuple + end + Base.size(st::SizedThing) = st.size + # Tests for setting and getting block variables, i.e. variables that have a non-zero # size in a PartialArray, but are not Arrays themselves. expected_err = ArgumentError(""" @@ -468,10 +479,10 @@ end range of indices. """) vnt = VarNamedTuple() - vnt = @inferred(setindex!!(vnt, Dirichlet(3, 1.0), @varname(x[2:4]))) + vnt = @inferred(setindex!!(vnt, SizedThing((3,)), @varname(x[2:4]))) test_invariants(vnt) @test haskey(vnt, @varname(x[2:4])) - @test @inferred(getindex(vnt, @varname(x[2:4]))) == Dirichlet(3, 1.0) + @test @inferred(getindex(vnt, @varname(x[2:4]))) == SizedThing((3,)) @test !haskey(vnt, @varname(x[2:3])) @test_throws expected_err getindex(vnt, @varname(x[2:3])) @test !haskey(vnt, @varname(x[3])) @@ -492,7 +503,7 @@ end vals = if index isa Int (2.0,) else - (fill(2.0, length(index)), Dirichlet(length(index), 2.0)) + (fill(2.0, length(index)), SizedThing((length(index),))) end @testset "val = $val" for val in vals vn = @varname(x[index]) @@ -513,9 +524,7 @@ end # Extra checks, mostly for type stability and to confirm that multidimensional # blocks work too. - struct TwoByTwoBlock end - Base.size(::TwoByTwoBlock) = (2, 2) - val = TwoByTwoBlock() + val = SizedThing((2, 2)) vnt = VarNamedTuple() vnt = @inferred(setindex!!(vnt, val, @varname(y.z[1:2, 1:2]))) test_invariants(vnt) From 633e920c561bb536a4b0c633c70f10a5ee5952a3 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 15 Dec 2025 11:39:09 +0000 Subject: [PATCH 09/26] Improve type stability --- src/varnamedtuple.jl | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 711fc6037..36ddd7377 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -27,9 +27,30 @@ function _setindex!!(arr::AbstractArray, value, optic::IndexLens) end # Some utilities for checking what sort of indices we are dealing with. -_has_colon(::T) where {T<:Tuple} = any(x <: Colon for x in T.parameters) -function _is_multiindex(::T) where {T<:Tuple} - return any(x <: UnitRange || x <: Colon for x in T.parameters) +# The non-generated function implementations of these would be +# _has_colon(::T) where {T<:Tuple} = any(x <: Colon for x in T.parameters) +# function _is_multiindex(::T) where {T<:Tuple} +# return any(x <: UnitRange || x <: Colon for x in T.parameters) +# end +# However, constant propagation sometimes fails if the index tuple is too big (e.g. length +# 4), so we play it safe and use generated functions. Constant propagating these is +# important, because many functions choose different paths based on their values, which +# would lead to type instability if they were only evaluated at runtime. +@generated function _has_colon(::T) where {T<:Tuple} + for x in T.parameters + if x <: Colon + return :(true) + end + end + return :(false) +end +@generated function _is_multiindex(::T) where {T<:Tuple} + for x in T.parameters + if x <: UnitRange || x <: Colon + return :(true) + end + end + return :(false) end """ From 222334a97219d6786cbac319285ebe2441d12230 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 15 Dec 2025 11:39:56 +0000 Subject: [PATCH 10/26] Fix keys for PartialArray --- src/varnamedtuple.jl | 6 +++++- test/varnamedtuple.jl | 28 ++++++++++++++++++++++------ 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 36ddd7377..ffc69bdcf 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -652,6 +652,7 @@ function Base.keys(pa::PartialArray) inds = findall(pa.mask) lenses = map(x -> IndexLens(Tuple(x)), inds) ks = Any[] + alb_inds_seen = Set{Tuple}() for lens in lenses val = getindex(pa.data, lens.indices...) if val isa VarNamedTuple @@ -661,7 +662,10 @@ function Base.keys(pa::PartialArray) push!(ks, _compose_no_identity(sublens, lens)) end elseif val isa ArrayLikeBlock - push!(ks, IndexLens(Tuple(val.inds))) + if !(val.inds in alb_inds_seen) + push!(ks, IndexLens(Tuple(val.inds))) + push!(alb_inds_seen, val.inds) + end else push!(ks, lens) end diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 67c3621a7..c9e9cb07f 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -40,6 +40,12 @@ function test_invariants(vnt::VarNamedTuple) @test merge(VarNamedTuple(), vnt) == vnt end +""" A type that has a size but is not an Array. Used in ArrayLikeBlock tests.""" +struct SizedThing{T<:Tuple} + size::T +end +Base.size(st::SizedThing) = st.size + @testset "VarNamedTuple" begin @testset "Construction" begin vnt1 = VarNamedTuple() @@ -417,6 +423,22 @@ end @varname(j[6]), @varname(n[2].a), ] + + vnt = setindex!!(vnt, SizedThing((3, 1, 4)), @varname(o[2:4, 5:5, 11:14])) + @test keys(vnt) == [ + @varname(a), + @varname(b), + @varname(c.x.y), + @varname(d[4]), + @varname(e.f[3, 3].g.h[2, 4, 1].i), + @varname(j[1]), + @varname(j[2]), + @varname(j[3]), + @varname(j[4]), + @varname(j[6]), + @varname(n[2].a), + @varname(o[2:4, 5:5, 11:14]), + ] end @testset "printing" begin @@ -466,12 +488,6 @@ end end @testset "block variables" begin - """ A type that has a size but is not an Array.""" - struct SizedThing - size::Tuple - end - Base.size(st::SizedThing) = st.size - # Tests for setting and getting block variables, i.e. variables that have a non-zero # size in a PartialArray, but are not Arrays themselves. expected_err = ArgumentError(""" From d22face74a1d70bab93feeba2a1739d1b823a817 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 15 Dec 2025 11:40:12 +0000 Subject: [PATCH 11/26] More ArrayLikeBlock tests --- test/varnamedtuple.jl | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index c9e9cb07f..8be72a184 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -562,6 +562,30 @@ Base.size(st::SizedThing) = st.size @test @inferred(getindex(vnt, @varname(y.z[2:3, 2:3]))) == val @test haskey(vnt, @varname(y.z[4:5, 2:3])) @test @inferred(getindex(vnt, @varname(y.z[4:5, 2:3]))) == val + + # A lot like above, but with extra indices that are not ranges. + val = SizedThing((2, 2)) + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, val, @varname(y.z[2, 1:2, 3, 1:2, 4]))) + test_invariants(vnt) + @test haskey(vnt, @varname(y.z[2, 1:2, 3, 1:2, 4])) + @test @inferred(getindex(vnt, @varname(y.z[2, 1:2, 3, 1:2, 4]))) == val + @test !haskey(vnt, @varname(y.z[2, 1, 3, 1, 4])) + @test_throws expected_err getindex(vnt, @varname(y.z[2, 1, 3, 1, 4])) + + vnt = @inferred(setindex!!(vnt, val, @varname(y.z[2, 2:3, 3, 2:3, 4]))) + test_invariants(vnt) + @test haskey(vnt, @varname(y.z[2, 2:3, 3, 2:3, 4])) + @test @inferred(getindex(vnt, @varname(y.z[2, 2:3, 3, 2:3, 4]))) == val + @test !haskey(vnt, @varname(y.z[2, 1:2, 3, 1:2, 4])) + @test_throws BoundsError getindex(vnt, @varname(y.z[2, 1:2, 3, 1:2, 4])) + + vnt = @inferred(setindex!!(vnt, val, @varname(y.z[3, 2:3, 3, 2:3, 4]))) + test_invariants(vnt) + @test haskey(vnt, @varname(y.z[2, 2:3, 3, 2:3, 4])) + @test @inferred(getindex(vnt, @varname(y.z[2, 2:3, 3, 2:3, 4]))) == val + @test haskey(vnt, @varname(y.z[3, 2:3, 3, 2:3, 4])) + @test @inferred(getindex(vnt, @varname(y.z[3, 2:3, 3, 2:3, 4]))) == val end end From 4cb49e1194616ce8238d583c8aba22d18c1ab49f Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 15 Dec 2025 11:46:02 +0000 Subject: [PATCH 12/26] Add docstrings --- src/varnamedtuple.jl | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index ffc69bdcf..ab66da5ac 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -76,6 +76,19 @@ const PARTIAL_ARRAY_DIM_GROWTH_FACTOR = 4 """A convenience for defining method argument type bounds.""" const INDEX_TYPES = Union{Integer,UnitRange,Colon} +""" + ArrayLikeBlock{T,I} + +A wrapper for non-array blocks stored in `PartialArray`s. + +When setting a value in a `PartialArray` over a range of indices, if the value being set +is not itself an `AbstractArray`, but has a well-defined size, we wrap it in an +`ArrayLikeBlock`, which records both the value and the indices it was set with. + +When getting values from a `PartialArray`, if any of the requested indices correspond to +an `ArrayLikeBlock`, we check that the requested indices match the ones used to set the +value. If they do, we return the underlying block, otherwise we throw an error. +""" struct ArrayLikeBlock{T,I} block::T inds::I @@ -136,6 +149,14 @@ Like `Base.Array`s, `PartialArray`s have a well-defined, compile-time-known elem `ElType` and number of dimensions `numdims`. Indices into a `PartialArray` must have exactly `numdims` elements. +One can set values in a `PartialArray` either element-by-element, or with ranges like +`arr[1:3,2] = [5,10,15]`. When setting values over a range of indices, the value being set +must either be an `AbstractArray` or otherwise something for which `size(value)` is defined, +and the size mathces the range. If the value is an `AbstractArray`, the elements are copied +individually, but if it is not, the value is stored as a block, that takes up the whole +range, e.g. `[1:3,2]`, but is only a single object. Getting such a block-value must be done +with the exact same range of indices, otherwise an error is thrown. + If the element type of a `PartialArray` is not concrete, any call to `setindex!!` will check if, after the new value has been set, the element type can be made more concrete. If so, a new `PartialArray` with a more concrete element type is returned. Thus the element type From 420a6b2889429625a7fadb948937f2ef1ce1b6aa Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 15 Dec 2025 12:03:39 +0000 Subject: [PATCH 13/26] Remove redundant code, improve documentation --- src/varnamedtuple.jl | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index ab66da5ac..308951608 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -528,12 +528,20 @@ function BangBang.delete!!(pa::PartialArray, inds::Vararg{INDEX_TYPES}) else pa.mask[inds...] = false end - return _concretise_eltype!!(pa) + return pa end _ensure_range(r::UnitRange) = r _ensure_range(i::Integer) = i:i +""" + _remove_partial_blocks!!(pa::PartialArray, inds::Vararg{INDEX_TYPES}) + +Remove any ArrayLikeBlocks that overlap with the given indices from the PartialArray. + +Note that this removes the whole block, even the parts that are within `inds`, to avoid +partially indexing into ArrayLikeBlocks. +""" function _remove_partial_blocks!!(pa::PartialArray, inds::Vararg{INDEX_TYPES}) et = eltype(pa) if !(et <: ArrayLikeBlock || ArrayLikeBlock <: et) @@ -552,6 +560,13 @@ function _remove_partial_blocks!!(pa::PartialArray, inds::Vararg{INDEX_TYPES}) return pa end +""" + _needs_arraylikeblock(value, inds::Vararg{INDEX_TYPES}) + +Check if the given value needs to be wrapped in an `ArrayLikeBlock` when being set at inds. + +The value only depends on the types of the arguments, and should be constant propagated. +""" function _needs_arraylikeblock(value, inds::Vararg{INDEX_TYPES}) return _is_multiindex(inds) && !isa(value, AbstractArray) && @@ -569,9 +584,6 @@ function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) new_data = pa.data if _needs_arraylikeblock(value, inds...) - if !hasmethod(size, Tuple{typeof(value)}) - throw(ArgumentError("Cannot assign a scalar value to a range.")) - end inds_size = reduce((x, y) -> tuple(x..., y...), map(size, inds)) if size(value) != inds_size throw( @@ -584,7 +596,7 @@ function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) # At this point we know we have a value that is not an AbstractArray, but it has # some notion of size, and that size matches the indices that are being set. In this # case we wrap the value in an ArrayLikeBlock, and set all the individual indices - # point to that. + # to point to that. alb = ArrayLikeBlock(value, inds) new_data = setindex!!(new_data, fill(alb, inds_size...), inds...) else From ce9da19422b4255c97bcb42c6e3a4a9eff29e31d Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 15 Dec 2025 12:10:11 +0000 Subject: [PATCH 14/26] Add Base.size(::RangeAndLinked) --- src/contexts/init.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 90394a24c..e666e0622 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -213,6 +213,8 @@ struct RangeAndLinked is_linked::Bool end +Base.size(ral::RangeAndLinked) = size(ral.range) + """ VectorWithRanges{Tlink}( varname_ranges::VarNamedTuple, From 4eb33e931853d55400cf6b897bf97f485d6016bd Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 15 Dec 2025 16:36:22 +0000 Subject: [PATCH 15/26] Fix issues with RangeAndLinked and VNT --- ext/DynamicPPLMarginalLogDensitiesExt.jl | 10 ++++------ src/contexts/init.jl | 13 +++++++++--- src/logdensityfunction.jl | 10 ++++++++-- src/varname.jl | 25 ++++++++++++++++++++++++ src/varnamedtuple.jl | 21 ++++++++++++++++---- 5 files changed, 64 insertions(+), 15 deletions(-) diff --git a/ext/DynamicPPLMarginalLogDensitiesExt.jl b/ext/DynamicPPLMarginalLogDensitiesExt.jl index ffb5baf25..e28560872 100644 --- a/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ b/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -1,6 +1,6 @@ module DynamicPPLMarginalLogDensitiesExt -using DynamicPPL: DynamicPPL, LogDensityProblems, VarName +using DynamicPPL: DynamicPPL, LogDensityProblems, VarName, RangeAndLinked using MarginalLogDensities: MarginalLogDensities # A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by @@ -105,11 +105,9 @@ function DynamicPPL.marginalize( ldf = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo) # Determine the indices for the variables to marginalise out. varindices = mapreduce(vcat, marginalized_varnames) do vn - if DynamicPPL.getoptic(vn) === identity - ldf._iden_varname_ranges[DynamicPPL.getsym(vn)].range - else - ldf._varname_ranges[vn].range - end + # The type assertion helps in cases where the model is type unstable and thus + # `varname_ranges` may have an abstract element type. + (ldf._varname_ranges[vn]::RangeAndLinked).range end mld = MarginalLogDensities.MarginalLogDensity( LogDensityFunctionWrapper(ldf, varinfo), diff --git a/src/contexts/init.jl b/src/contexts/init.jl index e666e0622..dc811df85 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -206,14 +206,16 @@ an unlinked value. $(TYPEDFIELDS) """ -struct RangeAndLinked +struct RangeAndLinked{T<:Tuple} # indices that the variable corresponds to in the vectorised parameter range::UnitRange{Int} # whether it's linked is_linked::Bool + # original size of the variable before vectorisation + original_size::T end -Base.size(ral::RangeAndLinked) = size(ral.range) +Base.size(ral::RangeAndLinked) = ral.original_size """ VectorWithRanges{Tlink}( @@ -249,7 +251,12 @@ struct VectorWithRanges{Tlink,VNT<:VarNamedTuple,T<:AbstractVector{<:Real}} end function _get_range_and_linked(vr::VectorWithRanges, vn::VarName) - return vr.varname_ranges[vn] + # The type assertion does nothing if VectorWithRanges has concrete element types, as is + # the case for all type stable models. However, if the model is not type stable, + # vr.varname_ranges[vn] may infer to have type `Any`. In this case it is helpful to + # assert that it is a RangeAndLinked, because even though it remains non-concrete, + # it'll allow the compiler to infer the types of `range` and `is_linked`. + return vr.varname_ranges[vn]::RangeAndLinked end function init( ::Random.AbstractRNG, diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 47b49a277..89e2b5989 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -330,7 +330,10 @@ function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int) for (vn, idx) in md.idcs is_linked = md.is_transformed[idx] range = md.ranges[idx] .+ (start_offset - 1) - all_ranges = BangBang.setindex!!(all_ranges, RangeAndLinked(range, is_linked), vn) + orig_size = varnamesize(vn) + all_ranges = BangBang.setindex!!( + all_ranges, RangeAndLinked(range, is_linked, orig_size), vn + ) offset += length(range) end return all_ranges, offset @@ -341,7 +344,10 @@ function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int) for (vn, idx) in vnv.varname_to_index is_linked = vnv.is_unconstrained[idx] range = vnv.ranges[idx] .+ (start_offset - 1) - all_ranges = BangBang.setindex!!(all_ranges, RangeAndLinked(range, is_linked), vn) + orig_size = varnamesize(vn) + all_ranges = BangBang.setindex!!( + all_ranges, RangeAndLinked(range, is_linked, orig_size), vn + ) offset += length(range) end return all_ranges, offset diff --git a/src/varname.jl b/src/varname.jl index 3eb1f2460..7ffe9cc08 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -41,3 +41,28 @@ Possibly existing indices of `varname` are neglected. ) where {s,missings,_F,_a,_T} return s in missings end + +# TODO(mhauru) This should probably be Base.size(::VarName) in AbstractPPL. +""" + varnamesize(vn::VarName) + +Return the size of the object referenced by this VarName. + +```jldoctest +julia> varnamesize(@varname(a)) +() + +julia> varnamesize(@varname(b[1:3, 2])) +(3,) + +julia> varnamesize(@varname(c.d[4].e[3, 2:5, 2, 1:4, 1])) +(4, 4) +""" +function varnamesize(vn::VarName) + l = AbstractPPL._last(vn.optic) + if l isa Accessors.IndexLens + return reduce((x, y) -> tuple(x..., y...), map(size, l.indices)) + else + return () + end +end diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 308951608..1340846a9 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -352,7 +352,13 @@ function _concretise_eltype!!(pa::PartialArray) if isconcretetype(eltype(pa)) return pa end - new_et = promote_type((typeof(pa.data[i]) for i in eachindex(pa.mask) if pa.mask[i])...) + # We could use promote_type here, instead of typejoin. However, that would e.g. + # cause Ints to be converted to Float64s, since + # promote_type(Int, Float64) == Float64, which can cause problems. See + # https://github.com/TuringLang/DynamicPPL.jl/pull/1098#discussion_r2472636188. + # Base.promote_typejoin would be like typejoin, but creates Unions out of Nothing + # and Missing, rather than falling back on Any. However, it's not exported. + new_et = typejoin((typeof(pa.data[i]) for i in eachindex(pa.mask) if pa.mask[i])...) # TODO(mhauru) Should we check as below, or rather isconcretetype(new_et)? # In other words, does it help to be more concrete, even if we aren't fully concrete? if new_et === eltype(pa) @@ -588,8 +594,8 @@ function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) if size(value) != inds_size throw( DimensionMismatch( - "Assigned value has size $(size(value)), which does not match the size " * - "implied by the indices $(map(x -> _length_needed(x), inds)).", + "Assigned value has size $(size(value)), which does not match the " * + "size implied by the indices $(map(x -> _length_needed(x), inds)).", ), ) end @@ -659,7 +665,14 @@ function _merge_recursive(pa1::PartialArray, pa2::PartialArray) result else # Neither is strictly bigger than the other. - et = promote_type(eltype(pa1), eltype(pa2)) + # We could use promote_type here, instead of typejoin. However, that would e.g. + # cause Ints to be converted to Float64s, since + # promote_type(Int, Float64) == Float64, which can cause problems. See + # https://github.com/TuringLang/DynamicPPL.jl/pull/1098#discussion_r2472636188. + # Base.promote_typejoin would be like typejoin, but creates Unions out of + # Nothing and Missing, rather than falling back on Any. However, it's not + # exported. + et = typejoin(eltype(pa1), eltype(pa2)) new_data = Array{et,num_dims}(undef, merge_size) new_mask = fill(false, merge_size) result = PartialArray(new_data, new_mask) From 51b399aeb1f3c4ee29e1029215668b47847e0a15 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 15 Dec 2025 17:33:16 +0000 Subject: [PATCH 16/26] Write more design doc for ArrayLikeBlocks --- docs/src/internals/varnamedtuple.md | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/docs/src/internals/varnamedtuple.md b/docs/src/internals/varnamedtuple.md index 47ff9c65e..63f4bb5b9 100644 --- a/docs/src/internals/varnamedtuple.md +++ b/docs/src/internals/varnamedtuple.md @@ -144,6 +144,29 @@ You can also set the elements with `vnt = setindex!!(vnt, @varname(a[1]), 3.0)`, At this point you can not set any new values in that array that would be outside of its range, with something like `vnt = setindex!!(vnt, @varname(a[5]), 5.0)`. The philosophy here is that once a `Base.Array` has been attached to a `VarName`, that takes precedence, and a `PartialArray` is only used as a fallback when we are told to store a value for `@varname(a[i])` without having any previous knowledge about what `@varname(a)` is. +## Non-Array blocks with `IndexLens`es + +The above is all that is needed for setting regular scalar values. +However, in DynamicPPL we also have a particular need for something slightly odd: +We sometimes need to do calls like `setindex!!(vnt, @varname(a[1:5]), val)` on a `val` that is _not_ an `AbstractArray`, or even iterable at all. +Normally this would error: As a scalar value with size `()`, `val` is the wrong size to be set with `@varname(a[1:5])`, which clearly wants something with size `(5,)`. +However, we want to allow this even if `val` is not an iterable, if it is some object for which `size` is well-defined, and `size(val) == (5,)`. +In DynamicPPL this comes up when storing e.g. the priors of a model, where a random variable like `@varname(a[1:5])` may be associated with a prior that is a 5-dimensional distribution. + +Internally, a `PartialArray` is just a regular `Array` with a mask saying which elements have been set. +Hence we can't store `val` directly in the same `PartialArray`: +We need it to take up a sub-block of the array, in our example case a sub-block of length 5. +To this end, internally, `PartialArray` uses a wrapper type called `ArrayLikeWrapper`, that stores `val` together with the indices that are being used to set it. +The `PartialArray` has all its corresponding elements, in our example elements 1, 2, 3, 4, and, 5, point to the same wrapper object. + +While such blocks can be stored using a wrapper like this, some care must be taken in indexing into these blocks. +For instance, after setting a block with `setindex!!(vnt, @varname(a[1:5]), val)`, we can't `getindex(vnt, @varname(a[1]))`, since we can't return "the first element of five in `val`", because `val` may not be indexable in any way. +Similarly, if next we set `setindex!!(vnt, @varname(a[1]), some_other_value)`, that should invalidate/delete the elements `@varname(a[2:5])`, since the block only makes sense as a whole. +Because of these reasons, setting and getting blocks of well-defined size like this is allowed with `VarNamedTuple`s, but _only by always using the full range_. +For instance, if `setindex!!(vnt, @varname(a[1:5]), val)` has been set, then the only valid `getindex` key to access `val` is `@varname(a[1:5])`; +Not `@varname(a[1:10])`, nor `@varname(a[3])`, nor for anything else that overlaps with `@varname(a[1:5])`. +`haskey` likewise only returns true for `@varname(a[1:5])`, and `keys(vnt)` only has that as an element. + ## Limitations This design has a several of benefits, for performance and generality, but it also has limitations: From 57fd11a30b4bf5b5b55500300d5af6506c7e31d5 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 16 Dec 2025 15:14:22 +0000 Subject: [PATCH 17/26] Make VNT support concretized slices --- docs/src/internals/varnamedtuple.md | 1 + src/test_utils.jl | 1 + src/test_utils/models.jl | 111 ++++++++++++++++++++++++++++ src/varnamedtuple.jl | 31 +++++--- test/simple_varinfo.jl | 8 ++ test/varnamedtuple.jl | 21 +++++- 6 files changed, 162 insertions(+), 11 deletions(-) diff --git a/docs/src/internals/varnamedtuple.md b/docs/src/internals/varnamedtuple.md index 63f4bb5b9..aa08c119d 100644 --- a/docs/src/internals/varnamedtuple.md +++ b/docs/src/internals/varnamedtuple.md @@ -50,6 +50,7 @@ The typical use of this structure in DynamicPPL is that the user may define valu This is also the reason why `PartialArray`, and by extension `VarNamedTuple`, do not support indexing by `Colon()`, i.e. `:`, as in `x[:]`. A `Colon()` says that we should get or set all the values along that dimension, but a `PartialArray` does not know how many values there may be. If `x[1]` and `x[4]` have been set, asking for `x[:]` is not a well-posed question. +Note however, that concretising the `VarName` resolves this ambiguity, and makes the `VarName` fine as a key to a `VarNamedTuple`. `PartialArray`s have other restrictions, compared to the full indexing syntax of Julia, as well: They do not support linearly indexing into multidimemensional arrays (as in `rand(3,3)[8]`), nor indexing with arrays of indices (as in `rand(4)[[1,3]]`), nor indexing with boolean mask arrays (as in `rand(4)[[true, false, true, false]]`). diff --git a/src/test_utils.jl b/src/test_utils.jl index f584055b3..ebb516844 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -1,6 +1,7 @@ module TestUtils using AbstractMCMC +using AbstractPPL: AbstractPPL using DynamicPPL using LinearAlgebra using Distributions diff --git a/src/test_utils/models.jl b/src/test_utils/models.jl index 84e1f10d8..dcc2d92a2 100644 --- a/src/test_utils/models.jl +++ b/src/test_utils/models.jl @@ -565,6 +565,71 @@ function varnames(model::Model{typeof(demo_assume_matrix_observe_matrix_index)}) return [@varname(s), @varname(m)] end +@model function demo_nested_colons( + x=(; data=[(; subdata=transpose([1.5 2.0;]))]), ::Type{TV}=Array{Float64} +) where {TV} + n = length(x.data[1].subdata) + d = n รท 2 + s = (; params=[(; subparams=TV(undef, (d, 1, 2)))]) + s.params[1].subparams[:, 1, :] ~ reshape( + product_distribution(fill(InverseGamma(2, 3), n)), d, 2 + ) + s_vec = vec(s.params[1].subparams) + # TODO(mhauru) The below element type concretisation is because of + # https://github.com/JuliaFolds2/BangBang.jl/issues/39 + # which causes, when this is evaluated with an untyped VarInfo, s_vec to be an + # Array{Any}. + s_vec = [x for x in s_vec] + m ~ MvNormal(zeros(n), Diagonal(s_vec)) + + x.data[1].subdata[:, 1] ~ MvNormal(m, Diagonal(s_vec)) + + return (; s=s, m=m, x=x) +end +function logprior_true(model::Model{typeof(demo_nested_colons)}, s, m) + n = length(model.args.x.data[1].subdata) + # TODO(mhauru) We need to enforce a convention on whether this function gets called + # with the parameters as the model returns them, or with the parameters "unpacked". + # Currently different tests do different things. + s_vec = if s isa NamedTuple + vec(s.params[1].subparams) + else + vec(s) + end + return loglikelihood(InverseGamma(2, 3), s_vec) + + logpdf(MvNormal(zeros(n), Diagonal(s_vec)), m) +end +function loglikelihood_true(model::Model{typeof(demo_nested_colons)}, s, m) + # TODO(mhauru) We need to enforce a convention on whether this function gets called + # with the parameters as the model returns them, or with the parameters "unpacked". + # Currently different tests do different things. + s_vec = if s isa NamedTuple + vec(s.params[1].subparams) + else + vec(s) + end + return loglikelihood(MvNormal(m, Diagonal(s_vec)), model.args.x.data[1].subdata) +end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_nested_colons)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s.params[1].subparams, m) +end +function varnames(::Model{typeof(demo_nested_colons)}) + return [ + @varname( + s.params[1].subparams[ + AbstractPPL.ConcretizedSlice(Base.Slice(Base.OneTo(1))), + 1, + AbstractPPL.ConcretizedSlice(Base.Slice(Base.OneTo(2))), + ] + ), + # @varname(s.params[1].subparams[1,1,1]), + # @varname(s.params[1].subparams[1,1,2]), + @varname(m), + ] +end + const UnivariateAssumeDemoModels = Union{ Model{typeof(demo_assume_dot_observe)}, Model{typeof(demo_assume_dot_observe_literal)}, @@ -701,6 +766,51 @@ function rand_prior_true(rng::Random.AbstractRNG, model::MatrixvariateAssumeDemo return vals end +function posterior_mean(model::Model{typeof(demo_nested_colons)}) + # Get some containers to fill. + vals = rand_prior_true(model) + + vals.s.params[1].subparams[1, 1, 1] = 19 / 8 + vals.m[1] = 3 / 4 + + vals.s.params[1].subparams[1, 1, 2] = 8 / 3 + vals.m[2] = 1 + + return vals +end +function likelihood_optima(model::Model{typeof(demo_nested_colons)}) + # Get some containers to fill. + vals = rand_prior_true(model) + + # NOTE: These are "as close to zero as we can get". + vals.s.params[1].subparams[1, 1, 1] = 1e-32 + vals.s.params[1].subparams[1, 1, 2] = 1e-32 + + vals.m[1] = 1.5 + vals.m[2] = 2.0 + + return vals +end +function posterior_optima(model::Model{typeof(demo_nested_colons)}) + # Get some containers to fill. + vals = rand_prior_true(model) + + # TODO: Figure out exact for `s[1]`. + vals.s.params[1].subparams[1, 1, 1] = 0.890625 + vals.s.params[1].subparams[1, 1, 2] = 1 + vals.m[1] = 3 / 4 + vals.m[2] = 1 + + return vals +end +function rand_prior_true(rng::Random.AbstractRNG, ::Model{typeof(demo_nested_colons)}) + svec = rand(rng, InverseGamma(2, 3), 2) + return (; + s=(; params=[(; subparams=reshape(svec, (1, 1, 2)))]), + m=rand(rng, MvNormal(zeros(2), Diagonal(svec))), + ) +end + """ A collection of models corresponding to the posterior distribution defined by the generative process @@ -749,6 +859,7 @@ const DEMO_MODELS = ( demo_dot_assume_observe_submodel(), demo_dot_assume_observe_matrix_index(), demo_assume_matrix_observe_matrix_index(), + demo_nested_colons(), ) """ diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 1340846a9..55f613e87 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -30,7 +30,7 @@ end # The non-generated function implementations of these would be # _has_colon(::T) where {T<:Tuple} = any(x <: Colon for x in T.parameters) # function _is_multiindex(::T) where {T<:Tuple} -# return any(x <: UnitRange || x <: Colon for x in T.parameters) +# return any(x <: AbstractUnitRange || x <: Colon for x in T.parameters) # end # However, constant propagation sometimes fails if the index tuple is too big (e.g. length # 4), so we play it safe and use generated functions. Constant propagating these is @@ -39,18 +39,18 @@ end @generated function _has_colon(::T) where {T<:Tuple} for x in T.parameters if x <: Colon - return :(true) + return :(return true) end end - return :(false) + return :(return false) end @generated function _is_multiindex(::T) where {T<:Tuple} for x in T.parameters - if x <: UnitRange || x <: Colon - return :(true) + if x <: AbstractUnitRange || x <: Colon || x <: AbstractPPL.ConcretizedSlice + return :(return true) end end - return :(false) + return :(return false) end """ @@ -74,7 +74,10 @@ _merge_recursive(_, x2) = x2 const PARTIAL_ARRAY_DIM_GROWTH_FACTOR = 4 """A convenience for defining method argument type bounds.""" -const INDEX_TYPES = Union{Integer,UnitRange,Colon} +const INDEX_TYPES = Union{Integer,AbstractUnitRange,Colon,AbstractPPL.ConcretizedSlice} + +_unwrap_concretized_slice(cs::AbstractPPL.ConcretizedSlice) = cs.range +_unwrap_concretized_slice(x::Union{Integer,AbstractUnitRange,Colon}) = x """ ArrayLikeBlock{T,I} @@ -376,7 +379,7 @@ end """Return the length needed in a dimension given an index.""" _length_needed(i::Integer) = i -_length_needed(r::UnitRange) = last(r) +_length_needed(r::AbstractUnitRange) = last(r) """Take the minimum size that a dimension of a PartialArray needs to be, and return the size we choose it to be. This size will be the smallest possible power of @@ -447,12 +450,16 @@ function _check_index_validity(pa::PartialArray, inds::NTuple{N,INDEX_TYPES}) wh throw(BoundsError(pa, inds)) end if _has_colon(inds) - throw(ArgumentError("Indexing PartialArrays with Colon is not supported")) + msg = """ + Indexing PartialArrays with Colon is not supported. + You may need to concretise the `VarName` first.""" + throw(ArgumentError(msg)) end return nothing end function _getindex(pa::PartialArray, inds::Vararg{INDEX_TYPES}) + inds = _unwrap_concretized_slice.(inds) _check_index_validity(pa, inds) if !(checkbounds(Bool, pa.mask, inds...) && all(@inbounds(getindex(pa.mask, inds...)))) throw(BoundsError(pa, inds)) @@ -501,6 +508,7 @@ function _getindex(pa::PartialArray, inds::Vararg{INDEX_TYPES}) end function _haskey(pa::PartialArray, inds::NTuple{N,INDEX_TYPES}) where {N} + inds = _unwrap_concretized_slice.(inds) _check_index_validity(pa, inds) hasall = checkbounds(Bool, pa.mask, inds...) && all(@inbounds(getindex(pa.mask, inds...))) @@ -528,6 +536,7 @@ function _haskey(pa::PartialArray, inds::NTuple{N,INDEX_TYPES}) where {N} end function BangBang.delete!!(pa::PartialArray, inds::Vararg{INDEX_TYPES}) + inds = _unwrap_concretized_slice.(inds) _check_index_validity(pa, inds) if _is_multiindex(inds) pa.mask[inds...] .= false @@ -537,7 +546,7 @@ function BangBang.delete!!(pa::PartialArray, inds::Vararg{INDEX_TYPES}) return pa end -_ensure_range(r::UnitRange) = r +_ensure_range(r::AbstractUnitRange) = r _ensure_range(i::Integer) = i:i """ @@ -580,6 +589,7 @@ function _needs_arraylikeblock(value, inds::Vararg{INDEX_TYPES}) end function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) + inds = _unwrap_concretized_slice.(inds) _check_index_validity(pa, inds) pa = if checkbounds(Bool, pa.mask, inds...) pa @@ -733,6 +743,7 @@ The there are two major limitations to indexing by VarNamedTuples: * `VarName`s with `Colon`s, (e.g. `a[:]`) are not supported. This is because the meaning of `a[:]` is ambiguous if only some elements of `a`, say `a[1]` and `a[3]`, are defined. + However, _concretised_ `VarName`s with `Colon`s are supported. * Any `VarNames` with IndexLenses` must have a consistent number of indices. That is, one cannot set `a[1]` and `a[1,2]` in the same `VarNamedTuple`. diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 42e377440..2c0e21bec 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -144,6 +144,14 @@ @testset "SimpleVarInfo on $(nameof(model))" for model in DynamicPPL.TestUtils.ALL_MODELS + if model.f === DynamicPPL.TestUtils.demo_nested_colons + # TODO(mhauru) Either VarNamedVector or SimpleVarInfo has a bug that causes + # the push!! below to fail with a NamedTuple variable like what + # demo_nested_colons has. I don't want to fix it now though, because this may + # all go soon (as of 2025-12-16). + @test false broken = true + continue + end # We might need to pre-allocate for the variable `m`, so we need # to see whether this is the case. svi_nt = SimpleVarInfo(DynamicPPL.TestUtils.rand_prior_true(model)) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 8be72a184..6578d19ae 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -4,7 +4,7 @@ using Combinatorics: Combinatorics using Test: @inferred, @test, @test_throws, @testset using DynamicPPL: DynamicPPL, @varname, VarNamedTuple using DynamicPPL.VarNamedTuples: PartialArray, ArrayLikeBlock -using AbstractPPL: VarName, prefix +using AbstractPPL: VarName, concretize, prefix using BangBang: setindex!! """ @@ -231,6 +231,25 @@ Base.size(st::SizedThing) = st.size vnt = @inferred(setindex!!(vnt, 6, vn5)) @test @inferred(getindex(vnt, vn5)) == 6 test_invariants(vnt) + + # ConcretizedSlices + vnt = VarNamedTuple() + x = [1, 2, 3] + vn = concretize(@varname(y[:]), x) + vnt = @inferred(setindex!!(vnt, x, vn)) + @test haskey(vnt, vn) + @test @inferred(getindex(vnt, vn)) == x + test_invariants(vnt) + + y = fill("a", (3, 2, 4)) + x = y[:, 2, :] + a = (; b=[nothing, nothing, (; c=(; d=reshape(y, (1, 3, 2, 4, 1))))]) + vn = @varname(a.b[3].c.d[1, 3:5, 2, :, 1]) + vn = concretize(vn, a) + vnt = @inferred(setindex!!(vnt, x, vn)) + @test haskey(vnt, vn) + @test @inferred(getindex(vnt, vn)) == x + test_invariants(vnt) end @testset "equality and hash" begin From 7bdce5ce53a9dc410ad140f782d4ad4c7b31f939 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 16 Dec 2025 16:47:50 +0000 Subject: [PATCH 18/26] Start the VNT HISTORY.md entry --- HISTORY.md | 106 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) diff --git a/HISTORY.md b/HISTORY.md index 9dc4414ce..0ad1824dd 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -2,6 +2,112 @@ ## 0.40 +### Changes to indexing random variables with square brackets + +0.40 internally reimplements how DynamicPPL handles random variables like `x[1]`, `x.y[2,2]`, and `x[:,1:4,5]`, i.e. ones that use indexing with square brackets. +Most of this is invisible to users, but it has some effects that show on the surface. +The gist of the changes is that any indexing by square brackets is now implicitly assumed to be indexing into a regular `Base.Array`, with 1-based indexing. +The general effect this has is that the new rules on what is and isn't allowed are stricter, forbidding some old syntax that used to be allowed, and at the same time guaranteeing that it works correctly. +(Previously there were some sharp edges around these sorts of variable names.) + +#### No more linear indexing of multidimensional arrays + +Previously you could do this: + +```julia +x = Array{Float64,2}(undef, (2, 2)) +x[1] ~ Normal() +x[1, 1] ~ Normal() +``` + +Now you can't, this will error. +If you first create a variable like `x[1]`, DynamicPPL from there on assumes that this variable only takes a single index (like a `Vector`). +It will then error if you try to index the same variable with any other number of indices. + +The same logic also bans this, which likewise was previously allowed: + +```julia +x = Array{Float64,2}(undef, (2, 2)) +x[1, 1, 1] ~ Normal() +x[1, 1] ~ Normal() +``` + +This made use of Julia allowing trailing indices of `1`. + +Note that the above models were previously quite dangerous and easy to misuse, because DynamicPPL was oblivious to the fact that e.g. `x[1]` and `x[1,1]` refer to the same element. +Both of the above examples previously created 2-dimensional models, with two distinct random variables, one of which effectively overwrote the other in the model body. + +TODO(mhauru) This may cause surprising issues when using `eachindex`, which is generally encouraged, e.g. + +``` +x = Array{Float64,2}(undef, (3, 3) +for i in eachindex(x) + x[i] ~ Normal() +end +``` + +Maybe we should fix linear indexing before releasing? + +#### No more square bracket indexing with arbitrary keys + +Previously you could do this: + +```julia +x = Dict() +x["a"] ~ Normal() +``` + +Now you can't, this will error. +This is because DynamicPPL now assumes that if you are indexing with square brackets, you are dealing with an `Array`, for which `"a"` is not a valid index. +You can still use a dictionary on the left-hand side of a `~` statement as long as the indices are valid indices to an `Array`, e.g. integers. + +#### No more unusually indexed arrays, such as `OffsetArrays` + +Previously you could do this + +```julia +using OffsetArrays +x = OffsetArray(Vector{Float64}(undef, 3), -3) +x[-2] ~ Normal() +0.0 ~ Normal(x[-2]) +``` + +Now you can't, this will error. +This is because DynamicPPL now assumes that if you are indexing with square brackes, you are dealing with an `Array`, for which `-2` is not a valid index. + +#### The above limitations are not fundamental + +The above, new restrictions to what sort of variable names are allowed aren't fundamental. +With some effort we could e.g. add support for linear indexing, this time done properly, so that e.g. `x[1,1]` and `x[1]` would be the same variable. +Likewise, we could manually add structures to support indexing into dictionaries or `OffsetArrays`. +If this would be useful to you, let us know. + +#### This only affects `~` statements + +You can still use any arbitrary indexing within your model in statements that don't involve `~`. +For instance, you can use `OffsetArray`s, or linear indexing, as long as you don't put them on the left-hand side of a `~`. + +#### Performance benefits + +The upside of all these new limitations is that models that use square bracket indexing are now faster. +For instance, take the following model + +```julia +@model function f() + x = Vector{Float64}(undef, 1000) + for i in eachindex(x) + x[i] ~ Normal() + end + return 0.0 ~ Normal(sum(x)) +end +``` + +Evaluating the log joint for this model has gotten about 3 times faster in v0.40. + +#### Robustness benefits + +TODO(mhauru) Add an example here for how this improves `condition`ing, once `condition` uses `VarNamedTuple`. + ## 0.39.4 Removed the internal functions `DynamicPPL.getranges`, `DynamicPPL.vector_getrange`, and `DynamicPPL.vector_getranges` (the new LogDensityFunction construction does exactly the same thing, so this specialised function was not needed). From 9992051225b887f45da466c0067e113d17280b71 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 16 Dec 2025 17:48:03 +0000 Subject: [PATCH 19/26] Skip a type stability test on 1.10 --- test/model.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/model.jl b/test/model.jl index c878fd905..30fc614ca 100644 --- a/test/model.jl +++ b/test/model.jl @@ -408,6 +408,14 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() DynamicPPL.TestUtils.DEMO_MODELS..., DynamicPPL.TestUtils.demo_lkjchol(2) ] @testset "$(model.f)" for model in models_to_test + if model.f === DynamicPPL.TestUtils.demo_nested_colons && VERSION < v"1.11" + # On v1.10, the demo_nested_colons model, which uses a lot of + # NamedTuples, is badly type unstable. Not worth doing much about + # it, since it's fixed on later Julia versions, so just skipping + # these tests. + @test_skip false skip = true + continue + end vns = DynamicPPL.TestUtils.varnames(model) example_values = DynamicPPL.TestUtils.rand_prior_true(model) varinfos = filter( From 753ca81b85af88adb0970dff88670dda2445fa4d Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 16 Dec 2025 18:32:34 +0000 Subject: [PATCH 20/26] Fix test_skip --- test/model.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/model.jl b/test/model.jl index 30fc614ca..3272fd8b5 100644 --- a/test/model.jl +++ b/test/model.jl @@ -413,7 +413,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() # NamedTuples, is badly type unstable. Not worth doing much about # it, since it's fixed on later Julia versions, so just skipping # these tests. - @test_skip false skip = true + @test false skip = true continue end vns = DynamicPPL.TestUtils.varnames(model) From d9e5405df819835daa81429e51343659cb444e3d Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 19 Dec 2025 10:42:24 +0000 Subject: [PATCH 21/26] Mark a test as broken on 1.10 --- test/logdensityfunction.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index f96e7bf27..9f30d7b68 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -113,7 +113,10 @@ end end ldf = DynamicPPL.LogDensityFunction(m, DynamicPPL.getlogjoint_internal, vi) x = vi[:] - @inferred LogDensityProblems.logdensity(ldf, x) + # The below type inference fails on v1.10. + @test begin + @inferred LogDensityProblems.logdensity(ldf, x) + end broken = (VERSION < v"1.11.0") end end end From 267c55471f44bb1ae03daf4877a25405bb79b437 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 19 Dec 2025 11:46:12 +0000 Subject: [PATCH 22/26] Trivial bug fix --- test/logdensityfunction.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index 9f30d7b68..2e3c56c53 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -116,6 +116,7 @@ end # The below type inference fails on v1.10. @test begin @inferred LogDensityProblems.logdensity(ldf, x) + true end broken = (VERSION < v"1.11.0") end end From 76ac5b617218f093c35dfd0b3403f9ba10902a7f Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 19 Dec 2025 18:06:26 +0000 Subject: [PATCH 23/26] Use skip rather than broken for an inference test --- test/logdensityfunction.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index 2e3c56c53..b58006de2 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -117,7 +117,7 @@ end @test begin @inferred LogDensityProblems.logdensity(ldf, x) true - end broken = (VERSION < v"1.11.0") + end skip = (VERSION < v"1.11.0") end end end From 0c50bd74276dd1e5bd3efa0e00ba0766bc81ee86 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 6 Jan 2026 14:21:46 +0000 Subject: [PATCH 24/26] Fix a docs typo Co-authored-by: Penelope Yong --- HISTORY.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/HISTORY.md b/HISTORY.md index 0ad1824dd..bb40b8464 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -40,7 +40,7 @@ Both of the above examples previously created 2-dimensional models, with two dis TODO(mhauru) This may cause surprising issues when using `eachindex`, which is generally encouraged, e.g. ``` -x = Array{Float64,2}(undef, (3, 3) +x = Array{Float64,2}(undef, (3, 3)) for i in eachindex(x) x[i] ~ Normal() end From 6b211b105739a470d69d11b18d81ba9b069e0fb1 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 6 Jan 2026 14:26:31 +0000 Subject: [PATCH 25/26] Use floatmin in test_utils --- src/test_utils/models.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/test_utils/models.jl b/src/test_utils/models.jl index dcc2d92a2..283d5f01b 100644 --- a/src/test_utils/models.jl +++ b/src/test_utils/models.jl @@ -680,8 +680,8 @@ function likelihood_optima(model::MultivariateAssumeDemoModels) vals = rand_prior_true(model) # NOTE: These are "as close to zero as we can get". - vals.s[1] = 1e-32 - vals.s[2] = 1e-32 + vals.s[1] = floatmin() + vals.s[2] = floatmin() vals.m[1] = 1.5 vals.m[2] = 2.0 @@ -733,8 +733,8 @@ function likelihood_optima(model::MatrixvariateAssumeDemoModels) vals = rand_prior_true(model) # NOTE: These are "as close to zero as we can get". - vals.s[1, 1] = 1e-32 - vals.s[1, 2] = 1e-32 + vals.s[1, 1] = floatmin() + vals.s[1, 2] = floatmin() vals.m[1] = 1.5 vals.m[2] = 2.0 @@ -783,8 +783,8 @@ function likelihood_optima(model::Model{typeof(demo_nested_colons)}) vals = rand_prior_true(model) # NOTE: These are "as close to zero as we can get". - vals.s.params[1].subparams[1, 1, 1] = 1e-32 - vals.s.params[1].subparams[1, 1, 2] = 1e-32 + vals.s.params[1].subparams[1, 1, 1] = floatmin() + vals.s.params[1].subparams[1, 1, 2] = floatmin() vals.m[1] = 1.5 vals.m[2] = 2.0 From 57fd84b20c0ad7c251a415739a64ef186fafa2c7 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 6 Jan 2026 14:38:42 +0000 Subject: [PATCH 26/26] Narrow a skip clause --- test/logdensityfunction.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index b58006de2..7014140b9 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -114,10 +114,11 @@ end ldf = DynamicPPL.LogDensityFunction(m, DynamicPPL.getlogjoint_internal, vi) x = vi[:] # The below type inference fails on v1.10. + skip = (VERSION < v"1.11.0" && m.f === DynamicPPL.TestUtils.demo_nested_colons) @test begin @inferred LogDensityProblems.logdensity(ldf, x) true - end skip = (VERSION < v"1.11.0") + end skip = skip end end end