From d0edd81c5a56fe6cbf272dccafdf73e076428224 Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Fri, 13 Mar 2026 10:20:50 -0400 Subject: [PATCH 01/10] first pass at the SMC iterface --- src/mcmc/Inference.jl | 2 +- src/mcmc/smc.jl | 431 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 432 insertions(+), 1 deletion(-) create mode 100644 src/mcmc/smc.jl diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 73f0661e6..5daf33dd8 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -96,7 +96,7 @@ include("ess.jl") include("hmc.jl") include("mh.jl") include("is.jl") -include("particle_mcmc.jl") +include("smc.jl") include("sghmc.jl") include("emcee.jl") include("prior.jl") diff --git a/src/mcmc/smc.jl b/src/mcmc/smc.jl new file mode 100644 index 000000000..3f53d148b --- /dev/null +++ b/src/mcmc/smc.jl @@ -0,0 +1,431 @@ +#### +#### Combining DynamicPPL and Libtask. +#### + +mutable struct TracedModel{T<:TapedTask} + const task::T + varinfo::AbstractVarInfo +end + +function construct_task(rng::AbstractRNG, model::Model, vi::AbstractVarInfo) + inner_rng = Random.seed!(Random123.Philox2x(), rand(rng, Random.Sampler(rng, UInt64))) + inner_model = DynamicPPL.setleafcontext(model, SMCContext()) + args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(inner_model, vi) + return TapedTask(inner_rng, inner_model.f, args...; kwargs...) +end + +# overload consume to store the local varinfo +function Libtask.consume(trace::TracedModel) + score = Libtask.consume(trace.task) + set_varinfo!(trace, score) + return score +end + +# apply the same iteration utilities as a TapedTask +function Base.iterate(trace::TracedModel, ::Nothing=nothing) + v = Libtask.consume(trace) + return v === nothing ? nothing : (v, nothing) +end + +Base.IteratorSize(::Type{<:TracedModel}) = Base.SizeUnknown() +Base.IteratorEltype(::Type{<:TracedModel}) = Base.EltypeUnknown() + +# these will be useful when constructing new traces from proposed values +get_model(trace::TracedModel) = trace.task.fargs[2] +get_varinfo(trace::TracedModel) = trace.varinfo +get_rng(trace::TracedModel) = trace.task.taped_globals + +function TracedModel(rng::AbstractRNG, model::Model) + vi = DynamicPPL.setacc!!(VarInfo(model), ProduceLogLikelihoodAccumulator()) + vi = DynamicPPL.empty!!(vi) + return TracedModel(construct_task(rng, model, vi), vi) +end + +# if score is nothing, the varinfo is caught up and there's no need to update +set_varinfo!(::TracedModel, ::Nothing) = nothing +set_varinfo!(trace::TracedModel, ::Real) = (trace.varinfo = task_local_storage(:varinfo); ) + +struct ProduceLogLikelihoodAccumulator{T<:Real} <: DynamicPPL.LogProbAccumulator{T} + logp::T +end + +DynamicPPL.accumulator_name(::Type{<:ProduceLogLikelihoodAccumulator}) = :LogLikelihood +DynamicPPL.logp(acc::ProduceLogLikelihoodAccumulator) = acc.logp + +# this is the only difference between LogLikelihoodAccumulator +function DynamicPPL.acclogp(acc::ProduceLogLikelihoodAccumulator, val) + task_local_storage(:logscore, val) + newacc = ProduceLogLikelihoodAccumulator(DynamicPPL.logp(acc) + val) + return newacc +end + +function DynamicPPL.accumulate_assume!!( + acc::ProduceLogLikelihoodAccumulator, val, tval, logjac, vn, dist, template +) + return acc +end + +function DynamicPPL.accumulate_observe!!( + acc::ProduceLogLikelihoodAccumulator, dist, val, vn, template +) + return DynamicPPL.acclogp(acc, Distributions.loglikelihood(dist, val)) +end + +# Relevant call chains: +# tilde_observe!! -> accumulate_observe!! -> acclogp -> produce +Libtask.@might_produce(DynamicPPL.tilde_observe!!) +Libtask.@might_produce(DynamicPPL.accumulate_observe!!) +Libtask.@might_produce(DynamicPPL.acclogp) + +# tilde_assume!! in Gibbs -> tilde_observe!! -> ... +Libtask.@might_produce(DynamicPPL.tilde_assume!!) + +# @addlogprob!(::Number) -> accloglikelihood!! -> map_accumulator!! -> acclogp -> produce +Libtask.@might_produce(DynamicPPL.accloglikelihood!!) +Libtask.@might_produce(DynamicPPL.map_accumulator!!) + +# @addlogprob!(::NamedTuple) -> acclogp!! -> accloglikelihood!! -> ... +Libtask.@might_produce(DynamicPPL.acclogp!!) + +# Generic catch-all to handle submodels and kwargs on models, see +# https://github.com/TuringLang/Libtask.jl/issues/217 +Libtask.might_produce_if_sig_contains(::Type{<:DynamicPPL.Model}) = true + +struct SMCContext <: DynamicPPL.AbstractContext end + +function init_context(rng::AbstractRNG, vi::VarInfo, vn::VarName) + if ~haskey(vi, vn) + return InitContext(rng, InitFromPrior(), vi.transform_strategy) + else + return DefaultContext() + end +end + +function DynamicPPL.tilde_assume!!( + ::SMCContext, dist::Distribution, vn::VarName, template::Any, vi::AbstractVarInfo +) + rng = Libtask.get_taped_globals(AbstractRNG) + dispatch_ctx = init_context(rng, vi, vn) + val, vi = DynamicPPL.tilde_assume!!(dispatch_ctx, dist, vn, template, vi) + return val, vi +end + +function DynamicPPL.tilde_observe!!( + ::SMCContext, dist::Distribution, val, vn::Union{VarName,Nothing}, template, vi::AbstractVarInfo +) + val, vi = DynamicPPL.tilde_observe!!(DefaultContext(), dist, val, vn, template, vi) + task_local_storage(:varinfo, vi) + Libtask.produce(task_local_storage(:logscore)) + return val, vi +end + +#### +#### Resampling Schemes. +#### + +abstract type AbstractResampler end + +struct AlwaysResample <: AbstractResampler end + +function should_resample(::AbstractVector, ::AlwaysResample) + return true +end + +struct ESSResampler{T<:Real} <: AbstractResampler + threshold::T +end + +function should_resample(weights::AbstractVector, resampler::ESSResampler) + ess = inv(sum(abs2, weights)) + return ess ≤ resampler.threshold * length(weights) +end + +#### +#### Particle Containers. +#### + +""" + Particle +""" +mutable struct Particle{PT,WT<:Real} + const value::PT + logw::WT +end + +Particle(value) = Particle(value, 0.0) + +""" + ParticleContainer + +A custom array object to handle getting and setting of particle values as well as their log- +weights. This allows a plethora of in-place operations and intuitive handling throughout the +SMC process. +""" +const ParticleContainer{PT,WT} = Vector{Particle{PT,WT}} +ParticleContainer(values::AbstractVector) = Particle.(values) + +# this is quite overkill, so I might ditch it in future versions +@inline Base.getproperty(pc::ParticleContainer, s::Symbol) = _getproperty(pc, Val(s)) +@inline _getproperty(pc::ParticleContainer, ::Val{:values}) = @. getproperty(pc, :value) +@inline _getproperty(pc::ParticleContainer, ::Val{:log_weights}) = @. getproperty(pc, :logw) +@inline _getproperty(pc::ParticleContainer, ::Val{S}) where {S} = getfield(pc, S) + +function StatsBase.weights(particles::ParticleContainer) + return weights(softmax(particles.log_weights)) +end + +function StatsBase.sample(rng::AbstractRNG, particles::ParticleContainer) + return sample(rng, particles.values, weights(particles)) +end + +function resample!(rng::AbstractRNG, particles::ParticleContainer, weights::Weights) + idx = sample_ancestors(rng, weights.values) + @. particles = Particle($split!(rng, particles.values[idx])) +end + +# TODO: optimize this for the love of god +function split!(rng::AbstractRNG, particles::Vector{<:TracedModel}) + children = deepcopy.(particles) + seeds = rand(rng, Random.Sampler(rng, UInt64), length(particles)) + @. Random.seed!(get_rng(children), seeds) + return children +end + +advance!(particle::Particle{<:TracedModel}) = consume(particle.value) + +""" + ReferencedContainer + +An object which associates a given reference trajectory with a given particle container. One +can access `container.values` and `container.log_weights` as before, where the final element +of the vector is the reference trajectory. +""" +struct ReferencedContainer{PT,WT,RT} + particles::ParticleContainer{PT,WT} + reference::Particle{RT,WT} +end + +Base.length(pc::ReferencedContainer) = length(pc.particles) + 1 +Base.keys(pc::ReferencedContainer) = LinearIndices(pc.values) + +Base.iterate(pc::ReferencedContainer) = iterate(pc.particles) + +function Base.iterate(pc::ReferencedContainer, i) + i == length(pc) && return (pc.reference, i + 1) + return iterate(pc.particles, i) +end + +function Base.getindex(pc::ReferencedContainer, i) + i == length(pc) && return pc.reference + return pc.particles[i] +end + +function Base.collect(pc::ReferencedContainer) + particles = Vector{eltype(pc.particles)}(undef, length(pc)) + particles[1:end-1] = @views(pc.particles) + particles[end] = pc.reference + return particles +end + +Base.getproperty(pc::ReferencedContainer, s::Symbol) = _getproperty(pc, Val(s)) + +function _getproperty(pc::ReferencedContainer{PT}, ::Val{:values}) where {PT} + values = Vector{PT}(undef, length(pc.particles) + 1) + values[1:end-1] = @views(pc.particles.values) + values[end] = pc.reference.value + return values +end + +function _getproperty(pc::ReferencedContainer{PT,WT}, ::Val{:log_weights}) where {PT, WT} + log_weights = Vector{WT}(undef, length(pc.particles) + 1) + log_weights[1:end-1] = @views(pc.particles.log_weights) + log_weights[end] = pc.reference.logw + return log_weights +end + +_getproperty(pc::ReferencedContainer, ::Val{S}) where {S} = getfield(pc, S) + +function StatsBase.weights(particles::ReferencedContainer) + return weights(softmax(particles.log_weights)) +end + +function StatsBase.sample(rng::AbstractRNG, particles::ReferencedContainer) + return sample(rng, particles.values, weights(particles)) +end + +function resample!(rng::AbstractRNG, ref::ReferencedContainer, weights::Weights) + idx = sample_ancestors(rng, weights.values, length(ref.particles)) + @. ref.particles = Particle($split!(rng, ref.values[idx])) + return ref +end + +#### +#### Generic Sequential Monte Carlo sampler. +#### + +""" + SMC{RT,KT} + +A basic Sequential Monte Carlo sampler, resampling according to scheme RT and rejuvenating +the sample according to the kernel KT. + +By default this is set to always resample and never to rejuvenate, as was done in previous +versions of Turing. + +See [`ParticleGibbs`](@ref) for use within Markov Chain Monte Carlo. +""" +struct SMC{RT,KT} <: AbstractSampler + resampler::RT + kernel::KT +end + +SMC(threshold::Real) = SMC(ESSResampler(threshold), nothing) +SMC() = SMC(AlwaysResample(), nothing) + +function initialize(rng::AbstractRNG, model::DynamicPPL.Model, ::SMC, N::Integer) + return ParticleContainer([TracedModel(rng, model) for _ in 1:N]), false +end + +function initialize( + rng::AbstractRNG, model::DynamicPPL.Model, sampler::SMC, N::Integer, ref +) + particles, is_done = if isnothing(ref) + initialize(rng, model, sampler, N) + else + particles, is_done = initialize(rng, model, sampler, N - 1) + ReferencedContainer(particles, Particle(ref)), is_done + end +end + +# TODO: replace this with a systematic resampler +function sample_ancestors( + rng::AbstractRNG, weights::Vector{<:Real}, N::Integer=length(weights) +) + return rand(rng, Categorical(weights), N) +end + +increment_weight!(particle::Particle, score::Nothing) = true +increment_weight!(particle::Particle, score::Real) = (particle.logw += score; return false) + +function reweight!(particles, ::AbstractMCMC.MCMCSerial) + num_done = map(particles) do particle + score = advance!(particle) + increment_weight!(particle, score) + end + return all(num_done) +end + +function reweight!(particles, ::AbstractMCMC.MCMCThreads) + num_done = Vector{Bool}(undef, length(particles)) + Threads.@threads for i in eachindex(particles) + score = advance!(particles[i]) + num_done[i] = increment_weight!(particles[i], score) + end + return all(num_done) +end + +function maybe_resample!( + rng::AbstractRNG, + particles, + resampler::AbstractResampler +) + weights = StatsBase.weights(particles) + rs_flag = should_resample(weights, resampler) + rs_flag && resample!(rng, particles, weights) + return rs_flag +end + +# leave out rejuvenation for now, we'll cross that bridge when we get there +function rejuvenate!( + ::AbstractRNG, + particles::ParticleContainer, + ::Nothing, + ::AbstractMCMC.AbstractMCMCEnsemble, + ::Bool, + ::Integer; + kwargs... +) + return particles +end + +function smcsample( + rng::AbstractRNG, + model::DynamicPPL.Model, + sampler::SMC, + ensemble::AbstractMCMC.AbstractMCMCEnsemble, + N::Integer; + ref=nothing +) + particles, is_done = initialize(rng, model, sampler, N, ref) + iter = 0 + while !is_done + rs_flag = maybe_resample!(rng, particles, sampler.resampler) + particles = rejuvenate!(rng, particles, sampler.kernel, ensemble, rs_flag, iter) + is_done = reweight!(particles, ensemble) + iter += 1 + end + return particles +end + +# this doesn't return an MCMCChain which kinda blows, but whatever +function AbstractMCMC.sample( + rng::AbstractRNG, + model::DynamicPPL.Model, + sampler::SMC, + N::Integer; + ensemble::AbstractMCMC.AbstractMCMCEnsemble=MCMCSerial(), + kwargs... +) + return smcsample(rng, model, sampler, ensemble, N; kwargs...) +end + +#### +#### Particle Gibbs and Conditional SMC. +#### + +""" + ParticleGibbs + +An MCMC sampler which wraps a conditional Sequential Monte Carlo step at every iteration of +the sampler for `N` iterations. + +See [`SMC`](@ref) for details on the Sequential Monte Carlo kernel. + +# Examples +```julia +# samples 128 particles each iteration, resampling when ESS drops below 50% +chain = sample(model, PG(SMC(0.5), 128), 10_000) +``` +""" +struct ParticleGibbs{T<:AbstractSMC} <: AbstractMCMC.AbstractSampler + kernel::T + N::Int +end + +const PG{T} = ParticleGibbs{T} + +function AbstractMCMC.step( + rng::AbstractRNG, + model::DynamicPPL.Model, + sampler::ParticleGibbs; + kwargs..., +) + particles = smcsample(rng, model, sampler.kernel, MCMCSerial(), sampler.N); + state = sample(rng, particles) + return get_varinfo(state), state +end + +# NOTE: this needs some TLC, not sure how I can better integrate +function AbstractMCMC.step( + rng::AbstractRNG, + model::DynamicPPL.Model, + sampler::ParticleGibbs, + state::TracedModel; + kwargs..., +) + particles = smcsample( + rng, model, sampler.kernel, MCMCSerial(), sampler.N; ref=deepcopy(state) + ) + state = sample(rng, particles) + return get_varinfo(state), state +end From 5ca3c31cbc0e12fc14441b26fe38bdf17a1eb92a Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Fri, 13 Mar 2026 10:46:33 -0400 Subject: [PATCH 02/10] ensure precompilation --- Project.toml | 2 ++ src/mcmc/Inference.jl | 2 ++ src/mcmc/gibbs.jl | 11 ++++++----- src/mcmc/smc.jl | 6 +++--- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/Project.toml b/Project.toml index 38b0736b7..dd7b05b34 100644 --- a/Project.toml +++ b/Project.toml @@ -30,6 +30,7 @@ OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Random123 = "74087812-796a-5b5d-8853-05524746bad3" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" @@ -73,6 +74,7 @@ OptimizationOptimJL = "0.1 - 0.4" OrderedCollections = "1" Printf = "1" Random = "1" +Random123 = "1.7.1" Reexport = "0.2, 1" SciMLBase = "2" SpecialFunctions = "0.7.2, 0.8, 0.9, 0.10, 1, 2" diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 5daf33dd8..586df217b 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -23,6 +23,7 @@ using Random: AbstractRNG using AbstractMCMC: AbstractModel, AbstractSampler using DocStringExtensions: FIELDS, TYPEDEF, TYPEDFIELDS using DataStructures: OrderedSet, OrderedDict +using StatsBase import ADTypes import AbstractMCMC @@ -35,6 +36,7 @@ import AdvancedPS import EllipticalSliceSampling import LogDensityProblems import Random +import Random123 import MCMCChains import StatsBase: predict diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 6e7f796e7..cd127de36 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -534,11 +534,12 @@ function setparams_varinfo!!( return HMCState(params, state.i, state.kernel, hamiltonian, z, state.adaptor, state.ldf) end -function setparams_varinfo!!( - ::DynamicPPL.Model, ::PG, state::PGState, params::AbstractVarInfo -) - return PGState(params, state.rng) -end +# TODO: not sure what to do here... I'll get there eventually +# function setparams_varinfo!!( +# ::DynamicPPL.Model, ::PG, state::PGState, params::AbstractVarInfo +# ) +# return PGState(params, state.rng) +# end """ match_linking!!(varinfo_local, prev_state_local, model) diff --git a/src/mcmc/smc.jl b/src/mcmc/smc.jl index 3f53d148b..838751550 100644 --- a/src/mcmc/smc.jl +++ b/src/mcmc/smc.jl @@ -178,7 +178,7 @@ function StatsBase.sample(rng::AbstractRNG, particles::ParticleContainer) return sample(rng, particles.values, weights(particles)) end -function resample!(rng::AbstractRNG, particles::ParticleContainer, weights::Weights) +function resample!(rng::AbstractRNG, particles::ParticleContainer, weights::StatsBase.Weights) idx = sample_ancestors(rng, weights.values) @. particles = Particle($split!(rng, particles.values[idx])) end @@ -253,7 +253,7 @@ function StatsBase.sample(rng::AbstractRNG, particles::ReferencedContainer) return sample(rng, particles.values, weights(particles)) end -function resample!(rng::AbstractRNG, ref::ReferencedContainer, weights::Weights) +function resample!(rng::AbstractRNG, ref::ReferencedContainer, weights::StatsBase.Weights) idx = sample_ancestors(rng, weights.values, length(ref.particles)) @. ref.particles = Particle($split!(rng, ref.values[idx])) return ref @@ -397,7 +397,7 @@ See [`SMC`](@ref) for details on the Sequential Monte Carlo kernel. chain = sample(model, PG(SMC(0.5), 128), 10_000) ``` """ -struct ParticleGibbs{T<:AbstractSMC} <: AbstractMCMC.AbstractSampler +struct ParticleGibbs{T<:SMC} <: AbstractMCMC.AbstractSampler kernel::T N::Int end From 9747f464e0ac662955c8870e79568fdb2b204b19 Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Fri, 13 Mar 2026 11:02:35 -0400 Subject: [PATCH 03/10] formatter --- src/mcmc/smc.jl | 44 ++++++++++++++++++++++---------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/src/mcmc/smc.jl b/src/mcmc/smc.jl index 838751550..1aaf857ff 100644 --- a/src/mcmc/smc.jl +++ b/src/mcmc/smc.jl @@ -43,7 +43,7 @@ end # if score is nothing, the varinfo is caught up and there's no need to update set_varinfo!(::TracedModel, ::Nothing) = nothing -set_varinfo!(trace::TracedModel, ::Real) = (trace.varinfo = task_local_storage(:varinfo); ) +set_varinfo!(trace::TracedModel, ::Real) = (trace.varinfo = task_local_storage(:varinfo)) struct ProduceLogLikelihoodAccumulator{T<:Real} <: DynamicPPL.LogProbAccumulator{T} logp::T @@ -111,7 +111,12 @@ function DynamicPPL.tilde_assume!!( end function DynamicPPL.tilde_observe!!( - ::SMCContext, dist::Distribution, val, vn::Union{VarName,Nothing}, template, vi::AbstractVarInfo + ::SMCContext, + dist::Distribution, + val, + vn::Union{VarName,Nothing}, + template, + vi::AbstractVarInfo, ) val, vi = DynamicPPL.tilde_observe!!(DefaultContext(), dist, val, vn, template, vi) task_local_storage(:varinfo, vi) @@ -178,7 +183,9 @@ function StatsBase.sample(rng::AbstractRNG, particles::ParticleContainer) return sample(rng, particles.values, weights(particles)) end -function resample!(rng::AbstractRNG, particles::ParticleContainer, weights::StatsBase.Weights) +function resample!( + rng::AbstractRNG, particles::ParticleContainer, weights::StatsBase.Weights +) idx = sample_ancestors(rng, weights.values) @. particles = Particle($split!(rng, particles.values[idx])) end @@ -210,7 +217,7 @@ Base.keys(pc::ReferencedContainer) = LinearIndices(pc.values) Base.iterate(pc::ReferencedContainer) = iterate(pc.particles) -function Base.iterate(pc::ReferencedContainer, i) +function Base.iterate(pc::ReferencedContainer, i) i == length(pc) && return (pc.reference, i + 1) return iterate(pc.particles, i) end @@ -222,7 +229,7 @@ end function Base.collect(pc::ReferencedContainer) particles = Vector{eltype(pc.particles)}(undef, length(pc)) - particles[1:end-1] = @views(pc.particles) + particles[1:(end - 1)] = @views(pc.particles) particles[end] = pc.reference return particles end @@ -231,14 +238,14 @@ Base.getproperty(pc::ReferencedContainer, s::Symbol) = _getproperty(pc, Val(s)) function _getproperty(pc::ReferencedContainer{PT}, ::Val{:values}) where {PT} values = Vector{PT}(undef, length(pc.particles) + 1) - values[1:end-1] = @views(pc.particles.values) + values[1:(end - 1)] = @views(pc.particles.values) values[end] = pc.reference.value return values end -function _getproperty(pc::ReferencedContainer{PT,WT}, ::Val{:log_weights}) where {PT, WT} +function _getproperty(pc::ReferencedContainer{PT,WT}, ::Val{:log_weights}) where {PT,WT} log_weights = Vector{WT}(undef, length(pc.particles) + 1) - log_weights[1:end-1] = @views(pc.particles.log_weights) + log_weights[1:(end - 1)] = @views(pc.particles.log_weights) log_weights[end] = pc.reference.logw return log_weights end @@ -289,7 +296,7 @@ end function initialize( rng::AbstractRNG, model::DynamicPPL.Model, sampler::SMC, N::Integer, ref ) - particles, is_done = if isnothing(ref) + return particles, is_done = if isnothing(ref) initialize(rng, model, sampler, N) else particles, is_done = initialize(rng, model, sampler, N - 1) @@ -324,11 +331,7 @@ function reweight!(particles, ::AbstractMCMC.MCMCThreads) return all(num_done) end -function maybe_resample!( - rng::AbstractRNG, - particles, - resampler::AbstractResampler -) +function maybe_resample!(rng::AbstractRNG, particles, resampler::AbstractResampler) weights = StatsBase.weights(particles) rs_flag = should_resample(weights, resampler) rs_flag && resample!(rng, particles, weights) @@ -343,7 +346,7 @@ function rejuvenate!( ::AbstractMCMC.AbstractMCMCEnsemble, ::Bool, ::Integer; - kwargs... + kwargs..., ) return particles end @@ -354,7 +357,7 @@ function smcsample( sampler::SMC, ensemble::AbstractMCMC.AbstractMCMCEnsemble, N::Integer; - ref=nothing + ref=nothing, ) particles, is_done = initialize(rng, model, sampler, N, ref) iter = 0 @@ -374,7 +377,7 @@ function AbstractMCMC.sample( sampler::SMC, N::Integer; ensemble::AbstractMCMC.AbstractMCMCEnsemble=MCMCSerial(), - kwargs... + kwargs..., ) return smcsample(rng, model, sampler, ensemble, N; kwargs...) end @@ -405,12 +408,9 @@ end const PG{T} = ParticleGibbs{T} function AbstractMCMC.step( - rng::AbstractRNG, - model::DynamicPPL.Model, - sampler::ParticleGibbs; - kwargs..., + rng::AbstractRNG, model::DynamicPPL.Model, sampler::ParticleGibbs; kwargs... ) - particles = smcsample(rng, model, sampler.kernel, MCMCSerial(), sampler.N); + particles = smcsample(rng, model, sampler.kernel, MCMCSerial(), sampler.N) state = sample(rng, particles) return get_varinfo(state), state end From f1ec36c68d6bd697206bb84acb4ad03637f8fe13 Mon Sep 17 00:00:00 2001 From: Charles Knipp <32943413+charlesknipp@users.noreply.github.com> Date: Wed, 18 Mar 2026 09:52:31 -0400 Subject: [PATCH 04/10] add forgotten import I fixed this locally, but forgot to commit to the PR Co-authored-by: Penelope Yong --- src/mcmc/smc.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/mcmc/smc.jl b/src/mcmc/smc.jl index 1aaf857ff..996637658 100644 --- a/src/mcmc/smc.jl +++ b/src/mcmc/smc.jl @@ -1,6 +1,7 @@ #### #### Combining DynamicPPL and Libtask. #### +using StatsFuns: softmax mutable struct TracedModel{T<:TapedTask} const task::T From 753c0ad3236d6c4629bca9d98afb8a3db5a395af Mon Sep 17 00:00:00 2001 From: Charles Knipp <32943413+charlesknipp@users.noreply.github.com> Date: Wed, 18 Mar 2026 15:08:47 -0400 Subject: [PATCH 05/10] more imports necessary Co-authored-by: Penelope Yong --- src/mcmc/smc.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mcmc/smc.jl b/src/mcmc/smc.jl index 996637658..9afdbff80 100644 --- a/src/mcmc/smc.jl +++ b/src/mcmc/smc.jl @@ -96,9 +96,9 @@ struct SMCContext <: DynamicPPL.AbstractContext end function init_context(rng::AbstractRNG, vi::VarInfo, vn::VarName) if ~haskey(vi, vn) - return InitContext(rng, InitFromPrior(), vi.transform_strategy) + return DynamicPPL.InitContext(rng, DynamicPPL.InitFromPrior(), vi.transform_strategy) else - return DefaultContext() + return DynamicPPL.DefaultContext() end end From f1b7826ff82118cffddf31f651f90995f4ea052f Mon Sep 17 00:00:00 2001 From: Charles Knipp <32943413+charlesknipp@users.noreply.github.com> Date: Wed, 18 Mar 2026 15:11:16 -0400 Subject: [PATCH 06/10] formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/mcmc/smc.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/mcmc/smc.jl b/src/mcmc/smc.jl index 9afdbff80..9681cd826 100644 --- a/src/mcmc/smc.jl +++ b/src/mcmc/smc.jl @@ -96,7 +96,9 @@ struct SMCContext <: DynamicPPL.AbstractContext end function init_context(rng::AbstractRNG, vi::VarInfo, vn::VarName) if ~haskey(vi, vn) - return DynamicPPL.InitContext(rng, DynamicPPL.InitFromPrior(), vi.transform_strategy) + return DynamicPPL.InitContext( + rng, DynamicPPL.InitFromPrior(), vi.transform_strategy + ) else return DynamicPPL.DefaultContext() end From beec574a59c177a11b1254e40d298418d921c30e Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Wed, 18 Mar 2026 16:56:03 -0400 Subject: [PATCH 07/10] properly treat `bundle_samples` --- src/mcmc/smc.jl | 61 ++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 50 insertions(+), 11 deletions(-) diff --git a/src/mcmc/smc.jl b/src/mcmc/smc.jl index 9681cd826..9dd22ba49 100644 --- a/src/mcmc/smc.jl +++ b/src/mcmc/smc.jl @@ -1,7 +1,7 @@ #### #### Combining DynamicPPL and Libtask. #### -using StatsFuns: softmax +using StatsFuns: softmax, logsumexp mutable struct TracedModel{T<:TapedTask} const task::T @@ -193,6 +193,10 @@ function resample!( @. particles = Particle($split!(rng, particles.values[idx])) end +function logevidence(particles::ParticleContainer) + return logsumexp(particles.log_weights) - log(length(particles)) +end + # TODO: optimize this for the love of god function split!(rng::AbstractRNG, particles::Vector{<:TracedModel}) children = deepcopy.(particles) @@ -269,6 +273,10 @@ function resample!(rng::AbstractRNG, ref::ReferencedContainer, weights::StatsBas return ref end +function logevidence(particles::ReferencedContainer) + return logsumexp(particles.log_weights) - log(length(particles)) +end + #### #### Generic Sequential Monte Carlo sampler. #### @@ -341,6 +349,19 @@ function maybe_resample!(rng::AbstractRNG, particles, resampler::AbstractResampl return rs_flag end +# for referenced particle sets, rejuvenate all but the reference trajectory +function rejuvenate!( + rng::AbstractRNG, + ref::ReferencedContainer, + kernel, + parallel::AbstractMCMC.AbstractMCMCEnsemble, + rs_flag::Bool, + iter::Integer; + kwargs..., +) + return rejuvenate!(rng, ref.particles, kernel, parallel, rs_flag, iter) +end + # leave out rejuvenation for now, we'll cross that bridge when we get there function rejuvenate!( ::AbstractRNG, @@ -354,6 +375,7 @@ function rejuvenate!( return particles end +# TODO: add custom logging like is done for AbstractMCMC function smcsample( rng::AbstractRNG, model::DynamicPPL.Model, @@ -361,6 +383,7 @@ function smcsample( ensemble::AbstractMCMC.AbstractMCMCEnsemble, N::Integer; ref=nothing, + kwargs... ) particles, is_done = initialize(rng, model, sampler, N, ref) iter = 0 @@ -373,16 +396,24 @@ function smcsample( return particles end -# this doesn't return an MCMCChain which kinda blows, but whatever function AbstractMCMC.sample( rng::AbstractRNG, model::DynamicPPL.Model, sampler::SMC, N::Integer; ensemble::AbstractMCMC.AbstractMCMCEnsemble=MCMCSerial(), + chain_type::Any=DEFAULT_CHAIN_TYPE, kwargs..., ) - return smcsample(rng, model, sampler, ensemble, N; kwargs...) + # perform a particle sweep + particles = smcsample(rng, model, sampler, ensemble, N; kwargs...) + stats = (; logevidence=logevidence(particles)) + + # convert to readable format and bundle samples + sample = map( + x -> DynamicPPL.ParamsWithStats(get_varinfo(x.value), model, stats), particles + ) + return AbstractMCMC.bundle_samples(sample, model, sampler, particles, chain_type) end #### @@ -408,17 +439,18 @@ struct ParticleGibbs{T<:SMC} <: AbstractMCMC.AbstractSampler N::Int end -const PG{T} = ParticleGibbs{T} +const PG = ParticleGibbs function AbstractMCMC.step( rng::AbstractRNG, model::DynamicPPL.Model, sampler::ParticleGibbs; kwargs... ) - particles = smcsample(rng, model, sampler.kernel, MCMCSerial(), sampler.N) - state = sample(rng, particles) - return get_varinfo(state), state + particles = smcsample(rng, model, sampler.kernel, AbstractMCMC.MCMCSerial(), sampler.N) + state = StatsBase.sample(rng, particles) + stats = (; logevidence=logevidence(particles)) + sample = DynamicPPL.ParamsWithStats(get_varinfo(state), model, stats) + return sample, state end -# NOTE: this needs some TLC, not sure how I can better integrate function AbstractMCMC.step( rng::AbstractRNG, model::DynamicPPL.Model, @@ -427,8 +459,15 @@ function AbstractMCMC.step( kwargs..., ) particles = smcsample( - rng, model, sampler.kernel, MCMCSerial(), sampler.N; ref=deepcopy(state) + rng, + model, + sampler.kernel, + AbstractMCMC.MCMCSerial(), + sampler.N; + ref=deepcopy(state), ) - state = sample(rng, particles) - return get_varinfo(state), state + state = StatsBase.sample(rng, particles) + stats = (; logevidence=logevidence(particles)) + sample = DynamicPPL.ParamsWithStats(get_varinfo(state), model, stats) + return sample, state end From f7c2c3e1adf8e9b08e9ee2d5ff0ef610a9d7b62d Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Wed, 18 Mar 2026 16:56:18 -0400 Subject: [PATCH 08/10] update some of the original unit tests --- test/mcmc/particle_mcmc.jl | 48 ++++++++++---------------------------- 1 file changed, 12 insertions(+), 36 deletions(-) diff --git a/test/mcmc/particle_mcmc.jl b/test/mcmc/particle_mcmc.jl index 8b3c68ff1..5329d40e5 100644 --- a/test/mcmc/particle_mcmc.jl +++ b/test/mcmc/particle_mcmc.jl @@ -12,16 +12,10 @@ using Turing @testset "SMC" begin @testset "constructor" begin s = SMC() - @test s.resampler == ResampleWithESSThreshold() + @test s.resampler == AlwaysResample() s = SMC(0.6) - @test s.resampler === ResampleWithESSThreshold(resample_systematic, 0.6) - - s = SMC(resample_multinomial, 0.6) - @test s.resampler === ResampleWithESSThreshold(resample_multinomial, 0.6) - - s = SMC(resample_systematic) - @test s.resampler === resample_systematic + @test s.resampler === ESSResampler(0.6) end @testset "models" begin @@ -119,26 +113,8 @@ using Turing end @testset "PG" begin - @testset "constructor" begin - s = PG(10) - @test s.nparticles == 10 - @test s.resampler == ResampleWithESSThreshold() - - s = PG(60, 0.6) - @test s.nparticles == 60 - @test s.resampler === ResampleWithESSThreshold(resample_systematic, 0.6) - - s = PG(80, resample_multinomial, 0.6) - @test s.nparticles == 80 - @test s.resampler === ResampleWithESSThreshold(resample_multinomial, 0.6) - - s = PG(100, resample_systematic) - @test s.nparticles == 100 - @test s.resampler === resample_systematic - end - @testset "chain log-density metadata" begin - test_chain_logp_metadata(PG(10)) + test_chain_logp_metadata(PG(SMC(), 10)) end @testset "logevidence" begin @@ -152,7 +128,7 @@ end return x end - chains_pg = sample(StableRNG(468), test(), PG(10), 100) + chains_pg = sample(StableRNG(468), test(), PG(SMC(), 10), 100) @test all(isone, chains_pg[:x]) pg_logevidence = mean(chains_pg[:logevidence]) @@ -163,7 +139,7 @@ end # https://github.com/TuringLang/Turing.jl/issues/1598 @testset "reference particle" begin - c = sample(gdemo_default, PG(1), 1_000) + c = sample(gdemo_default, PG(SMC(), 1), 1_000) @test length(unique(c[:m])) == 1 @test length(unique(c[:s])) == 1 end @@ -181,7 +157,7 @@ end @addlogprob! 0.0 end end - c = sample(StableRNG(468), addlogprob_demo(), PG(10), 100) + c = sample(StableRNG(468), addlogprob_demo(), PG(SMC(), 10), 100) # Result should be biased towards x > 0. @test mean(c[:x]) > 0.7 end @@ -192,11 +168,11 @@ end return y ~ Normal(x) end - chain = sample(StableRNG(468), kwarg_demo(5.0), PG(20), 1000) + chain = sample(StableRNG(468), kwarg_demo(5.0), PG(SMC(), 20), 1000) @test chain isa MCMCChains.Chains @test mean(chain[:x]) ≈ 2.5 atol = 0.3 - chain2 = sample(StableRNG(468), kwarg_demo(5.0; n=10.0), PG(20), 1000) + chain2 = sample(StableRNG(468), kwarg_demo(5.0; n=10.0), PG(SMC(), 20), 1000) @test chain2 isa MCMCChains.Chains @test mean(chain2[:x]) ≈ 7.5 atol = 0.3 end @@ -214,7 +190,7 @@ end return a ~ to_submodel(inner(y, x)) end m1 = nested(1.0) - chn = sample(StableRNG(468), m1, PG(10), 1000) + chn = sample(StableRNG(468), m1, PG(SMC(), 10), 1000) @test mean(chn[:x]) ≈ 0.5 atol = 0.1 end @@ -228,14 +204,14 @@ end return a ~ to_submodel(inner_kwarg(5.0)) end m1 = outer_kwarg1() - chn1 = sample(StableRNG(468), m1, PG(10), 1000) + chn1 = sample(StableRNG(468), m1, PG(SMC(), 10), 1000) @test mean(chn1[Symbol("a.x")]) ≈ 2.5 atol = 0.3 @model function outer_kwarg2(n) return a ~ to_submodel(inner_kwarg(5.0; n=n)) end m2 = outer_kwarg2(10.0) - chn2 = sample(StableRNG(468), m2, PG(10), 1000) + chn2 = sample(StableRNG(468), m2, PG(SMC(), 10), 1000) @test mean(chn2[Symbol("a.x")]) ≈ 7.5 atol = 0.3 end @@ -249,7 +225,7 @@ end end end model = setthreadsafe(f(randn(10)), true) - @test_throws ArgumentError sample(model, PG(10), 100) + @test_throws ArgumentError sample(model, PG(SMC(), 10), 100) end end From 00516aba76da8771d72dadb5dd84d983fd4744c7 Mon Sep 17 00:00:00 2001 From: Charles Knipp <32943413+charlesknipp@users.noreply.github.com> Date: Wed, 18 Mar 2026 16:58:15 -0400 Subject: [PATCH 09/10] formatter I stg I need to enable auto-format upon saving Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/mcmc/smc.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcmc/smc.jl b/src/mcmc/smc.jl index 9dd22ba49..e5d804f48 100644 --- a/src/mcmc/smc.jl +++ b/src/mcmc/smc.jl @@ -383,7 +383,7 @@ function smcsample( ensemble::AbstractMCMC.AbstractMCMCEnsemble, N::Integer; ref=nothing, - kwargs... + kwargs..., ) particles, is_done = initialize(rng, model, sampler, N, ref) iter = 0 From 80bfac8f904d4a1eebe1f9532baf6f4c18c16459 Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Wed, 18 Mar 2026 17:12:43 -0400 Subject: [PATCH 10/10] convenience constructor --- src/mcmc/smc.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/mcmc/smc.jl b/src/mcmc/smc.jl index e5d804f48..f49ccad8c 100644 --- a/src/mcmc/smc.jl +++ b/src/mcmc/smc.jl @@ -441,6 +441,10 @@ end const PG = ParticleGibbs +# parity the original interface +PG(N::Int, threshold::Real) = PG(SMC(threshold), N) +PG(N::Int) = PG(N, 0.5) + function AbstractMCMC.step( rng::AbstractRNG, model::DynamicPPL.Model, sampler::ParticleGibbs; kwargs... )