Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# 0.40.17

Implemented missing methods for `Base.copy` on internal struct.

# 0.40.16

Fixed `Base.copy` for `VNTAccumulator` and `TSVNTAccumulator` to also copy the `acc.f` field, not just `acc.values`.
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.40.16"
version = "0.40.17"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
1 change: 1 addition & 0 deletions src/accumulators/pointwise_logdensities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ where `Prior` and `Likelihood` are the boolean type parameters. This accumulator
store the log-probabilities for all tilde-statements in the model.
"""
struct PointwiseLogProb{Prior,Likelihood} end
Base.copy(plp::PointwiseLogProb) = plp
function (plp::PointwiseLogProb{Prior,Likelihood})(
val, tval, logjac, vn, dist
) where {Prior,Likelihood}
Expand Down
2 changes: 2 additions & 0 deletions src/accumulators/raw_values.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ struct GetRawValues
"A flag indicating whether variables on the LHS of := should also be included"
include_colon_eq::Bool
end
Base.copy(g::GetRawValues) = g
# TODO(mhauru) The deepcopy here is quite unfortunate. It is needed so that the model body
# can go mutating the object without that in turn mutating the value stored in the
# accumulator, which should be as it was at `~` time. Could there be a way around this?
Expand Down Expand Up @@ -49,6 +50,7 @@ struct DebugGetRawValues
repeated_vns::Set{VarName}
end
is_extracting_colon_eq_values(g::DebugGetRawValues) = true
Base.copy(d::DebugGetRawValues) = DebugGetRawValues(copy(d.repeated_vns))
function DebugRawValueAccumulator()
return VNTAccumulator{RAW_VALUE_ACCNAME}(DebugGetRawValues(Set{VarName}()))
end
Expand Down
5 changes: 5 additions & 0 deletions src/accumulators/vnt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ The function `f` should have the signature:
where `val`, `tval`, `logjac`, `vn`, and `dist` have their usual meanings in
accumulate_assume!! (see its docstring for more details). If a value does not need to
be accumulated, this can be signalled by returning `DoNotAccumulate()` from `f`.

If `f` is a struct (and not a function), you also need to define `Base.copy` on `f`.
"""
struct VNTAccumulator{AccName,F,VNT<:VarNamedTuple} <: AbstractAccumulator
f::F
Expand Down Expand Up @@ -61,6 +63,9 @@ for acc_type in (:VNTAccumulator, :TSVNTAccumulator)
function Base.copy(acc::$acc_type{AccName}) where {AccName}
return $acc_type{AccName}(copy(acc.f), copy(acc.values))
end
function Base.copy(acc::$acc_type{AccName,F}) where {AccName,F<:Function}
return $acc_type{AccName}(acc.f, copy(acc.values))
end
accumulator_name(::$acc_type{AccName}) where {AccName} = AccName

function update_values(
Expand Down
22 changes: 22 additions & 0 deletions test/accumulators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,21 @@ using DynamicPPL:
get_priors,
@varname

TEST_ACCUMULATORS = (
LogPriorAccumulator(1.0),
LogLikelihoodAccumulator(1.0),
LogJacobianAccumulator(1.0),
RawValueAccumulator(false),
DynamicPPL.DebugRawValueAccumulator(),
DynamicPPL.BijectorAccumulator(),
DynamicPPL.VNTAccumulator{DynamicPPL.POINTWISE_ACCNAME}(
DynamicPPL.PointwiseLogProb{true,true}()
),
PriorDistributionAccumulator(),
DynamicPPL.VectorValueAccumulator(),
DynamicPPL.VectorParamAccumulator([], Bool[], VarNamedTuple()),
)

@testset "accumulators" begin
@testset "individual accumulator types" begin
@testset "constructors" begin
Expand Down Expand Up @@ -211,6 +226,13 @@ using DynamicPPL:
@test priors[@varname(x)] == Normal()
@test priors[@varname(y)] == Normal(vals[@varname(x)])
end

@testset "Base.copy" begin
for acc in TEST_ACCUMULATORS
# just check that it works.
@test copy(acc) isa Any
end
end
end

@info "Completed $(@__FILE__) in $(now() - __now__)."
Expand Down
Loading