diff --git a/HISTORY.md b/HISTORY.md index ecec98217..e3a1fa9de 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,7 @@ +# 0.40.17 + +Implemented missing methods for `Base.copy` on internal struct. + # 0.40.16 Fixed `Base.copy` for `VNTAccumulator` and `TSVNTAccumulator` to also copy the `acc.f` field, not just `acc.values`. diff --git a/Project.toml b/Project.toml index d01c1fcc0..86503e8d4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.40.16" +version = "0.40.17" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/accumulators/pointwise_logdensities.jl b/src/accumulators/pointwise_logdensities.jl index 63e5aa4b9..ac2f59b7d 100644 --- a/src/accumulators/pointwise_logdensities.jl +++ b/src/accumulators/pointwise_logdensities.jl @@ -13,6 +13,7 @@ where `Prior` and `Likelihood` are the boolean type parameters. This accumulator store the log-probabilities for all tilde-statements in the model. """ struct PointwiseLogProb{Prior,Likelihood} end +Base.copy(plp::PointwiseLogProb) = plp function (plp::PointwiseLogProb{Prior,Likelihood})( val, tval, logjac, vn, dist ) where {Prior,Likelihood} diff --git a/src/accumulators/raw_values.jl b/src/accumulators/raw_values.jl index e5b1df1de..cad84e9fd 100644 --- a/src/accumulators/raw_values.jl +++ b/src/accumulators/raw_values.jl @@ -4,6 +4,7 @@ struct GetRawValues "A flag indicating whether variables on the LHS of := should also be included" include_colon_eq::Bool end +Base.copy(g::GetRawValues) = g # TODO(mhauru) The deepcopy here is quite unfortunate. It is needed so that the model body # can go mutating the object without that in turn mutating the value stored in the # accumulator, which should be as it was at `~` time. Could there be a way around this? @@ -49,6 +50,7 @@ struct DebugGetRawValues repeated_vns::Set{VarName} end is_extracting_colon_eq_values(g::DebugGetRawValues) = true +Base.copy(d::DebugGetRawValues) = DebugGetRawValues(copy(d.repeated_vns)) function DebugRawValueAccumulator() return VNTAccumulator{RAW_VALUE_ACCNAME}(DebugGetRawValues(Set{VarName}())) end diff --git a/src/accumulators/vnt.jl b/src/accumulators/vnt.jl index f9e144104..5fbc5e709 100644 --- a/src/accumulators/vnt.jl +++ b/src/accumulators/vnt.jl @@ -16,6 +16,8 @@ The function `f` should have the signature: where `val`, `tval`, `logjac`, `vn`, and `dist` have their usual meanings in accumulate_assume!! (see its docstring for more details). If a value does not need to be accumulated, this can be signalled by returning `DoNotAccumulate()` from `f`. + +If `f` is a struct (and not a function), you also need to define `Base.copy` on `f`. """ struct VNTAccumulator{AccName,F,VNT<:VarNamedTuple} <: AbstractAccumulator f::F @@ -61,6 +63,9 @@ for acc_type in (:VNTAccumulator, :TSVNTAccumulator) function Base.copy(acc::$acc_type{AccName}) where {AccName} return $acc_type{AccName}(copy(acc.f), copy(acc.values)) end + function Base.copy(acc::$acc_type{AccName,F}) where {AccName,F<:Function} + return $acc_type{AccName}(acc.f, copy(acc.values)) + end accumulator_name(::$acc_type{AccName}) where {AccName} = AccName function update_values( diff --git a/test/accumulators.jl b/test/accumulators.jl index 6ecde3e84..ba2585f86 100644 --- a/test/accumulators.jl +++ b/test/accumulators.jl @@ -25,6 +25,21 @@ using DynamicPPL: get_priors, @varname +TEST_ACCUMULATORS = ( + LogPriorAccumulator(1.0), + LogLikelihoodAccumulator(1.0), + LogJacobianAccumulator(1.0), + RawValueAccumulator(false), + DynamicPPL.DebugRawValueAccumulator(), + DynamicPPL.BijectorAccumulator(), + DynamicPPL.VNTAccumulator{DynamicPPL.POINTWISE_ACCNAME}( + DynamicPPL.PointwiseLogProb{true,true}() + ), + PriorDistributionAccumulator(), + DynamicPPL.VectorValueAccumulator(), + DynamicPPL.VectorParamAccumulator([], Bool[], VarNamedTuple()), +) + @testset "accumulators" begin @testset "individual accumulator types" begin @testset "constructors" begin @@ -211,6 +226,13 @@ using DynamicPPL: @test priors[@varname(x)] == Normal() @test priors[@varname(y)] == Normal(vals[@varname(x)]) end + + @testset "Base.copy" begin + for acc in TEST_ACCUMULATORS + # just check that it works. + @test copy(acc) isa Any + end + end end @info "Completed $(@__FILE__) in $(now() - __now__)."