Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/accumulators/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ abstract type LogProbAccumulator{T<:Real} <: AbstractAccumulator end
Create a new `LogProbAccumulator` accumulator with the log prior initialized to zero.
"""
(::Type{AccType})() where {T<:Real,AccType<:LogProbAccumulator{T}} = AccType(zero(T))
(::Type{AccType})() where {AccType<:LogProbAccumulator} = AccType{LogProbType}()
(::Type{AccType})() where {AccType<:LogProbAccumulator} = AccType{LogProbType[]}()

Base.copy(acc::LogProbAccumulator) = acc

Expand Down Expand Up @@ -175,7 +175,7 @@ function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn, te
return acclogp(acc, Distributions.loglikelihood(right, left))
end

function default_accumulators(::Type{FloatT}=LogProbType) where {FloatT}
function default_accumulators(::Type{FloatT}=LogProbType[]) where {FloatT}
return AccumulatorTuple(
LogPriorAccumulator{FloatT}(),
LogJacobianAccumulator{FloatT}(),
Expand Down
6 changes: 3 additions & 3 deletions src/distribution_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ function Distributions.rand!(
return Distributions.rand!(rng, d.dist, x)
end
function Distributions.logpdf(::NoDist{<:Univariate}, x::Real)
return zero(LogProbType)
return zero(LogProbType[])
end
function Distributions.logpdf(::NoDist{<:Multivariate}, x::AbstractVector{<:Real})
return zero(LogProbType)
return zero(LogProbType[])
end
function Distributions.logpdf(::NoDist{<:Matrixvariate}, x::AbstractMatrix{<:Real})
return zero(LogProbType)
return zero(LogProbType[])
end

for f in (
Expand Down
6 changes: 3 additions & 3 deletions src/transformed_values.jl
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ function apply_transform_strategy(
# vectorisation transform. However, sometimes that's not needed (e.g. when
# evaluating with an OnlyAccsVarInfo). So we just return an UntransformedValue. If a
# downstream function requires a VectorValue, it's on them to generate it.
(raw_value, UntransformedValue(raw_value), zero(LogProbType))
(raw_value, UntransformedValue(raw_value), zero(LogProbType[]))
else
error("unknown target transform $target")
end
Expand All @@ -383,7 +383,7 @@ function apply_transform_strategy(
(raw_value, linked_tv, logjac)
elseif target isa Unlink
# No need to transform further
(raw_value, tv, zero(LogProbType))
(raw_value, tv, zero(LogProbType[]))
else
error("unknown target transform $target")
end
Expand All @@ -406,7 +406,7 @@ function apply_transform_strategy(
(raw_value, linked_tv, logjac)
elseif target isa Unlink
# No need to transform further
(raw_value, tv, zero(LogProbType))
(raw_value, tv, zero(LogProbType[]))
else
error("unknown target transform $target")
end
Expand Down
4 changes: 2 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ The type for all log probability variables.

This is Float64 on 64-bit systems and Float32 on 32-bit systems.
"""
const LogProbType = float(Real)
const LogProbType = Ref(float(Real))

"""
typed_identity(x)
Expand Down Expand Up @@ -42,7 +42,7 @@ behaviour.
function typed_identity end
@inline typed_identity(x) = x
@inline Bijectors.with_logabsdet_jacobian(::typeof(typed_identity), x) =
(x, zero(LogProbType))
(x, zero(LogProbType[]))
@inline Bijectors.inverse(::typeof(typed_identity)) = typed_identity

"""
Expand Down
Loading