From a3dd5b2b6a70a40489a7bd24eb1b9390acbcaf88 Mon Sep 17 00:00:00 2001 From: Ryan Senne <50930199+rsenne@users.noreply.github.com> Date: Thu, 12 Mar 2026 08:08:03 -0400 Subject: [PATCH 1/8] Implements interoperability with sampling libraries. Creates interface.jl. Minor tweaks to tests/erro handling --- Project.toml | 6 +- src/DEER/DEER.jl | 25 +++++--- src/MALA/MALA.jl | 8 +-- src/ParallelMCMC.jl | 13 ++-- src/interface.jl | 124 ++++++++++++++++++++++++++++++++++++++ test/test-Deer-vs-MALA.jl | 2 +- 6 files changed, 154 insertions(+), 24 deletions(-) create mode 100644 src/interface.jl diff --git a/Project.toml b/Project.toml index dd88cdc..2b3df4a 100644 --- a/Project.toml +++ b/Project.toml @@ -4,22 +4,20 @@ version = "0.1.0" authors = ["Ryan Senne "] [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" -Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] +ADTypes = "1.21.0" AbstractMCMC = "5.10.0" DifferentiationInterface = "0.7.13" -Distributions = "0.25.122" LinearAlgebra = "1.12.0" -LogExpFunctions = "0.3.29" MCMCChains = "7.7.0" Mooncake = "0.4.192" Random = "1.11.0" diff --git a/src/DEER/DEER.jl b/src/DEER/DEER.jl index bc31819..75d4057 100644 --- a/src/DEER/DEER.jl +++ b/src/DEER/DEER.jl @@ -3,11 +3,12 @@ module DEER using Base.Threads: @threads, threadid using LinearAlgebra using DifferentiationInterface +using ADTypes: ADTypes, AbstractADType import Mooncake: Mooncake using Random const DI = DifferentiationInterface -backend = DI.AutoMooncake(; config=nothing) +const DEFAULT_BACKEND = DI.AutoMooncake(; config=nothing) """ Deterministic recursion driven by a pre-generated tape. @@ -16,23 +17,28 @@ Deterministic recursion driven by a pre-generated tape. - `step_lin(x, tape_t, c...)`: surrogate used only for Jacobians. - `consts(x, tape_t) -> Tuple`: returns constants `c...` passed into `step_lin` as DI.Constant. - `const_example`: example tuple of constants, used in `prepare`. +- `backend`: AD backend (any `ADTypes.AbstractADType`); defaults to `AutoMooncake`. """ -struct TapedRecursion{Ff,Fl,Fc,Tt,Ce} +struct TapedRecursion{Ff,Fl,Fc,Tt,Ce,AD<:AbstractADType} step_fwd::Ff step_lin::Fl consts::Fc tape::Vector{Tt} const_example::Ce + backend::AD end "Backward-compatible constructor: uses the same step for forward + Jacobian, and no extra constants." -TapedRecursion(step, tape::Vector) = TapedRecursion(step, step, (_x, _tt)->(), tape, ()) +function TapedRecursion(step, tape::Vector) + return TapedRecursion(step, step, (_x, _tt) -> (), tape, (), DEFAULT_BACKEND) +end "Main constructor." function TapedRecursion( - step_fwd, step_lin, tape::Vector; consts=(_x, _tt)->(), const_example=() + step_fwd, step_lin, tape::Vector; + consts=(_x, _tt) -> (), const_example=(), backend::AbstractADType=DEFAULT_BACKEND, ) - return TapedRecursion(step_fwd, step_lin, consts, tape, const_example) + return TapedRecursion(step_fwd, step_lin, consts, tape, const_example, backend) end # Stable callable for DI (Jacobian always taken w.r.t. step_lin) @@ -46,7 +52,7 @@ function prepare(rec::TapedRecursion, x0::AbstractVector) f = StepWithTape(rec) cs = rec.const_example return DI.prepare_jacobian( - f, backend, x0, DI.Constant(rec.tape[1]), (DI.Constant(c) for c in cs)... + f, rec.backend, x0, DI.Constant(rec.tape[1]), (DI.Constant(c) for c in cs)... ) end @@ -54,10 +60,9 @@ end function prepare_pushforward(rec::TapedRecursion, x0::AbstractVector) f = StepWithTape(rec) cs = rec.const_example - # tx is a tuple of tangents; we supply an example tangent tx0 = (zero(x0),) return DI.prepare_pushforward( - f, backend, x0, tx0, DI.Constant(rec.tape[1]), (DI.Constant(c) for c in cs)... + f, rec.backend, x0, tx0, DI.Constant(rec.tape[1]), (DI.Constant(c) for c in cs)... ) end @@ -66,7 +71,7 @@ function jac_full(rec::TapedRecursion, prep, x::AbstractVector, t::Int) f = StepWithTape(rec) cs = rec.consts(x, rec.tape[t]) return DI.jacobian( - f, prep, backend, x, DI.Constant(rec.tape[t]), (DI.Constant(c) for c in cs)... + f, prep, rec.backend, x, DI.Constant(rec.tape[t]), (DI.Constant(c) for c in cs)... ) end @@ -94,7 +99,7 @@ function _jvp_step_lin( res = DI.pushforward( f, prep_pf, - backend, + rec.backend, x, tx, DI.Constant(rec.tape[t]), diff --git a/src/MALA/MALA.jl b/src/MALA/MALA.jl index 81d5141..72615e6 100644 --- a/src/MALA/MALA.jl +++ b/src/MALA/MALA.jl @@ -35,8 +35,8 @@ Returns: function mala_step_taped( logp, gradlogp, x::AbstractVector, ϵ::Real, ξ::AbstractVector, u::Real ) - @assert length(x) == length(ξ) - @assert 0.0 < u < 1.0 + length(x) == length(ξ) || throw(DimensionMismatch("x and ξ must have the same length")) + 0.0 < u < 1.0 || throw(ArgumentError("u must be in (0, 1)")) g_x = gradlogp(x) # Proposal @@ -77,7 +77,7 @@ function run_mala_sequential_taped( us::AbstractVector, ) T = length(us) - @assert length(ξs) == T + length(ξs) == T || throw(DimensionMismatch("ξs and us must have the same length")) xs = Vector{typeof(x0)}(undef, T + 1) xs[1] = copy(x0) x = copy(x0) @@ -90,7 +90,7 @@ end "Compute the MALA proposal map y = x + ϵ∇logp(x) + sqrt(2ϵ) ξ." function mala_proposal(logp, gradlogp, x::AbstractVector, ϵ::Real, ξ::AbstractVector) - @assert length(x) == length(ξ) + length(x) == length(ξ) || throw(DimensionMismatch("x and ξ must have the same length")) g_x = gradlogp(x) return x .+ ϵ .* g_x .+ sqrt(2ϵ) .* ξ end diff --git a/src/ParallelMCMC.jl b/src/ParallelMCMC.jl index 37900e5..7097a64 100644 --- a/src/ParallelMCMC.jl +++ b/src/ParallelMCMC.jl @@ -1,13 +1,16 @@ module ParallelMCMC -# imports -using Distributions -using LinearAlgebra -using LogExpFunctions +using AbstractMCMC using MCMCChains +using LinearAlgebra +using Random +using Statistics -# inclusions include("MALA/MALA.jl") include("DEER/DEER.jl") +include("interface.jl") + +export DensityModel, MALASampler, MALATransition, MALAState +export MALA, DEER end diff --git a/src/interface.jl b/src/interface.jl new file mode 100644 index 0000000..0d87985 --- /dev/null +++ b/src/interface.jl @@ -0,0 +1,124 @@ +#= +AbstractMCMC interface for ParallelMCMC samplers. + +Defines model/sampler/state/transition types and implements +`AbstractMCMC.step` so that `sample(model, sampler, N)` works out of the box. +=# + +""" + DensityModel(logdensity, grad_logdensity, dim) + +Wraps a log-density function and its gradient for use with ParallelMCMC samplers. + +- `logdensity(x::AbstractVector) -> Real` +- `grad_logdensity(x::AbstractVector) -> AbstractVector` +- `dim::Int` — dimensionality of the parameter space +""" +struct DensityModel{F,G} <: AbstractMCMC.AbstractModel + logdensity::F + grad_logdensity::G + dim::Int +end + +""" + MALASampler(epsilon) + +Metropolis-Adjusted Langevin Algorithm sampler with step size `epsilon`. +""" +struct MALASampler <: AbstractMCMC.AbstractSampler + epsilon::Float64 +end + +struct MALAState{V<:AbstractVector} + x::V + logp::Float64 +end + +struct MALATransition{V<:AbstractVector} + x::V + logp::Float64 + accepted::Bool +end + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::DensityModel, + sampler::MALASampler; + initial_params=nothing, + kwargs..., +) + x = if initial_params !== nothing + copy(initial_params) + else + randn(rng, model.dim) + end + logp_val = model.logdensity(x) + t = MALATransition(x, logp_val, true) + s = MALAState(x, logp_val) + return t, s +end + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::DensityModel, + sampler::MALASampler, + state::MALAState; + kwargs..., +) + x = state.x + ϵ = sampler.epsilon + D = model.dim + + ξ = randn(rng, D) + u = rand(rng) + + accepted = MALA.mala_accept_indicator( + model.logdensity, model.grad_logdensity, x, ϵ, ξ, u + ) == 1.0 + + x_next = MALA.mala_step_taped( + model.logdensity, model.grad_logdensity, x, ϵ, ξ, u + ) + + logp_val = accepted ? model.logdensity(x_next) : state.logp + t = MALATransition(x_next, logp_val, accepted) + s = MALAState(x_next, logp_val) + return t, s +end + +function AbstractMCMC.bundle_samples( + samples::Vector{<:MALATransition}, + model::DensityModel, + sampler::MALASampler, + state::MALAState, + ::Type{MCMCChains.Chains}; + param_names=nothing, + kwargs..., +) + N = length(samples) + D = model.dim + + names = if param_names !== nothing + param_names + else + [Symbol("x[$i]") for i in 1:D] + end + + internal_names = [:logp, :accepted] + + vals = Matrix{Float64}(undef, N, D) + internals = Matrix{Float64}(undef, N, 2) + + for i in 1:N + s = samples[i] + vals[i, :] .= s.x + internals[i, 1] = s.logp + internals[i, 2] = s.accepted ? 1.0 : 0.0 + end + + return MCMCChains.Chains( + hcat(vals, internals), + vcat(names, internal_names), + Dict(:parameters => names, :internals => internal_names), + ) +end diff --git a/test/test-Deer-vs-MALA.jl b/test/test-Deer-vs-MALA.jl index e52d0ba..b0567d1 100644 --- a/test/test-Deer-vs-MALA.jl +++ b/test/test-Deer-vs-MALA.jl @@ -173,7 +173,7 @@ end s_prev = copy(s0) for t in 1:T s_prev = view(A, :, t) .* s_prev .+ view(B, :, t) - @inbounds S_ref[:, t] .= s_prev + S_ref[:, t] .= s_prev end # Scan From d84233a517d6f4eacadfdc5ff0aa05f0e1063aaa Mon Sep 17 00:00:00 2001 From: Ryan Senne <50930199+rsenne@users.noreply.github.com> Date: Thu, 12 Mar 2026 11:04:00 -0400 Subject: [PATCH 2/8] Adds preconditioning support for MALA and updates API a bit --- src/MALA/MALA.jl | 111 ++++++++++++---- src/interface.jl | 6 +- test/Project.toml | 2 + test/test-AbstractMCMC-Interface.jl | 197 ++++++++++++++++++++++++++++ 4 files changed, 283 insertions(+), 33 deletions(-) create mode 100644 test/test-AbstractMCMC-Interface.jl diff --git a/src/MALA/MALA.jl b/src/MALA/MALA.jl index 72615e6..923f628 100644 --- a/src/MALA/MALA.jl +++ b/src/MALA/MALA.jl @@ -2,20 +2,38 @@ module MALA using Random, LinearAlgebra +# Preconditioner dispatch helpers. cholM is either nothing (identity) or a Cholesky factor. +_apply_M(g, ::Nothing) = g +_apply_M(g, cholM::Cholesky) = cholM.L * (cholM.L' * g) + +_apply_L(ξ, ::Nothing) = ξ +_apply_L(ξ, cholM::Cholesky) = cholM.L * ξ + +_quad_Minv(r, ::Nothing) = dot(r, r) +function _quad_Minv(r, cholM::Cholesky) + w = cholM.L \ r + return dot(w, w) +end + +_logdet_M(::Nothing) = 0.0 +_logdet_M(cholM::Cholesky) = logdet(cholM) + """ Compute the log of the MALA proposal density q(y | x). -We use the Gaussian: - y ~ Normal(x + ϵ∇logp(x), 2ϵ I) + +With mass matrix M (passed as `cholM = cholesky(M)`): + y ~ Normal(x + ϵ M ∇logp(x), 2ϵ M) + +With `cholM=nothing` (default), uses identity M = I. """ function logq_mala( - y::AbstractVector, x::AbstractVector, gradlogp_x::AbstractVector, ϵ::Real + y::AbstractVector, x::AbstractVector, gradlogp_x::AbstractVector, ϵ::Real; + cholM=nothing, ) - μ = x .+ ϵ .* gradlogp_x - # log N(y; μ, 2ϵ I) up to constant: - # -0.5 * ||y-μ||^2 / (2ϵ) - (d/2) log(4πϵ) + μ = x .+ ϵ .* _apply_M(gradlogp_x, cholM) d = length(x) r = y .- μ - return -0.5 * dot(r, r) / (2ϵ) - (d / 2) * log(4π * ϵ) + return -0.5 * _quad_Minv(r, cholM) / (2ϵ) - (d / 2) * log(4π * ϵ) - 0.5 * _logdet_M(cholM) end """ @@ -28,32 +46,30 @@ Inputs: - ϵ: step size - ξ: N(0, I) noise vector (tape) - u: Uniform(0,1) scalar (tape) +- cholM: optional Cholesky factor of mass matrix M (default: identity) Returns: - x_next """ function mala_step_taped( - logp, gradlogp, x::AbstractVector, ϵ::Real, ξ::AbstractVector, u::Real + logp, gradlogp, x::AbstractVector, ϵ::Real, ξ::AbstractVector, u::Real; + cholM=nothing, ) length(x) == length(ξ) || throw(DimensionMismatch("x and ξ must have the same length")) 0.0 < u < 1.0 || throw(ArgumentError("u must be in (0, 1)")) g_x = gradlogp(x) - # Proposal - y = x .+ ϵ .* g_x .+ sqrt(2ϵ) .* ξ + y = x .+ ϵ .* _apply_M(g_x, cholM) .+ sqrt(2ϵ) .* _apply_L(ξ, cholM) - # Compute log acceptance ratio: - # log α = logp(y) + log q(x|y) - logp(x) - log q(y|x) logp_x = logp(x) logp_y = logp(y) g_y = gradlogp(y) - logq_y_given_x = logq_mala(y, x, g_x, ϵ) - logq_x_given_y = logq_mala(x, y, g_y, ϵ) + logq_y_given_x = logq_mala(y, x, g_x, ϵ; cholM=cholM) + logq_x_given_y = logq_mala(x, y, g_y, ϵ; cholM=cholM) logα = (logp_y + logq_x_given_y) - (logp_x + logq_y_given_x) - # Accept/reject using tape u return (log(u) < logα) ? y : x end @@ -74,7 +90,8 @@ function run_mala_sequential_taped( x0::AbstractVector, ϵ::Real, ξs::Vector{<:AbstractVector}, - us::AbstractVector, + us::AbstractVector; + cholM=nothing, ) T = length(us) length(ξs) == T || throw(DimensionMismatch("ξs and us must have the same length")) @@ -82,27 +99,31 @@ function run_mala_sequential_taped( xs[1] = copy(x0) x = copy(x0) for t in 1:T - x = mala_step_taped(logp, gradlogp, x, ϵ, ξs[t], us[t]) + x = mala_step_taped(logp, gradlogp, x, ϵ, ξs[t], us[t]; cholM=cholM) xs[t + 1] = copy(x) end return xs end -"Compute the MALA proposal map y = x + ϵ∇logp(x) + sqrt(2ϵ) ξ." -function mala_proposal(logp, gradlogp, x::AbstractVector, ϵ::Real, ξ::AbstractVector) +"Compute the MALA proposal map y = x + ϵ M ∇logp(x) + √(2ϵ) L ξ." +function mala_proposal( + logp, gradlogp, x::AbstractVector, ϵ::Real, ξ::AbstractVector; cholM=nothing, +) length(x) == length(ξ) || throw(DimensionMismatch("x and ξ must have the same length")) g_x = gradlogp(x) - return x .+ ϵ .* g_x .+ sqrt(2ϵ) .* ξ + return x .+ ϵ .* _apply_M(g_x, cholM) .+ sqrt(2ϵ) .* _apply_L(ξ, cholM) end "Compute log acceptance ratio logα(x→y) for MALA." -function mala_logα(logp, gradlogp, x::AbstractVector, y::AbstractVector, ϵ::Real) +function mala_logα( + logp, gradlogp, x::AbstractVector, y::AbstractVector, ϵ::Real; cholM=nothing, +) g_x = gradlogp(x) g_y = gradlogp(y) logp_x = logp(x) logp_y = logp(y) - logq_y_given_x = logq_mala(y, x, g_x, ϵ) - logq_x_given_y = logq_mala(x, y, g_y, ϵ) + logq_y_given_x = logq_mala(y, x, g_x, ϵ; cholM=cholM) + logq_x_given_y = logq_mala(x, y, g_y, ϵ; cholM=cholM) return (logp_y + logq_x_given_y) - (logp_x + logq_y_given_x) end @@ -111,21 +132,55 @@ Primal accept indicator for a taped MALA step. Returns Float64 in {0.0, 1.0} so it can be used as a constant gate. """ function mala_accept_indicator( - logp, gradlogp, x::AbstractVector, ϵ::Real, ξ::AbstractVector, u::Real + logp, gradlogp, x::AbstractVector, ϵ::Real, ξ::AbstractVector, u::Real; + cholM=nothing, ) - y = mala_proposal(logp, gradlogp, x, ϵ, ξ) - logα = mala_logα(logp, gradlogp, x, y, ϵ) + y = mala_proposal(logp, gradlogp, x, ϵ, ξ; cholM=cholM) + logα = mala_logα(logp, gradlogp, x, y, ϵ; cholM=cholM) return (log(u) < logα) ? 1.0 : 0.0 end +""" +One taped MALA step, returning both the next state and the accept flag. + +This is the efficient entry point: it evaluates `gradlogp` exactly twice (once at `x`, +once at the proposal `y`), compared to calling `mala_accept_indicator` + `mala_step_taped` +separately which evaluates it five times. + +Returns `(x_next, accepted::Bool)`. +""" +function mala_step_full( + logp, gradlogp, x::AbstractVector, ϵ::Real, ξ::AbstractVector, u::Real; + cholM=nothing, +) + length(x) == length(ξ) || throw(DimensionMismatch("x and ξ must have the same length")) + 0.0 < u < 1.0 || throw(ArgumentError("u must be in (0, 1)")) + + g_x = gradlogp(x) + y = x .+ ϵ .* _apply_M(g_x, cholM) .+ sqrt(2ϵ) .* _apply_L(ξ, cholM) + + logp_x = logp(x) + logp_y = logp(y) + g_y = gradlogp(y) + + logq_y_given_x = logq_mala(y, x, g_x, ϵ; cholM=cholM) + logq_x_given_y = logq_mala(x, y, g_y, ϵ; cholM=cholM) + logα = (logp_y + logq_x_given_y) - (logp_x + logq_y_given_x) + + accepted = log(u) < logα + x_next = accepted ? y : x + return x_next, accepted +end + """ Stop-gradient surrogate step used for Jacobians. `a` (0.0 or 1.0) must be provided as a constant by the DEER machinery. """ function mala_step_surrogate( - logp, gradlogp, x::AbstractVector, ϵ::Real, ξ::AbstractVector, a::Real + logp, gradlogp, x::AbstractVector, ϵ::Real, ξ::AbstractVector, a::Real; + cholM=nothing, ) - y = mala_proposal(logp, gradlogp, x, ϵ, ξ) + y = mala_proposal(logp, gradlogp, x, ϵ, ξ; cholM=cholM) return (a .* y) .+ ((1 - a) .* x) end diff --git a/src/interface.jl b/src/interface.jl index 0d87985..ecb87eb 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -72,11 +72,7 @@ function AbstractMCMC.step( ξ = randn(rng, D) u = rand(rng) - accepted = MALA.mala_accept_indicator( - model.logdensity, model.grad_logdensity, x, ϵ, ξ, u - ) == 1.0 - - x_next = MALA.mala_step_taped( + x_next, accepted = MALA.mala_step_full( model.logdensity, model.grad_logdensity, x, ϵ, ξ, u ) diff --git a/test/Project.toml b/test/Project.toml index 4e216a9..693918c 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,7 +2,9 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/test-AbstractMCMC-Interface.jl b/test/test-AbstractMCMC-Interface.jl new file mode 100644 index 0000000..d55ab7a --- /dev/null +++ b/test/test-AbstractMCMC-Interface.jl @@ -0,0 +1,197 @@ +using Test +using Random +using LinearAlgebra +using MCMCChains + +using ParallelMCMC + +logp_iface(x) = -0.5 * dot(x, x) +gradlogp_iface(x) = -x + +@testset "AbstractMCMC interface" begin + @testset "DensityModel construction" begin + m = DensityModel(logp_iface, gradlogp_iface, 3) + @test m isa ParallelMCMC.AbstractMCMC.AbstractModel + @test m.dim == 3 + @test m.logdensity([0.0, 0.0, 0.0]) == 0.0 + @test m.grad_logdensity([1.0, 2.0, 3.0]) == [-1.0, -2.0, -3.0] + end + + @testset "MALASampler construction" begin + s = MALASampler(0.1) + @test s isa ParallelMCMC.AbstractMCMC.AbstractSampler + @test s.epsilon == 0.1 + end + + @testset "initial step draws from rng" begin + rng = MersenneTwister(42) + model = DensityModel(logp_iface, gradlogp_iface, 5) + sampler = MALASampler(0.1) + + transition, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler) + + @test transition isa MALATransition + @test state isa MALAState + @test length(transition.x) == 5 + @test length(state.x) == 5 + @test transition.logp == logp_iface(transition.x) + @test state.logp == transition.logp + @test transition.accepted == true + end + + @testset "initial step respects initial_params" begin + rng = MersenneTwister(42) + model = DensityModel(logp_iface, gradlogp_iface, 3) + sampler = MALASampler(0.1) + x0 = [1.0, 2.0, 3.0] + + transition, state = ParallelMCMC.AbstractMCMC.step( + rng, model, sampler; initial_params=x0 + ) + + @test transition.x == x0 + @test state.x == x0 + @test transition.logp == logp_iface(x0) + end + + @testset "initial step does not mutate initial_params" begin + rng = MersenneTwister(42) + model = DensityModel(logp_iface, gradlogp_iface, 3) + sampler = MALASampler(0.1) + x0 = [1.0, 2.0, 3.0] + x0_copy = copy(x0) + + ParallelMCMC.AbstractMCMC.step(rng, model, sampler; initial_params=x0) + @test x0 == x0_copy + end + + @testset "subsequent step produces valid transition" begin + rng = MersenneTwister(42) + model = DensityModel(logp_iface, gradlogp_iface, 4) + sampler = MALASampler(0.1) + + t1, s1 = ParallelMCMC.AbstractMCMC.step(rng, model, sampler) + t2, s2 = ParallelMCMC.AbstractMCMC.step(rng, model, sampler, s1) + + @test t2 isa MALATransition + @test s2 isa MALAState + @test length(t2.x) == 4 + @test t2.accepted isa Bool + @test isfinite(t2.logp) + @test s2.x == t2.x + @test s2.logp == t2.logp + end + + @testset "step determinism with fixed rng" begin + model = DensityModel(logp_iface, gradlogp_iface, 3) + sampler = MALASampler(0.2) + + rng1 = MersenneTwister(999) + t1a, s1a = ParallelMCMC.AbstractMCMC.step(rng1, model, sampler) + t1b, s1b = ParallelMCMC.AbstractMCMC.step(rng1, model, sampler, s1a) + + rng2 = MersenneTwister(999) + t2a, s2a = ParallelMCMC.AbstractMCMC.step(rng2, model, sampler) + t2b, s2b = ParallelMCMC.AbstractMCMC.step(rng2, model, sampler, s2a) + + @test t1a.x == t2a.x + @test t1b.x == t2b.x + @test t1b.accepted == t2b.accepted + end + + @testset "rejection preserves logp from state" begin + model = DensityModel(logp_iface, gradlogp_iface, 3) + sampler = MALASampler(0.5) + + # Run enough steps to hit a rejection + rng = MersenneTwister(12345) + _, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler) + + found_rejection = false + for _ in 1:500 + t, state_new = ParallelMCMC.AbstractMCMC.step(rng, model, sampler, state) + if !t.accepted + # On rejection, x doesn't move and logp should match prior state + @test t.x == state.x + @test t.logp == state.logp + found_rejection = true + break + end + state = state_new + end + @test found_rejection + end + + @testset "sample() runs end-to-end" begin + model = DensityModel(logp_iface, gradlogp_iface, 3) + sampler = MALASampler(0.1) + + chain = sample(model, sampler, 100; progress=false) + @test length(chain) == 100 + end + + @testset "sample() with chain_type=Chains" begin + model = DensityModel(logp_iface, gradlogp_iface, 2) + sampler = MALASampler(0.15) + + chain = sample( + model, sampler, 200; + chain_type=MCMCChains.Chains, progress=false, + ) + + @test chain isa MCMCChains.Chains + @test size(chain, 1) == 200 + + # Parameter columns present + param_names = names(chain, :parameters) + @test length(param_names) == 2 + + # Internal columns present + internal_names = names(chain, :internals) + @test :logp in internal_names + @test :accepted in internal_names + + # logp values should be finite + @test all(isfinite, chain[:logp]) + + # accepted values should be 0 or 1 + acc = chain[:accepted] + @test all(a -> a == 0.0 || a == 1.0, acc) + end + + @testset "sample() with custom param_names" begin + model = DensityModel(logp_iface, gradlogp_iface, 2) + sampler = MALASampler(0.15) + + chain = sample( + model, sampler, 50; + chain_type=MCMCChains.Chains, progress=false, + param_names=[:mu, :sigma], + ) + + @test chain isa MCMCChains.Chains + param_names = names(chain, :parameters) + @test :mu in param_names + @test :sigma in param_names + end + + @testset "stationary distribution via sample()" begin + D = 3 + model = DensityModel(logp_iface, gradlogp_iface, D) + sampler = MALASampler(0.3) + + chain = sample( + MersenneTwister(2025), model, sampler, 20_000; + chain_type=MCMCChains.Chains, progress=false, + ) + + burn = 3_000 + post = Array(chain[burn:end, :, :]) # (N-burn) × D + + mu = vec(mean(post; dims=1)) + @test maximum(abs.(mu)) < 0.1 + + vars = vec(var(post; dims=1)) + @test maximum(abs.(vars .- 1.0)) < 0.15 + end +end From 1387e338127c8ae80c9ea1115f50d9bbd050ba7a Mon Sep 17 00:00:00 2001 From: Ryan Senne Date: Tue, 17 Mar 2026 09:41:42 -0400 Subject: [PATCH 3/8] Add GPU compatible Deer-Mala. Adds adaptrive step size. Begins interoperability with Turing/MCMCChains/AbstracTMCMC --- Project.toml | 7 + ext/LogDensityProblemsExt.jl | 54 +++ src/DEER/DEER.jl | 6 +- src/MALA/MALA.jl | 119 ++++++- src/ParallelMCMC.jl | 5 +- src/interface.jl | 488 ++++++++++++++++++++++++++++ test/test-AbstractMCMC-Interface.jl | 1 + test/test-Adaptive-MALA.jl | 238 ++++++++++++++ test/test-Batched-MALA.jl | 111 +++++++ test/test-DEER-Interface.jl | 170 ++++++++++ 10 files changed, 1192 insertions(+), 7 deletions(-) create mode 100644 ext/LogDensityProblemsExt.jl create mode 100644 test/test-Adaptive-MALA.jl create mode 100644 test/test-Batched-MALA.jl create mode 100644 test/test-DEER-Interface.jl diff --git a/Project.toml b/Project.toml index 2b3df4a..bba21f2 100644 --- a/Project.toml +++ b/Project.toml @@ -13,8 +13,15 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +[weakdeps] +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" + +[extensions] +LogDensityProblemsExt = "LogDensityProblems" + [compat] ADTypes = "1.21.0" +LogDensityProblems = "2" AbstractMCMC = "5.10.0" DifferentiationInterface = "0.7.13" LinearAlgebra = "1.12.0" diff --git a/ext/LogDensityProblemsExt.jl b/ext/LogDensityProblemsExt.jl new file mode 100644 index 0000000..84effbc --- /dev/null +++ b/ext/LogDensityProblemsExt.jl @@ -0,0 +1,54 @@ +module LogDensityProblemsExt + +using ParallelMCMC +import LogDensityProblems + +""" + DensityModel(ld) + +Construct a `DensityModel` from any object implementing the +[LogDensityProblems](https://github.com/tpapp/LogDensityProblems.jl) interface. + +`ld` must support: +- `LogDensityProblems.capabilities(ld)` returning at least + `LogDensityProblems.LogDensityOrder{1}` (i.e. gradient available). +- `LogDensityProblems.dimension(ld)` → `Int` +- `LogDensityProblems.logdensity_and_gradient(ld, x)` → `(logp, grad)` + +# Turing.jl / DynamicPPL example +```julia +using Turing, LogDensityProblems, LogDensityProblemsAD, Mooncake, ParallelMCMC, MCMCChains + +@model function mymodel(y) + μ ~ Normal(0, 1) + y ~ Normal(μ, 0.5) +end + +obs = 1.5 +ld = DynamicPPL.LogDensityFunction(mymodel(obs)) +ldg = LogDensityProblemsAD.ADgradient(Mooncake.Extras.MooncakeAD(), ld) + +model = DensityModel(ldg) +chain = sample(model, AdaptiveMALASampler(0.3; n_warmup=500), 2_000; + chain_type=MCMCChains.Chains, progress=true) +``` +""" +function ParallelMCMC.DensityModel(ld) + caps = LogDensityProblems.capabilities(ld) + caps isa LogDensityProblems.LogDensityOrder{0} && + error("LogDensityProblems model must support gradients (LogDensityOrder{1} or higher). " * + "Wrap it with LogDensityProblemsAD.ADgradient first.") + + dim = LogDensityProblems.dimension(ld) + + logp(x) = LogDensityProblems.logdensity(ld, x) + + function gradlogp(x) + _, g = LogDensityProblems.logdensity_and_gradient(ld, x) + return g + end + + return ParallelMCMC.DensityModel(logp, gradlogp, dim) +end + +end # module diff --git a/src/DEER/DEER.jl b/src/DEER/DEER.jl index 75d4057..eddcbb6 100644 --- a/src/DEER/DEER.jl +++ b/src/DEER/DEER.jl @@ -280,11 +280,7 @@ function deer_update( S_new = solve_affine_scan_diag(A, B, s0) if damping != 1.0 - @threads for t in 1:T - for i in 1:D - S_new[i, t] = (1 - damping) * S[i, t] + damping * S_new[i, t] - end - end + @. S_new = (1 - damping) * S + damping * S_new end return S_new diff --git a/src/MALA/MALA.jl b/src/MALA/MALA.jl index 923f628..8f1e936 100644 --- a/src/MALA/MALA.jl +++ b/src/MALA/MALA.jl @@ -152,6 +152,23 @@ Returns `(x_next, accepted::Bool)`. function mala_step_full( logp, gradlogp, x::AbstractVector, ϵ::Real, ξ::AbstractVector, u::Real; cholM=nothing, +) + x_next, accepted, _ = mala_step_with_logα(logp, gradlogp, x, ϵ, ξ, u; cholM=cholM) + return x_next, accepted +end + +""" +One taped MALA step, returning the next state, the accept flag, **and** the raw +log acceptance ratio `logα = log p(y) + log q(x|y) - log p(x) - log q(y|x)`. + +The returned `logα` is the un-clamped value; the actual acceptance probability is +`min(1, exp(logα))`. This is needed by adaptive step-size schemes (dual averaging). + +Returns `(x_next, accepted::Bool, logα::Float64)`. +""" +function mala_step_with_logα( + logp, gradlogp, x::AbstractVector, ϵ::Real, ξ::AbstractVector, u::Real; + cholM=nothing, ) length(x) == length(ξ) || throw(DimensionMismatch("x and ξ must have the same length")) 0.0 < u < 1.0 || throw(ArgumentError("u must be in (0, 1)")) @@ -169,7 +186,7 @@ function mala_step_full( accepted = log(u) < logα x_next = accepted ? y : x - return x_next, accepted + return x_next, accepted, Float64(logα) end """ @@ -184,4 +201,104 @@ function mala_step_surrogate( return (a .* y) .+ ((1 - a) .* x) end +# Apply mass matrix to a D×N matrix of gradient columns (same math as scalar, +# matrix multiply broadcasts naturally). +_apply_M_batched(G::AbstractMatrix, ::Nothing) = G +_apply_M_batched(G::AbstractMatrix, cholM::Cholesky) = cholM.L * (cholM.L' * G) + +_apply_L_batched(Ξ::AbstractMatrix, ::Nothing) = Ξ +_apply_L_batched(Ξ::AbstractMatrix, cholM::Cholesky) = cholM.L * Ξ + +""" +Compute column-wise M⁻¹-norm squared: `[||R[:,n]||²_{M⁻¹}]_n`. +`R` is D×N; returns a length-N vector. +GPU-compatible (uses `sum(abs2, …; dims=1)` which works on CuArrays). +Note: `cholM` must be `nothing` when using GPU arrays; the triangular solve +`cholM.L \\ R` pulls device arrays to CPU. +""" +function _quad_Minv_batched(R::AbstractMatrix, ::Nothing) + return vec(sum(abs2, R; dims=1)) +end + +function _quad_Minv_batched(R::AbstractMatrix, cholM::Cholesky) + W = cholM.L \ R + return vec(sum(abs2, W; dims=1)) +end + +""" + logq_mala_batched(Y, X, gradlogp_X, ε; cholM=nothing) + +Compute `log q(Y[:,n] | X[:,n])` for all N chains simultaneously. +`Y`, `X`, `gradlogp_X` are D×N; returns a length-N vector. +""" +function logq_mala_batched( + Y::AbstractMatrix, + X::AbstractMatrix, + gradlogp_X::AbstractMatrix, + ε::Real; + cholM=nothing, +) + D = size(X, 1) + μ = X .+ ε .* _apply_M_batched(gradlogp_X, cholM) + R = Y .- μ + q = _quad_Minv_batched(R, cholM) + ldet = _logdet_M(cholM) + return @. -0.5 * q / (2ε) - (D / 2) * log(4π * ε) - 0.5 * ldet +end + +""" + mala_step_batched(logp_batch, gradlogp_batch, X, ε, Ξ, u; cholM=nothing) + +Run one MALA step for N chains simultaneously. + +- `X` :: D×N — current states (one chain per column). +- `Ξ` :: D×N — N(0,I) noise. +- `u` :: length-N — Uniform(0,1) draws. +- `logp_batch(X)` → length-N log-densities. +- `gradlogp_batch(X)` → D×N gradient matrix. + +Returns `(X_next::AbstractMatrix, accepted::AbstractVector)`. + +**GPU use:** pass `CuArray` inputs and GPU-compatible `logp_batch`/`gradlogp_batch`. +Requires `cholM=nothing` for full on-device execution (Cholesky preconditioner involves +a CPU-side triangular solve). Use `eltype(X)` for `ε` to avoid float-type promotions +that would pull data off GPU. +""" +function mala_step_batched( + logp_batch, + gradlogp_batch, + X::AbstractMatrix, + ε::Real, + Ξ::AbstractMatrix, + u::AbstractVector; + cholM=nothing, +) + D, N = size(X) + size(Ξ) == (D, N) || throw(DimensionMismatch("X and Ξ must have the same size")) + length(u) == N || throw(DimensionMismatch("u must have length N = size(X,2)")) + + # Cast ε to element type of X to avoid float-promotion off GPU. + ε_T = eltype(X)(ε) + + G_X = gradlogp_batch(X) # D×N + Y = X .+ ε_T .* _apply_M_batched(G_X, cholM) .+ + sqrt(2 * ε_T) .* _apply_L_batched(Ξ, cholM) # D×N + + lp_X = logp_batch(X) # N + lp_Y = logp_batch(Y) # N + G_Y = gradlogp_batch(Y) # D×N + + lq_YX = logq_mala_batched(Y, X, G_X, ε_T; cholM=cholM) # N + lq_XY = logq_mala_batched(X, Y, G_Y, ε_T; cholM=cholM) # N + + logα = @. (lp_Y + lq_XY) - (lp_X + lq_YX) # N + accepted = @. log(u) < logα # N Bool + + # Select: proposal if accepted, current if rejected. + # reshape to 1×N so it broadcasts against D×N. + mask = reshape(accepted, 1, N) + X_next = @. ifelse(mask, Y, X) # D×N + return X_next, vec(accepted) +end + end # module diff --git a/src/ParallelMCMC.jl b/src/ParallelMCMC.jl index 7097a64..da2799c 100644 --- a/src/ParallelMCMC.jl +++ b/src/ParallelMCMC.jl @@ -10,7 +10,10 @@ include("MALA/MALA.jl") include("DEER/DEER.jl") include("interface.jl") -export DensityModel, MALASampler, MALATransition, MALAState +export DensityModel, BatchedDensityModel +export MALASampler, MALATransition, MALAState +export AdaptiveMALASampler, AdaptiveMALATransition, AdaptiveMALAState +export DEERSampler, DEERTransition, DEERState, MALATapeElement export MALA, DEER end diff --git a/src/interface.jl b/src/interface.jl index ecb87eb..03ba135 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -118,3 +118,491 @@ function AbstractMCMC.bundle_samples( Dict(:parameters => names, :internals => internal_names), ) end + +""" + MALATapeElement(ξ, u) + +One element of the MALA noise tape: a noise vector `ξ ~ N(0,I)` and a uniform +scalar `u ~ Uniform(0,1)`. Stored with a concrete vector type `V` for type +stability inside `DEER.TapedRecursion`. +""" +struct MALATapeElement{V<:AbstractVector} + ξ::V + u::Float64 +end + +""" + DEERSampler(epsilon; T, maxiter, tol_abs, tol_rel, jacobian, damping, probes) + +DEER-accelerated MALA sampler. + +DEER solves for a trajectory of `T` steps in parallel (O(log T) iterations), +then the AbstractMCMC interface returns samples from that trajectory +sequentially. When the trajectory is exhausted a new tape is drawn and DEER +re-solves starting from the last state. + +# Arguments +- `epsilon` — MALA step size. +- `T` — trajectory length per DEER solve (default 64). +- `maxiter` — maximum DEER iterations per solve (default 200). +- `tol_abs`, `tol_rel` — convergence tolerances (default 1e-6, 1e-5). +- `jacobian` — Jacobian mode: `:diag`, `:stoch_diag`, or `:full` (default `:diag`). +- `damping` — DEER damping in (0,1] (default 0.5; helps convergence). +- `probes` — Hutchinson probes for `:stoch_diag` mode (default 1). + +# Parallel chains +Both `MALASampler` and `DEERSampler` are compatible with +`AbstractMCMC.sample(model, sampler, MCMCThreads(), N, nchains)`. Each chain +has its own immutable state and RNG so there is no shared mutable data. +Note that each `DEERSampler` chain internally uses `Base.Threads.@threads` +for the parallel scan, so running many chains via `MCMCThreads()` on a machine +with few threads may over-subscribe the thread pool. +""" +struct DEERSampler <: AbstractMCMC.AbstractSampler + epsilon::Float64 + T::Int + maxiter::Int + tol_abs::Float64 + tol_rel::Float64 + jacobian::Symbol + damping::Float64 + probes::Int +end + +function DEERSampler( + epsilon::Real; + T::Int=64, + maxiter::Int=200, + tol_abs::Real=1e-6, + tol_rel::Real=1e-5, + jacobian::Symbol=:diag, + damping::Real=0.5, + probes::Int=1, +) + return DEERSampler( + Float64(epsilon), T, maxiter, + Float64(tol_abs), Float64(tol_rel), + jacobian, Float64(damping), probes, + ) +end + +""" +State for a `DEERSampler` chain. + +- `x` — current position (= `trajectory[:, t]`, the last returned sample). +- `logp` — log-density at `x`. +- `trajectory` — D×T matrix produced by the most recent DEER solve. +- `tape` — noise tape used for that solve. +- `t` — index within `trajectory` of the last returned sample (1-indexed). +""" +struct DEERState{V<:AbstractVector, M<:AbstractMatrix} + x::V + logp::Float64 + trajectory::M + tape::Vector{MALATapeElement{V}} + t::Int +end + +""" +One DEER sample: parameter vector `x` and its log-density `logp`. +""" +struct DEERTransition{V<:AbstractVector} + x::V + logp::Float64 +end + +function _build_mala_deer_rec( + model::DensityModel, ε::Float64, tape::Vector{<:MALATapeElement} +) + logp = model.logdensity + gradlogp = model.grad_logdensity + + step_fwd = (x, te) -> MALA.mala_step_taped(logp, gradlogp, x, ε, te.ξ, te.u) + step_lin = (x, te, a) -> MALA.mala_step_surrogate(logp, gradlogp, x, ε, te.ξ, a) + consts = (x, te) -> (MALA.mala_accept_indicator(logp, gradlogp, x, ε, te.ξ, te.u),) + + return DEER.TapedRecursion( + step_fwd, step_lin, tape; + consts=consts, const_example=(0.0,), + ) +end + +function _deer_solve_new_tape( + rng::Random.AbstractRNG, + model::DensityModel, + sampler::DEERSampler, + x0::AbstractVector, +) + D = model.dim + T = sampler.T + tape = [MALATapeElement(randn(rng, D), rand(rng)) for _ in 1:T] + rec = _build_mala_deer_rec(model, sampler.epsilon, tape) + S = DEER.solve( + rec, x0; + tol_abs = sampler.tol_abs, + tol_rel = sampler.tol_rel, + maxiter = sampler.maxiter, + jacobian = sampler.jacobian, + damping = sampler.damping, + probes = sampler.probes, + rng = rng, + ) + return S, tape +end + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::DensityModel, + sampler::DEERSampler; + initial_params=nothing, + kwargs..., +) + x0 = if initial_params !== nothing + copy(initial_params) + else + randn(rng, model.dim) + end + + S, tape = _deer_solve_new_tape(rng, model, sampler, x0) + + x1 = S[:, 1] + logp1 = Float64(model.logdensity(x1)) + trans = DEERTransition(x1, logp1) + state = DEERState(x1, logp1, S, tape, 1) + return trans, state +end + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::DensityModel, + sampler::DEERSampler, + state::DEERState; + kwargs..., +) + T = sampler.T + t_next = state.t + 1 + + if t_next <= T + # Consume the next cached sample from the trajectory. + x_new = state.trajectory[:, t_next] + logp_new = Float64(model.logdensity(x_new)) + trans = DEERTransition(x_new, logp_new) + new_state = DEERState(x_new, logp_new, state.trajectory, state.tape, t_next) + return trans, new_state + else + # Trajectory exhausted — re-solve with a fresh tape. + x0 = state.trajectory[:, T] + S_new, tape = _deer_solve_new_tape(rng, model, sampler, x0) + x_new = S_new[:, 1] + logp_new = Float64(model.logdensity(x_new)) + trans = DEERTransition(x_new, logp_new) + new_state = DEERState(x_new, logp_new, S_new, tape, 1) + return trans, new_state + end +end + +function AbstractMCMC.bundle_samples( + samples::Vector{<:DEERTransition}, + model::DensityModel, + sampler::DEERSampler, + state::DEERState, + ::Type{MCMCChains.Chains}; + param_names=nothing, + kwargs..., +) + N = length(samples) + D = model.dim + + names = if param_names !== nothing + param_names + else + [Symbol("x[$i]") for i in 1:D] + end + + internal_names = [:logp] + + vals = Matrix{Float64}(undef, N, D) + internals = Matrix{Float64}(undef, N, 1) + + for i in 1:N + vals[i, :] .= samples[i].x + internals[i, 1] = samples[i].logp + end + + return MCMCChains.Chains( + hcat(vals, internals), + vcat(names, internal_names), + Dict(:parameters => names, :internals => internal_names), + ) +end + +""" + BatchedDensityModel(logdensity_batch, grad_logdensity_batch, dim) + +Wraps batched log-density and gradient functions for use with +`MALA.mala_step_batched`. + +- `logdensity_batch(X::AbstractMatrix) -> AbstractVector` — + `X` is D×N; returns a length-N vector of log-densities. +- `grad_logdensity_batch(X::AbstractMatrix) -> AbstractMatrix` — + `X` is D×N; returns a D×N gradient matrix (column `n` = gradient for chain `n`). +- `dim::Int` — dimensionality D of the parameter space. + +# GPU use +Pass `CuMatrix` inputs and implement the two functions to return `CuArray`s. +Requires `cholM=nothing` in `mala_step_batched` for fully on-device execution. + +# Example +```julia +logp_b(X) = vec(sum(x -> -0.5x^2, X; dims=1)) # standard normal, N chains +gradlogp_b(X) = -X +bmodel = BatchedDensityModel(logp_b, gradlogp_b, 3) + +X = randn(3, 100) # 100 chains of dimension 3 +Xi = randn(3, 100) +u = rand(100) +X_next, accepted = MALA.mala_step_batched(bmodel.logdensity_batch, + bmodel.grad_logdensity_batch, + X, 0.1, Xi, u) +``` +""" +struct BatchedDensityModel{F,G} <: AbstractMCMC.AbstractModel + logdensity_batch::F + grad_logdensity_batch::G + dim::Int +end + +""" + AdaptiveMALASampler(epsilon_init; n_warmup, target_accept, gamma, t0, kappa, cholM) + +MALA sampler with automatic step-size adaptation via dual averaging +(Nesterov 2009, as used in NUTS — Hoffman & Gelman 2014). + +During the first `n_warmup` steps the step size `ε` is adapted online to drive +the Metropolis acceptance rate toward `target_accept` (default 0.574, which is +the asymptotically optimal rate for MALA in high dimensions). After warmup the +smoothed estimate `ε̄` is frozen and used for all remaining steps. + +# Algorithm +At warmup step `m`, given current log-acceptance ratio `logα`: + + α = min(1, exp(logα)) + H̄_m = (1 − 1/(m+t₀)) H̄_{m−1} + (1/(m+t₀)) (δ − α) + log ε_m = μ − √m/γ · H̄_m (instantaneous) + log ε̄_m = m^(−κ) log ε_m + (1−m^(−κ)) log ε̄_{m−1} (smoothed) + +where μ = log(10 ε₀) is a fixed target. After warmup `ε̄` is used. + +# Keyword arguments +- `n_warmup` — adaptation steps (default 1000). +- `target_accept` — δ, desired acceptance rate (default 0.574). +- `gamma` — γ, regularisation strength (default 0.05). +- `t0` — stability offset (default 10.0). +- `kappa` — shrinkage exponent κ ∈ (0.5, 1] (default 0.75). +- `cholM` — optional Cholesky factor of a mass matrix `M` (default `nothing` = identity). + +# MCMCChains output +The `Chains` object includes internals `[:logp, :accepted, :step_size, :is_warmup]`. +After warmup, `step_size` is constant (the frozen `ε̄`). + +# Parallel chains +Works with `MCMCThreads()`. R-hat and ESS are computed automatically by +`MCMCChains` from multi-chain output. + +# Turing.jl / LogDensityProblems +Load `LogDensityProblems` (and optionally `LogDensityProblemsAD`) then use the +`DensityModel(ld)` constructor to wrap any `LogDensityProblems`-compatible model +(including Turing/DynamicPPL models) directly. +""" +struct AdaptiveMALASampler{CM} <: AbstractMCMC.AbstractSampler + epsilon_init::Float64 + n_warmup::Int + target_accept::Float64 + gamma::Float64 + t0::Float64 + kappa::Float64 + cholM::CM +end + +function AdaptiveMALASampler( + epsilon_init::Real; + n_warmup::Int=1000, + target_accept::Real=0.574, + gamma::Real=0.05, + t0::Real=10.0, + kappa::Real=0.75, + cholM=nothing, +) + return AdaptiveMALASampler( + Float64(epsilon_init), n_warmup, + Float64(target_accept), Float64(gamma), Float64(t0), Float64(kappa), + cholM, + ) +end + +struct AdaptiveMALAState{V<:AbstractVector} + x::V + logp::Float64 + epsilon::Float64 # instantaneous step size ε_m + epsilon_bar::Float64 # smoothed step size ε̄_m (frozen after warmup) + H_bar::Float64 # dual-average statistic H̄_m + step::Int # warmup step counter (0 = initialisation) +end + +struct AdaptiveMALATransition{V<:AbstractVector} + x::V + logp::Float64 + accepted::Bool + step_size::Float64 # ε used for this step + is_warmup::Bool +end + +function _dual_average_update( + epsilon_init::Float64, + epsilon_bar::Float64, + H_bar::Float64, + m::Int, + logα::Float64, + sampler::AdaptiveMALASampler, +) + α = min(1.0, exp(logα)) + δ = sampler.target_accept + γ = sampler.gamma + t0 = sampler.t0 + κ = sampler.kappa + μ = log(10.0 * epsilon_init) # fixed shrinkage target + + H_bar_new = (1.0 - 1.0 / (m + t0)) * H_bar + (1.0 / (m + t0)) * (δ - α) + log_ε = μ - sqrt(Float64(m)) / γ * H_bar_new + log_ε_bar_new = Float64(m)^(-κ) * log_ε + (1.0 - Float64(m)^(-κ)) * log(epsilon_bar) + + return exp(log_ε), exp(log_ε_bar_new), H_bar_new +end + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::DensityModel, + sampler::AdaptiveMALASampler; + initial_params=nothing, + kwargs..., +) + x = if initial_params !== nothing + copy(initial_params) + else + randn(rng, model.dim) + end + logp_val = Float64(model.logdensity(x)) + trans = AdaptiveMALATransition(x, logp_val, true, sampler.epsilon_init, true) + state = AdaptiveMALAState(x, logp_val, sampler.epsilon_init, sampler.epsilon_init, 0.0, 0) + return trans, state +end + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::DensityModel, + sampler::AdaptiveMALASampler, + state::AdaptiveMALAState; + kwargs..., +) + D = model.dim + in_warmup = state.step < sampler.n_warmup + ε = in_warmup ? state.epsilon : state.epsilon_bar + + ξ = randn(rng, D) + u = rand(rng) + + x_next, accepted, logα = MALA.mala_step_with_logα( + model.logdensity, model.grad_logdensity, state.x, ε, ξ, u; + cholM=sampler.cholM, + ) + + logp_next = accepted ? Float64(model.logdensity(x_next)) : state.logp + + # Dual-average adaptation (only during warmup) + m_new = state.step + 1 + ε_new, ε_bar_new, H_bar_new = if in_warmup + _dual_average_update(sampler.epsilon_init, state.epsilon_bar, state.H_bar, m_new, logα, sampler) + else + state.epsilon, state.epsilon_bar, state.H_bar + end + + trans = AdaptiveMALATransition(x_next, logp_next, accepted, ε, in_warmup) + new_state = AdaptiveMALAState(x_next, logp_next, ε_new, ε_bar_new, H_bar_new, m_new) + return trans, new_state +end + +function AbstractMCMC.bundle_samples( + samples::Vector{<:AdaptiveMALATransition}, + model::DensityModel, + sampler::AdaptiveMALASampler, + state::AdaptiveMALAState, + ::Type{MCMCChains.Chains}; + param_names=nothing, + kwargs..., +) + N = length(samples) + D = model.dim + + names = if param_names !== nothing + param_names + else + [Symbol("x[$i]") for i in 1:D] + end + + internal_names = [:logp, :accepted, :step_size, :is_warmup] + + vals = Matrix{Float64}(undef, N, D) + internals = Matrix{Float64}(undef, N, 4) + + for i in 1:N + s = samples[i] + vals[i, :] .= s.x + internals[i, 1] = s.logp + internals[i, 2] = s.accepted ? 1.0 : 0.0 + internals[i, 3] = s.step_size + internals[i, 4] = s.is_warmup ? 1.0 : 0.0 + end + + return MCMCChains.Chains( + hcat(vals, internals), + vcat(names, internal_names), + Dict(:parameters => names, :internals => internal_names), + ) +end + +""" + DensityModel(ld) + +Construct a `DensityModel` from any object `ld` that implements the +[LogDensityProblems](https://github.com/tpapp/LogDensityProblems.jl) interface, +i.e. provides `LogDensityProblems.logdensity`, `LogDensityProblems.logdensity_and_gradient`, +and `LogDensityProblems.dimension`. + +Requires the `LogDensityProblems` package to be loaded: + +```julia +using LogDensityProblems # gradient-free: only logdensity used +using LogDensityProblemsAD # or this, for AD-based gradients +``` + +# Turing.jl example +```julia +using Turing, LogDensityProblems, LogDensityProblemsAD, Mooncake + +@model function mymodel(data) + μ ~ Normal(0, 1) + data ~ Normal(μ, 1) +end + +ld = DynamicPPL.LogDensityFunction(mymodel(obs)) +ldg = LogDensityProblemsAD.ADgradient(Mooncake.Extras.MooncakeAD(), ld) +model = DensityModel(ldg) + +chain = sample(model, AdaptiveMALASampler(0.1; n_warmup=500), 2000; + chain_type=MCMCChains.Chains, progress=true) +``` + +This method is defined in the `LogDensityProblemsExt` extension and is only +available when `LogDensityProblems` has been loaded. +""" +function DensityModel end # extended by LogDensityProblemsExt diff --git a/test/test-AbstractMCMC-Interface.jl b/test/test-AbstractMCMC-Interface.jl index d55ab7a..3da103d 100644 --- a/test/test-AbstractMCMC-Interface.jl +++ b/test/test-AbstractMCMC-Interface.jl @@ -1,6 +1,7 @@ using Test using Random using LinearAlgebra +using Statistics using MCMCChains using ParallelMCMC diff --git a/test/test-Adaptive-MALA.jl b/test/test-Adaptive-MALA.jl new file mode 100644 index 0000000..15ce7f2 --- /dev/null +++ b/test/test-Adaptive-MALA.jl @@ -0,0 +1,238 @@ +using Test +using Random +using LinearAlgebra +using Statistics +using MCMCChains + +using ParallelMCMC +const MALA = ParallelMCMC.MALA + +logp_adapt(x) = -0.5 * dot(x, x) +gradlogp_adapt(x) = -x + + +@testset "mala_step_with_logα returns same x_next and accepted as mala_step_full" begin + rng = MersenneTwister(1) + D = 4 + for _ in 1:20 + x = randn(rng, D) + ξ = randn(rng, D) + u = rand(rng) * 0.999 + 1e-15 + + x1, a1 = MALA.mala_step_full(logp_adapt, gradlogp_adapt, x, 0.1, ξ, u) + x2, a2, logα = MALA.mala_step_with_logα(logp_adapt, gradlogp_adapt, x, 0.1, ξ, u) + + @test x1 == x2 + @test a1 == a2 + @test isfinite(logα) + # When accepted logα ≥ log(u), when rejected logα < log(u) + @test a2 == (log(u) < logα) + end +end + +@testset "AdaptiveMALASampler construction" begin + s = AdaptiveMALASampler(0.1) + @test s isa ParallelMCMC.AbstractMCMC.AbstractSampler + @test s.epsilon_init == 0.1 + @test s.n_warmup == 1000 + @test s.target_accept ≈ 0.574 + @test s.gamma ≈ 0.05 + @test s.t0 ≈ 10.0 + @test s.kappa ≈ 0.75 + @test s.cholM === nothing + + s2 = AdaptiveMALASampler(0.05; n_warmup=500, target_accept=0.65) + @test s2.n_warmup == 500 + @test s2.target_accept ≈ 0.65 +end + +@testset "AdaptiveMALASampler initial step" begin + rng = MersenneTwister(42) + model = DensityModel(logp_adapt, gradlogp_adapt, 3) + sampler = AdaptiveMALASampler(0.1) + + trans, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler) + + @test trans isa AdaptiveMALATransition + @test state isa AdaptiveMALAState + @test length(trans.x) == 3 + @test isfinite(trans.logp) + @test trans.step_size == sampler.epsilon_init + @test trans.is_warmup == true + @test state.step == 0 + @test state.epsilon == sampler.epsilon_init + @test state.epsilon_bar == sampler.epsilon_init +end + +@testset "AdaptiveMALASampler initial step respects initial_params" begin + rng = MersenneTwister(7) + model = DensityModel(logp_adapt, gradlogp_adapt, 2) + sampler = AdaptiveMALASampler(0.1) + x0 = [1.0, -1.0] + x0_copy = copy(x0) + + trans, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler; initial_params=x0) + + @test x0 == x0_copy # not mutated + @test trans.x == x0 +end + +@testset "step counter increments during warmup" begin + rng = MersenneTwister(5) + model = DensityModel(logp_adapt, gradlogp_adapt, 2) + sampler = AdaptiveMALASampler(0.05; n_warmup=10) + + _, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler) + + for expected_step in 1:5 + _, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler, state) + @test state.step == expected_step + @test state.step <= sampler.n_warmup # still in warmup + end +end + +@testset "step size changes during warmup but freezes after" begin + rng = MersenneTwister(11) + model = DensityModel(logp_adapt, gradlogp_adapt, 3) + n_w = 20 + sampler = AdaptiveMALASampler(0.1; n_warmup=n_w) + + _, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler) + + epsilons_warmup = Float64[] + for _ in 1:n_w + t, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler, state) + push!(epsilons_warmup, t.step_size) + @test t.is_warmup == true + end + + # At least some epsilon change during warmup + @test !all(==(epsilons_warmup[1]), epsilons_warmup) + + # First post-warmup step: is_warmup should be false, step_size should be frozen ε̄ + epsilon_bar_final = state.epsilon_bar + t_post, state_post = ParallelMCMC.AbstractMCMC.step(rng, model, sampler, state) + @test t_post.is_warmup == false + @test t_post.step_size ≈ epsilon_bar_final + + # Subsequent post-warmup steps keep the same step_size + _, state_post2 = ParallelMCMC.AbstractMCMC.step(rng, model, sampler, state_post) + @test state_post2.epsilon_bar ≈ epsilon_bar_final +end + +@testset "adapted step size targets acceptance rate" begin + D = 5 + model = DensityModel(logp_adapt, gradlogp_adapt, D) + n_w = 2_000 + sampler = AdaptiveMALASampler(0.5; n_warmup=n_w, target_accept=0.574) + + # Use a single persistent RNG throughout warmup and post-warmup. + rng = MersenneTwister(42) + _, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler) + for _ in 1:n_w + _, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler, state) + end + + # After warmup, ε̄ should yield acceptance near target. + # Run 500 post-warmup steps and measure actual rate. + n_accept = 0 + for _ in 1:500 + t, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler, state) + n_accept += t.accepted + end + accept_rate = n_accept / 500 + + @test 0.35 < accept_rate < 0.80 # wide tolerance; adaptation is approximate +end + +@testset "sample() end-to-end with AdaptiveMALASampler" begin + model = DensityModel(logp_adapt, gradlogp_adapt, 2) + sampler = AdaptiveMALASampler(0.2; n_warmup=50) + + samples = sample(MersenneTwister(1), model, sampler, 100; progress=false) + @test length(samples) == 100 +end + +@testset "sample() with chain_type=Chains" begin + model = DensityModel(logp_adapt, gradlogp_adapt, 2) + sampler = AdaptiveMALASampler(0.2; n_warmup=50) + + chain = sample( + MersenneTwister(1), model, sampler, 150; + chain_type=MCMCChains.Chains, progress=false, + ) + + @test chain isa MCMCChains.Chains + @test size(chain, 1) == 150 + + internals = names(chain, :internals) + @test :logp in internals + @test :accepted in internals + @test :step_size in internals + @test :is_warmup in internals + + @test all(isfinite, chain[:logp]) + @test all(x -> x == 0.0 || x == 1.0, chain[:accepted]) + @test all(s -> s > 0, chain[:step_size]) + + # Warmup samples appear at the start + @test chain[:is_warmup][1] == 1.0 + # Post-warmup samples have is_warmup == 0 + @test chain[:is_warmup][end] == 0.0 +end + +@testset "step_size is constant after warmup" begin + model = DensityModel(logp_adapt, gradlogp_adapt, 3) + n_w = 30 + sampler = AdaptiveMALASampler(0.1; n_warmup=n_w) + + chain = sample( + MersenneTwister(3), model, sampler, n_w + 50; + chain_type=MCMCChains.Chains, progress=false, + ) + + # Filter by is_warmup flag to avoid off-by-one from the init transition. + is_wup = vec(chain[:is_warmup]) .== 1.0 + step_sizes = vec(chain[:step_size]) + post_warmup = step_sizes[.!is_wup] + + @test length(post_warmup) > 0 + # All post-warmup step sizes must be identical (frozen ε̄) + @test all(≈(post_warmup[1]), post_warmup) +end + +@testset "AdaptiveMALASampler stationary distribution" begin + D = 3 + model = DensityModel(logp_adapt, gradlogp_adapt, D) + n_w = 1_000 + sampler = AdaptiveMALASampler(0.5; n_warmup=n_w) + + chain = sample( + MersenneTwister(2025), model, sampler, n_w + 5_000; + chain_type=MCMCChains.Chains, progress=false, + ) + + # Discard warmup + post = Array(chain[(n_w + 1):end, :, :]) # 5000 × D + + mu = vec(mean(post; dims=1)) + vars = vec(var(post; dims=1)) + + @test maximum(abs.(mu)) < 0.15 + @test maximum(abs.(vars .- 1.0)) < 0.25 +end + +@testset "AdaptiveMALASampler parallel chains via MCMCThreads" begin + model = DensityModel(logp_adapt, gradlogp_adapt, 2) + sampler = AdaptiveMALASampler(0.1; n_warmup=20) + + chains = sample( + MersenneTwister(42), model, sampler, + ParallelMCMC.AbstractMCMC.MCMCThreads(), 60, 2; + chain_type=MCMCChains.Chains, progress=false, + ) + + @test chains isa MCMCChains.Chains + @test size(chains, 1) == 60 + @test size(chains, 3) == 2 +end diff --git a/test/test-Batched-MALA.jl b/test/test-Batched-MALA.jl new file mode 100644 index 0000000..5edcacd --- /dev/null +++ b/test/test-Batched-MALA.jl @@ -0,0 +1,111 @@ +using Test +using Random +using LinearAlgebra +using Statistics + +using ParallelMCMC +const MALA = ParallelMCMC.MALA + +# Standard normal target: logp(x) = -0.5 ||x||², grad = -x +logp_batch(X) = vec(-0.5 .* sum(abs2, X; dims=1)) # D×N → N +gradlogp_batch(X) = -X # D×N → D×N + +@testset "BatchedDensityModel construction" begin + bm = BatchedDensityModel(logp_batch, gradlogp_batch, 4) + @test bm.dim == 4 + + X = randn(4, 10) + @test length(bm.logdensity_batch(X)) == 10 + @test size(bm.grad_logdensity_batch(X)) == (4, 10) +end + +@testset "mala_step_batched output shapes" begin + rng = MersenneTwister(1) + D, N = 3, 20 + X = randn(rng, D, N) + Ξ = randn(rng, D, N) + u = rand(rng, N) + + X_next, accepted = MALA.mala_step_batched(logp_batch, gradlogp_batch, X, 0.1, Ξ, u) + + @test size(X_next) == (D, N) + @test length(accepted) == N + @test eltype(X_next) == Float64 +end + +@testset "mala_step_batched single accepted chain matches scalar step" begin + rng = MersenneTwister(42) + D = 3 + + x = randn(rng, D) + ξ = randn(rng, D) + u = rand(rng) + + # Scalar step + x_scalar, acc_scalar = MALA.mala_step_full( + x -> -0.5 * dot(x, x), x -> -x, x, 0.1, ξ, u, + ) + + # Batched step with N=1 + X = reshape(copy(x), D, 1) + Xi = reshape(copy(ξ), D, 1) + u_vec = [u] + X_next, accepted_vec = MALA.mala_step_batched( + logp_batch, gradlogp_batch, X, 0.1, Xi, u_vec, + ) + + @test isapprox(vec(X_next), x_scalar; atol=1e-12) + @test Bool(accepted_vec[1]) == acc_scalar +end + +@testset "mala_step_batched preserves rejected chains" begin + rng = MersenneTwister(5) + D, N = 4, 50 + + X = randn(rng, D, N) + Ξ = randn(rng, D, N) + # Force all rejections: u very close to 1 so log(u) ≈ 0 > any logα + u = ones(N) .* (1 - 1e-15) + + X_next, accepted = MALA.mala_step_batched(logp_batch, gradlogp_batch, X, 0.1, Ξ, u) + + # With u ≈ 1, log(u) ≈ 0 which is > most logα; almost all should be rejected + # (standard normal is well-behaved so this should reject everything or near it) + n_accepted = sum(accepted) + # At this u value essentially all should reject + @test n_accepted < N +end + +@testset "mala_step_batched DimensionMismatch errors" begin + X = randn(3, 5) + Ξ_bad = randn(3, 4) # wrong N + u_good = rand(5) + u_bad = rand(4) # wrong N + + @test_throws DimensionMismatch MALA.mala_step_batched( + logp_batch, gradlogp_batch, X, 0.1, Ξ_bad, u_good + ) + @test_throws DimensionMismatch MALA.mala_step_batched( + logp_batch, gradlogp_batch, X, 0.1, randn(3, 5), u_bad + ) +end + +@testset "mala_step_batched stationary distribution (many chains)" begin + rng = MersenneTwister(2025) + D, N, T = 3, 200, 2_000 + + X = randn(rng, D, N) + + for _ in 1:T + Ξ = randn(rng, D, N) + u = rand(rng, N) + X, _ = MALA.mala_step_batched(logp_batch, gradlogp_batch, X, 0.1, Ξ, u) + end + + # Pool all chains after burn-in (already burned in by running T steps) + mu = vec(mean(X; dims=2)) + vars = vec(var(X; dims=2)) + + @test maximum(abs.(mu)) < 0.15 + @test maximum(abs.(vars .- 1.0)) < 0.20 +end diff --git a/test/test-DEER-Interface.jl b/test/test-DEER-Interface.jl new file mode 100644 index 0000000..eaa0492 --- /dev/null +++ b/test/test-DEER-Interface.jl @@ -0,0 +1,170 @@ +using Test +using Random +using LinearAlgebra +using Statistics +using MCMCChains + +using ParallelMCMC + +logp_deer(x) = -0.5 * dot(x, x) +gradlogp_deer(x) = -x + +@testset "DEERSampler construction" begin + s = DEERSampler(0.05) + @test s isa ParallelMCMC.AbstractMCMC.AbstractSampler + @test s.epsilon == 0.05 + @test s.T == 64 + @test s.maxiter == 200 + @test s.jacobian === :diag + + # keyword overrides + s2 = DEERSampler(0.1; T=32, jacobian=:stoch_diag, damping=0.8) + @test s2.T == 32 + @test s2.jacobian === :stoch_diag + @test s2.damping == 0.8 +end + +@testset "DEERSampler initial step" begin + rng = MersenneTwister(42) + model = DensityModel(logp_deer, gradlogp_deer, 3) + sampler = DEERSampler(0.05; T=16) + + trans, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler) + + @test trans isa DEERTransition + @test state isa DEERState + @test length(trans.x) == 3 + @test isfinite(trans.logp) + @test trans.logp ≈ logp_deer(trans.x) + @test state.t == 1 + @test size(state.trajectory) == (3, 16) + @test length(state.tape) == 16 +end + +@testset "DEERSampler initial step respects initial_params" begin + rng = MersenneTwister(42) + model = DensityModel(logp_deer, gradlogp_deer, 3) + sampler = DEERSampler(0.05; T=8) + x0 = [1.0, 2.0, 3.0] + x0_copy = copy(x0) + + trans, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler; initial_params=x0) + + # initial_params not mutated + @test x0 == x0_copy + # first sample should differ from x0 (DEER solves the trajectory) + @test length(trans.x) == 3 + @test isfinite(trans.logp) +end + +@testset "DEERSampler sequential steps advance trajectory index" begin + rng = MersenneTwister(7) + model = DensityModel(logp_deer, gradlogp_deer, 2) + sampler = DEERSampler(0.05; T=8) + + trans, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler) + @test state.t == 1 + + trans2, state2 = ParallelMCMC.AbstractMCMC.step(rng, model, sampler, state) + @test state2.t == 2 + @test trans2.x ≈ state.trajectory[:, 2] + + trans3, state3 = ParallelMCMC.AbstractMCMC.step(rng, model, sampler, state2) + @test state3.t == 3 +end + +@testset "DEERSampler re-solves at trajectory boundary" begin + rng = MersenneTwister(99) + model = DensityModel(logp_deer, gradlogp_deer, 2) + T = 4 + sampler = DEERSampler(0.05; T=T) + + _, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler) + + # advance to t == T + for _ in 1:(T - 1) + _, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler, state) + end + @test state.t == T + + # next step should trigger re-solve: t resets to 1 with a new trajectory + _, state_new = ParallelMCMC.AbstractMCMC.step(rng, model, sampler, state) + @test state_new.t == 1 + # new trajectory should differ from old (different tape) + @test state_new.trajectory !== state.trajectory +end + +@testset "DEERSampler sample() end-to-end" begin + model = DensityModel(logp_deer, gradlogp_deer, 2) + sampler = DEERSampler(0.05; T=16) + + samples = sample(MersenneTwister(1), model, sampler, 50; progress=false) + @test length(samples) == 50 +end + +@testset "DEERSampler sample() with chain_type=Chains" begin + model = DensityModel(logp_deer, gradlogp_deer, 2) + sampler = DEERSampler(0.05; T=16) + + chain = sample( + MersenneTwister(1), model, sampler, 100; + chain_type=MCMCChains.Chains, progress=false, + ) + + @test chain isa MCMCChains.Chains + @test size(chain, 1) == 100 + @test :logp in names(chain, :internals) + @test all(isfinite, chain[:logp]) + + param_names = names(chain, :parameters) + @test length(param_names) == 2 +end + +@testset "DEERSampler sample() with custom param_names" begin + model = DensityModel(logp_deer, gradlogp_deer, 2) + sampler = DEERSampler(0.05; T=16) + + chain = sample( + MersenneTwister(2), model, sampler, 40; + chain_type=MCMCChains.Chains, progress=false, + param_names=[:mu, :sigma], + ) + + @test :mu in names(chain, :parameters) + @test :sigma in names(chain, :parameters) +end + +@testset "DEERSampler stationary distribution" begin + D = 3 + model = DensityModel(logp_deer, gradlogp_deer, D) + sampler = DEERSampler(0.1; T=32, damping=0.5) + + chain = sample( + MersenneTwister(2025), model, sampler, 5_000; + chain_type=MCMCChains.Chains, progress=false, + ) + + burn = 500 + post = Array(chain[burn:end, :, :]) # (N-burn) × D + + mu = vec(mean(post; dims=1)) + vars = vec(var(post; dims=1)) + + @test maximum(abs.(mu)) < 0.15 + @test maximum(abs.(vars .- 1.0)) < 0.25 +end + +@testset "DEERSampler parallel chains via MCMCThreads" begin + model = DensityModel(logp_deer, gradlogp_deer, 2) + sampler = DEERSampler(0.05; T=8) + + chains = sample( + MersenneTwister(42), model, sampler, + ParallelMCMC.AbstractMCMC.MCMCThreads(), 40, 2; + chain_type=MCMCChains.Chains, progress=false, + ) + + @test chains isa MCMCChains.Chains + @test size(chains, 1) == 40 # samples per chain + @test size(chains, 3) == 2 # number of chains +end From b85721abac277004c15e0b03316ff04a953597d1 Mon Sep 17 00:00:00 2001 From: Ryan Senne Date: Tue, 17 Mar 2026 15:17:13 -0400 Subject: [PATCH 4/8] Begin adding Parallel GPU DEER ascross sequence. Remove across chains version --- Project.toml | 4 +- src/DEER/DEER.jl | 308 +++++++++++++++++--------------------- src/MALA/MALA.jl | 122 ++------------- src/ParallelMCMC.jl | 3 +- src/interface.jl | 283 +++++++++++++++++----------------- test/Project.toml | 3 + test/test-Batched-MALA.jl | 111 -------------- test/test-GPU-DEER.jl | 111 ++++++++++++++ 8 files changed, 415 insertions(+), 530 deletions(-) delete mode 100644 test/test-Batched-MALA.jl create mode 100644 test/test-GPU-DEER.jl diff --git a/Project.toml b/Project.toml index bba21f2..d260ca7 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ authors = ["Ryan Senne "] [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" @@ -21,10 +22,11 @@ LogDensityProblemsExt = "LogDensityProblems" [compat] ADTypes = "1.21.0" -LogDensityProblems = "2" AbstractMCMC = "5.10.0" +CUDA = "5.11.0" DifferentiationInterface = "0.7.13" LinearAlgebra = "1.12.0" +LogDensityProblems = "2" MCMCChains = "7.7.0" Mooncake = "0.4.192" Random = "1.11.0" diff --git a/src/DEER/DEER.jl b/src/DEER/DEER.jl index eddcbb6..1b18f44 100644 --- a/src/DEER/DEER.jl +++ b/src/DEER/DEER.jl @@ -15,9 +15,10 @@ Deterministic recursion driven by a pre-generated tape. - `step_fwd(x, tape_t)`: exact forward transition (may include MH accept/reject). - `step_lin(x, tape_t, c...)`: surrogate used only for Jacobians. -- `consts(x, tape_t) -> Tuple`: returns constants `c...` passed into `step_lin` as DI.Constant. +- `consts(x, tape_t) -> Tuple`: returns constants `c...` passed into `step_lin`. - `const_example`: example tuple of constants, used in `prepare`. - `backend`: AD backend (any `ADTypes.AbstractADType`); defaults to `AutoMooncake`. + Pass `AutoEnzyme()` or another GPU-compatible backend for GPU execution. """ struct TapedRecursion{Ff,Fl,Fc,Tt,Ce,AD<:AbstractADType} step_fwd::Ff @@ -28,7 +29,7 @@ struct TapedRecursion{Ff,Fl,Fc,Tt,Ce,AD<:AbstractADType} backend::AD end -"Backward-compatible constructor: uses the same step for forward + Jacobian, and no extra constants." +"Backward-compatible constructor: uses the same step for forward + Jacobian, no constants." function TapedRecursion(step, tape::Vector) return TapedRecursion(step, step, (_x, _tt) -> (), tape, (), DEFAULT_BACKEND) end @@ -49,7 +50,7 @@ end "Prepare Jacobian for the surrogate step (reusable across t as long as tape element type is stable)." function prepare(rec::TapedRecursion, x0::AbstractVector) - f = StepWithTape(rec) + f = StepWithTape(rec) cs = rec.const_example return DI.prepare_jacobian( f, rec.backend, x0, DI.Constant(rec.tape[1]), (DI.Constant(c) for c in cs)... @@ -58,8 +59,8 @@ end "Prepare pushforward (JVP) for the surrogate step_lin." function prepare_pushforward(rec::TapedRecursion, x0::AbstractVector) - f = StepWithTape(rec) - cs = rec.const_example + f = StepWithTape(rec) + cs = rec.const_example tx0 = (zero(x0),) return DI.prepare_pushforward( f, rec.backend, x0, tx0, DI.Constant(rec.tape[1]), (DI.Constant(c) for c in cs)... @@ -68,21 +69,21 @@ end "Full Jacobian of surrogate step_lin(x, tape_t, consts...) w.r.t. x." function jac_full(rec::TapedRecursion, prep, x::AbstractVector, t::Int) - f = StepWithTape(rec) + f = StepWithTape(rec) cs = rec.consts(x, rec.tape[t]) return DI.jacobian( f, prep, rec.backend, x, DI.Constant(rec.tape[t]), (DI.Constant(c) for c in cs)... ) end -"Diagonal of the surrogate Jacobian (debug/correctness mode; not scalable if it computes full J)." +"Diagonal of the surrogate Jacobian via full Jacobian computation." function jac_diag(rec::TapedRecursion, prep, x::AbstractVector, t::Int) return diag(jac_full(rec, prep, x, t)) end @inline function _rademacher!(z::AbstractVector, rng::AbstractRNG) for i in eachindex(z) - z[i] = rand(rng, Bool) ? 1.0 : -1.0 + z[i] = rand(rng, Bool) ? one(eltype(z)) : -one(eltype(z)) end return z end @@ -90,22 +91,13 @@ end function _jvp_step_lin( rec::TapedRecursion, prep_pf, x::AbstractVector, t::Int, v::AbstractVector ) - f = StepWithTape(rec) + f = StepWithTape(rec) cs = rec.consts(x, rec.tape[t]) - - # pushforward expects tx as a tuple of tangents tx = (v,) - res = DI.pushforward( - f, - prep_pf, - rec.backend, - x, - tx, - DI.Constant(rec.tape[t]), - (DI.Constant(c) for c in cs)..., + f, prep_pf, rec.backend, x, tx, + DI.Constant(rec.tape[t]), (DI.Constant(c) for c in cs)..., ) - return res isa Tuple ? res[end] : res end @@ -117,7 +109,7 @@ diag(J) ≈ (1/K) * Σ_k z^(k) ⊙ (J z^(k)), z_i ∈ {±1} Keywords: - probes: number of random probe vectors K (typical 1–4) - rng: RNG for probes -- zbuf, jbuf: optional preallocated buffers (length D) to avoid allocations +- zbuf: optional preallocated buffer (length D) to avoid allocations """ function jac_diag_stoch( rec::TapedRecursion, @@ -128,87 +120,89 @@ function jac_diag_stoch( rng::AbstractRNG=Random.default_rng(), zbuf::Union{Nothing,AbstractVector}=nothing, ) - D = length(x) + D = length(x) + FT = float(eltype(x)) probes ≥ 1 || throw(ArgumentError("probes must be ≥ 1")) - z = zbuf === nothing ? Vector{Float64}(undef, D) : zbuf + z = zbuf === nothing ? Vector{FT}(undef, D) : zbuf length(z) == D || throw(ArgumentError("zbuf must have length D")) - d = zeros(Float64, D) - - for k in 1:probes + d = zeros(FT, D) + for _ in 1:probes _rademacher!(z, rng) - jv = _jvp_step_lin(rec, prep_pf, x, t, z) # J*z + jv = _jvp_step_lin(rec, prep_pf, x, t, z) for i in 1:D d[i] += z[i] * jv[i] end end - - invK = 1.0 / probes - for i in 1:D - d[i] *= invK - end + d .*= one(FT) / probes return d end """ -Solve s_t = a_t .* s_{t-1} + b_t for t=1..T using an associative inclusive scan. + solve_affine_scan_diag(A, B, s0) + +Solve the diagonal affine recurrence `s_t = A[:,t] .* s_{t-1} + B[:,t]` for t=1..T +via an associative inclusive parallel-prefix (Hillis-Steele) scan. + +**Complexity**: O(log T) sequential sweep levels; each level is a single broadcast +over all T columns — no loops or `@threads`. The implementation is +**array-type-agnostic**: it works identically on CPU `Matrix`, GPU `CuMatrix`, or +any other `AbstractMatrix`. Inputs: -- A :: D×T matrix, A[:,t] = a_t -- B :: D×T matrix, B[:,t] = b_t -- s0 :: length-D vector +- `A` :: D×T — per-step multiplicative coefficients +- `B` :: D×T — per-step additive offsets +- `s0` :: length-D initial state Returns: -- S :: D×T matrix, S[:,t] = s_t +- `S` :: D×T (same type as A) where `S[:,t] = s_t` """ function solve_affine_scan_diag(A::AbstractMatrix, B::AbstractMatrix, s0::AbstractVector) D, T = size(A) size(B) == (D, T) || throw(ArgumentError("B must have the same size as A")) - length(s0) == D || throw(ArgumentError("s0 length must match size(A,1)")) + length(s0) == D || throw(ArgumentError("s0 length must match size(A,1)")) - α = Matrix{Float64}(A) # working prefixes - β = Matrix{Float64}(B) + # Work on copies; swap buffers each level to avoid allocating inside the loop. + α = copy(A) + β = copy(B) αnew = similar(α) βnew = similar(β) offset = 1 while offset < T - @threads for t in 1:T - if t > offset - for i in 1:D - ai = α[i, t] - αnew[i, t] = ai * α[i, t - offset] - βnew[i, t] = ai * β[i, t - offset] + β[i, t] - end - else - for i in 1:D - αnew[i, t] = α[i, t] - βnew[i, t] = β[i, t] - end - end - end + # Columns 1:offset are unchanged at this level. + αnew[:, 1:offset] .= α[:, 1:offset] + βnew[:, 1:offset] .= β[:, 1:offset] + + # Columns offset+1:T combine with their left neighbour at distance `offset`. + # αnew[:,t] = α[:,t] * α[:,t-offset] + # βnew[:,t] = α[:,t] * β[:,t-offset] + β[:,t] + αnew[:, offset+1:T] .= α[:, offset+1:T] .* α[:, 1:T-offset] + βnew[:, offset+1:T] .= α[:, offset+1:T] .* β[:, 1:T-offset] .+ β[:, offset+1:T] + α, αnew = αnew, α β, βnew = βnew, β offset <<= 1 end - S = Matrix{Float64}(undef, D, T) - @threads for t in 1:T - for i in 1:D - S[i, t] = α[i, t] * s0[i] + β[i, t] - end - end + # Apply initial condition: S[:,t] = α[:,t] * s0 + β[:,t] + S = similar(A, D, T) + S .= α .* reshape(s0, D, 1) .+ β return S end """ -Compute one DEER update given a trajectory guess S (D×T). -Returns S_new (D×T). +Compute one DEER update given a trajectory guess `S` (D×T). +Returns `S_new` (D×T). + +The per-timestep Jacobian computation uses the AD backend stored in `rec.backend`. +For GPU execution, construct the `TapedRecursion` with a GPU-compatible backend +(e.g. `AutoEnzyme()`). Keywords: -- jacobian: :diag or :full -- damping: in (0,1] +- `jacobian`: `:diag`, `:stoch_diag`, or `:full` +- `damping`: in (0,1] """ function deer_update( rec::TapedRecursion, @@ -220,28 +214,24 @@ function deer_update( probes::Int=1, rng::AbstractRNG=Random.default_rng(), ) - s0 = vec(s0_in) + s0 = vec(s0_in) D, T = size(S) - length(s0) == D || throw(ArgumentError("S must be D×T with D=length(s0)")) - T == length(rec.tape) || throw(ArgumentError("size(S,2) must match length(rec.tape)")) + length(s0) == D || throw(ArgumentError("S must be D×T with D=length(s0)")) + T == length(rec.tape) || throw(ArgumentError("size(S,2) must match length(rec.tape)")) damping > 0 && damping ≤ 1 || throw(ArgumentError("damping must be in (0,1]")) - probes ≥ 1 || throw(ArgumentError("probes must be ≥ 1")) + probes ≥ 1 || throw(ArgumentError("probes must be ≥ 1")) - if jacobian === :diag || jacobian === :stoch_diag - A = Matrix{Float64}(undef, D, T) - B = Matrix{Float64}(undef, D, T) - - nt = Base.Threads.maxthreadid() - - # scratch vector per thread for xbar - xbufs = [zeros(Float64, D) for _ in 1:nt] - - # scratch per thread for stochastic diag probes - zbufs = jacobian === :stoch_diag ? [zeros(Float64, D) for _ in 1:nt] : nothing + FT = float(eltype(s0)) - # per-thread RNGs for reproducibility + thread safety - rngs = if jacobian === :stoch_diag - # Seed thread RNGs from the provided rng on the main thread (deterministic given rng). + if jacobian === :diag || jacobian === :stoch_diag + # A and B match the array type of S (CuMatrix on GPU, Matrix on CPU). + A = similar(S) + B = similar(S) + + nt = Base.Threads.maxthreadid() + # zbufs for stoch_diag: match array type of s0 so GPU backends work correctly. + zbufs = jacobian === :stoch_diag ? [similar(s0, D) for _ in 1:nt] : nothing + rngs = if jacobian === :stoch_diag seeds = rand(rng, UInt, nt) [MersenneTwister(seeds[i]) for i in 1:nt] else @@ -249,79 +239,59 @@ function deer_update( end @threads for t in 1:T - tid = threadid() - - xbar = if t == 1 - s0 - else - xb = xbufs[tid] - for i in 1:D - xb[i] = S[i, t - 1] - end - xb - end + tid = threadid() + # S[:, t-1] returns a concrete column copy — Vector on CPU, CuVector on GPU. + xbar = t == 1 ? s0 : S[:, t - 1] jt = if jacobian === :diag jac_diag(rec, prep, xbar, t) else jac_diag_stoch( - rec, prep, xbar, t; probes=probes, rng=rngs[tid], zbuf=zbufs[tid] + rec, prep, xbar, t; probes=probes, rng=rngs[tid], zbuf=zbufs[tid], ) end ft = rec.step_fwd(xbar, rec.tape[t]) - - for i in 1:D - A[i, t] = jt[i] - B[i, t] = ft[i] - jt[i] * xbar[i] - end + # Broadcast column assignment: works on both CPU Matrix and GPU CuMatrix. + view(A, :, t) .= jt + view(B, :, t) .= ft .- jt .* xbar end S_new = solve_affine_scan_diag(A, B, s0) - if damping != 1.0 @. S_new = (1 - damping) * S + damping * S_new end - return S_new elseif jacobian === :full - A = Vector{Matrix{Float64}}(undef, T) - b = Matrix{Float64}(undef, D, T) - - xbuf = zeros(Float64, D) # sequential scratch + A_mats = Vector{Matrix{FT}}(undef, T) + b = Matrix{FT}(undef, D, T) + xbuf = zeros(FT, D) for t in 1:T xbar = if t == 1 s0 else - for i in 1:D - xbuf[i] = S[i, t - 1] - end + for i in 1:D; xbuf[i] = S[i, t - 1]; end xbuf end - Jt = jac_full(rec, prep, xbar, t) - A[t] = Jt - - ft = rec.step_fwd(xbar, rec.tape[t]) - tmp = Jt * xbar - for i in 1:D - b[i, t] = ft[i] - tmp[i] - end + Jt = jac_full(rec, prep, xbar, t) + A_mats[t] = Jt + ft = rec.step_fwd(xbar, rec.tape[t]) + tmp = Jt * xbar + for i in 1:D; b[i, t] = ft[i] - tmp[i]; end end - S_new = Matrix{Float64}(undef, D, T) + S_new = Matrix{FT}(undef, D, T) s_prev = copy(s0) for t in 1:T - s_prev = A[t] * s_prev .+ view(b, :, t) + s_prev = A_mats[t] * s_prev .+ view(b, :, t) S_new[:, t] .= s_prev end - if damping != 1.0 S_new .= (1 - damping) .* S .+ damping .* S_new end - return S_new else throw(ArgumentError("jacobian must be :diag, :stoch_diag, or :full")) @@ -329,7 +299,7 @@ function deer_update( end @inline function _maxabs(x) - m = 0.0 + m = zero(real(eltype(x))) for i in eachindex(x) v = abs(x[i]) m = ifelse(v > m, v, m) @@ -338,7 +308,7 @@ end end @inline function _maxabsdiff(x, y) - m = 0.0 + m = zero(real(promote_type(eltype(x), eltype(y)))) for i in eachindex(x, y) v = abs(x[i] - y[i]) m = ifelse(v > m, v, m) @@ -349,16 +319,25 @@ end """ Run DEER iterations until convergence. -Returns: -- S :: D×T matrix trajectory (S[:,t] is state at time t) +Returns the solved trajectory `S :: D×T` (same array type as the initial state, +so passing a `CuVector` for `s0_in` will yield a `CuMatrix`). + +# GPU use +1. Construct a `DensityModel` whose `logdensity` and `grad_logdensity` operate on + GPU arrays. +2. Pass `backend = AutoEnzyme()` (or another GPU-compatible `ADTypes` backend) to + `DEERSampler` so that DEER uses it when building the `TapedRecursion`. +3. Pass a GPU vector as `initial_params` (or use a `DEERSampler` with a float type + matching the GPU array element type). + +The `solve_affine_scan_diag` kernel runs as pure broadcasts and requires no +backend-specific GPU code; only the Jacobian computation via `DI` depends on the +chosen backend. Keywords: -- init: initial trajectory guess, D×T (default: repeat s0) -- tol_abs, tol_rel: stopping tolerances (∞-norm per time step) -- maxiter -- jacobian: :diag or :full -- damping -- return_info +- `init`: initial trajectory guess D×T (default: repeat s0 across columns) +- `tol_abs`, `tol_rel`: stopping tolerances (∞-norm per time step) +- `maxiter`, `jacobian`, `damping`, `probes`, `return_info` """ function solve( rec::TapedRecursion, @@ -374,73 +353,62 @@ function solve( return_info::Bool=false, ) s0 = vec(s0_in) - D = length(s0) - T = length(rec.tape) + D = length(s0) + T = length(rec.tape) maxiter ≥ 1 || throw(ArgumentError("maxiter must be ≥ 1")) tol_abs ≥ 0 || throw(ArgumentError("tol_abs must be ≥ 0")) tol_rel ≥ 0 || throw(ArgumentError("tol_rel must be ≥ 0")) damping > 0 && damping ≤ 1 || throw(ArgumentError("damping must be in (0,1]")) + # Initial trajectory guess: preserve array type of s0 S = if init === nothing - # repeat s0 across time - M = Matrix{Float64}(undef, D, T) - @threads for t in 1:T - M[:, t] .= s0 - end - M + S0 = similar(s0, D, T) + S0 .= reshape(s0, D, 1) + S0 else size(init) == (D, T) || throw(ArgumentError("init must be size (D,T)")) - Matrix{Float64}(init) + copy(init) end prep = if jacobian === :stoch_diag prepare_pushforward(rec, s0) - else + elseif jacobian === :full || jacobian === :diag prepare(rec, s0) + else + nothing end - converged = false + converged = false last_metric = Inf - iters = 0 + iters = 0 for iter in 1:maxiter iters = iter S_new = deer_update( - rec, s0, S, prep; jacobian=jacobian, damping=damping, probes=probes, rng=rng + rec, s0, S, prep; jacobian=jacobian, damping=damping, probes=probes, rng=rng, ) - metric = 0.0 - for t in 1:T - xnew = view(S_new, :, t) - xold = view(S, :, t) - Δ = _maxabsdiff(xnew, xold) - scale = tol_abs + tol_rel * _maxabs(xnew) - metric = max(metric, Δ / scale) - end + # Single pair of reductions over the full D×T matrix — one GPU kernel each. + # This replaces T per-column scalar loops and is efficient on both CPU and GPU. + Δ_max = maximum(abs.(S_new .- S)) + S_scale = tol_abs + tol_rel * maximum(abs.(S_new)) + metric = Δ_max / S_scale - S = S_new + S = S_new last_metric = metric - - if metric ≤ 1.0 - converged = true - break - end + metric ≤ 1 && (converged = true; break) end - if return_info - return S, - ( - converged=converged, - iters=iters, - metric=last_metric, - jacobian=jacobian, - damping=damping, - probes=probes, - ) - else - return S - end + return_info || return S + return S, ( + converged = converged, + iters = iters, + metric = last_metric, + jacobian = jacobian, + damping = damping, + probes = probes, + ) end end # module DEER diff --git a/src/MALA/MALA.jl b/src/MALA/MALA.jl index 8f1e936..a4c6266 100644 --- a/src/MALA/MALA.jl +++ b/src/MALA/MALA.jl @@ -129,15 +129,17 @@ end """ Primal accept indicator for a taped MALA step. -Returns Float64 in {0.0, 1.0} so it can be used as a constant gate. +Returns a float in {0, 1} matching the precision of `u`, for use as a constant +gate in the DEER surrogate step. """ function mala_accept_indicator( logp, gradlogp, x::AbstractVector, ϵ::Real, ξ::AbstractVector, u::Real; cholM=nothing, ) - y = mala_proposal(logp, gradlogp, x, ϵ, ξ; cholM=cholM) + y = mala_proposal(logp, gradlogp, x, ϵ, ξ; cholM=cholM) logα = mala_logα(logp, gradlogp, x, y, ϵ; cholM=cholM) - return (log(u) < logα) ? 1.0 : 0.0 + FP = typeof(float(u)) + return (log(u) < logα) ? one(FP) : zero(FP) end """ @@ -164,7 +166,7 @@ log acceptance ratio `logα = log p(y) + log q(x|y) - log p(x) - log q(y|x)`. The returned `logα` is the un-clamped value; the actual acceptance probability is `min(1, exp(logα))`. This is needed by adaptive step-size schemes (dual averaging). -Returns `(x_next, accepted::Bool, logα::Float64)`. +Returns `(x_next, accepted::Bool, logα)`. """ function mala_step_with_logα( logp, gradlogp, x::AbstractVector, ϵ::Real, ξ::AbstractVector, u::Real; @@ -174,24 +176,24 @@ function mala_step_with_logα( 0.0 < u < 1.0 || throw(ArgumentError("u must be in (0, 1)")) g_x = gradlogp(x) - y = x .+ ϵ .* _apply_M(g_x, cholM) .+ sqrt(2ϵ) .* _apply_L(ξ, cholM) + y = x .+ ϵ .* _apply_M(g_x, cholM) .+ sqrt(2ϵ) .* _apply_L(ξ, cholM) logp_x = logp(x) logp_y = logp(y) - g_y = gradlogp(y) + g_y = gradlogp(y) logq_y_given_x = logq_mala(y, x, g_x, ϵ; cholM=cholM) logq_x_given_y = logq_mala(x, y, g_y, ϵ; cholM=cholM) - logα = (logp_y + logq_x_given_y) - (logp_x + logq_y_given_x) + logα = (logp_y + logq_x_given_y) - (logp_x + logq_y_given_x) accepted = log(u) < logα - x_next = accepted ? y : x - return x_next, accepted, Float64(logα) + x_next = accepted ? y : x + return x_next, accepted, logα end """ Stop-gradient surrogate step used for Jacobians. -`a` (0.0 or 1.0) must be provided as a constant by the DEER machinery. +`a` (0 or 1) must be provided as a constant by the DEER machinery. """ function mala_step_surrogate( logp, gradlogp, x::AbstractVector, ϵ::Real, ξ::AbstractVector, a::Real; @@ -201,104 +203,4 @@ function mala_step_surrogate( return (a .* y) .+ ((1 - a) .* x) end -# Apply mass matrix to a D×N matrix of gradient columns (same math as scalar, -# matrix multiply broadcasts naturally). -_apply_M_batched(G::AbstractMatrix, ::Nothing) = G -_apply_M_batched(G::AbstractMatrix, cholM::Cholesky) = cholM.L * (cholM.L' * G) - -_apply_L_batched(Ξ::AbstractMatrix, ::Nothing) = Ξ -_apply_L_batched(Ξ::AbstractMatrix, cholM::Cholesky) = cholM.L * Ξ - -""" -Compute column-wise M⁻¹-norm squared: `[||R[:,n]||²_{M⁻¹}]_n`. -`R` is D×N; returns a length-N vector. -GPU-compatible (uses `sum(abs2, …; dims=1)` which works on CuArrays). -Note: `cholM` must be `nothing` when using GPU arrays; the triangular solve -`cholM.L \\ R` pulls device arrays to CPU. -""" -function _quad_Minv_batched(R::AbstractMatrix, ::Nothing) - return vec(sum(abs2, R; dims=1)) -end - -function _quad_Minv_batched(R::AbstractMatrix, cholM::Cholesky) - W = cholM.L \ R - return vec(sum(abs2, W; dims=1)) -end - -""" - logq_mala_batched(Y, X, gradlogp_X, ε; cholM=nothing) - -Compute `log q(Y[:,n] | X[:,n])` for all N chains simultaneously. -`Y`, `X`, `gradlogp_X` are D×N; returns a length-N vector. -""" -function logq_mala_batched( - Y::AbstractMatrix, - X::AbstractMatrix, - gradlogp_X::AbstractMatrix, - ε::Real; - cholM=nothing, -) - D = size(X, 1) - μ = X .+ ε .* _apply_M_batched(gradlogp_X, cholM) - R = Y .- μ - q = _quad_Minv_batched(R, cholM) - ldet = _logdet_M(cholM) - return @. -0.5 * q / (2ε) - (D / 2) * log(4π * ε) - 0.5 * ldet -end - -""" - mala_step_batched(logp_batch, gradlogp_batch, X, ε, Ξ, u; cholM=nothing) - -Run one MALA step for N chains simultaneously. - -- `X` :: D×N — current states (one chain per column). -- `Ξ` :: D×N — N(0,I) noise. -- `u` :: length-N — Uniform(0,1) draws. -- `logp_batch(X)` → length-N log-densities. -- `gradlogp_batch(X)` → D×N gradient matrix. - -Returns `(X_next::AbstractMatrix, accepted::AbstractVector)`. - -**GPU use:** pass `CuArray` inputs and GPU-compatible `logp_batch`/`gradlogp_batch`. -Requires `cholM=nothing` for full on-device execution (Cholesky preconditioner involves -a CPU-side triangular solve). Use `eltype(X)` for `ε` to avoid float-type promotions -that would pull data off GPU. -""" -function mala_step_batched( - logp_batch, - gradlogp_batch, - X::AbstractMatrix, - ε::Real, - Ξ::AbstractMatrix, - u::AbstractVector; - cholM=nothing, -) - D, N = size(X) - size(Ξ) == (D, N) || throw(DimensionMismatch("X and Ξ must have the same size")) - length(u) == N || throw(DimensionMismatch("u must have length N = size(X,2)")) - - # Cast ε to element type of X to avoid float-promotion off GPU. - ε_T = eltype(X)(ε) - - G_X = gradlogp_batch(X) # D×N - Y = X .+ ε_T .* _apply_M_batched(G_X, cholM) .+ - sqrt(2 * ε_T) .* _apply_L_batched(Ξ, cholM) # D×N - - lp_X = logp_batch(X) # N - lp_Y = logp_batch(Y) # N - G_Y = gradlogp_batch(Y) # D×N - - lq_YX = logq_mala_batched(Y, X, G_X, ε_T; cholM=cholM) # N - lq_XY = logq_mala_batched(X, Y, G_Y, ε_T; cholM=cholM) # N - - logα = @. (lp_Y + lq_XY) - (lp_X + lq_YX) # N - accepted = @. log(u) < logα # N Bool - - # Select: proposal if accepted, current if rejected. - # reshape to 1×N so it broadcasts against D×N. - mask = reshape(accepted, 1, N) - X_next = @. ifelse(mask, Y, X) # D×N - return X_next, vec(accepted) -end - end # module diff --git a/src/ParallelMCMC.jl b/src/ParallelMCMC.jl index da2799c..591844d 100644 --- a/src/ParallelMCMC.jl +++ b/src/ParallelMCMC.jl @@ -1,6 +1,7 @@ module ParallelMCMC using AbstractMCMC +using CUDA using MCMCChains using LinearAlgebra using Random @@ -10,7 +11,7 @@ include("MALA/MALA.jl") include("DEER/DEER.jl") include("interface.jl") -export DensityModel, BatchedDensityModel +export DensityModel export MALASampler, MALATransition, MALAState export AdaptiveMALASampler, AdaptiveMALATransition, AdaptiveMALAState export DEERSampler, DEERTransition, DEERState, MALATapeElement diff --git a/src/interface.jl b/src/interface.jl index 03ba135..c6b7569 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -21,36 +21,47 @@ struct DensityModel{F,G} <: AbstractMCMC.AbstractModel end """ - MALASampler(epsilon) + MALASampler(epsilon; cholM=nothing) Metropolis-Adjusted Langevin Algorithm sampler with step size `epsilon`. + +Optionally pass `cholM = cholesky(M)` to use a mass matrix `M` as a +preconditioner. The proposal becomes `y = x + ε M ∇logp(x) + √(2ε) L ξ` +where `L` is the Cholesky factor of `M`. """ -struct MALASampler <: AbstractMCMC.AbstractSampler - epsilon::Float64 +struct MALASampler{FP<:AbstractFloat, CM} <: AbstractMCMC.AbstractSampler + epsilon::FP + cholM::CM +end + +function MALASampler(epsilon::Real; cholM=nothing) + epsilon > 0 || throw(ArgumentError("epsilon must be > 0, got $epsilon")) + eps_f = float(epsilon) + return MALASampler{typeof(eps_f), typeof(cholM)}(eps_f, cholM) end -struct MALAState{V<:AbstractVector} +struct MALAState{V<:AbstractVector, L<:Real} x::V - logp::Float64 + logp::L end -struct MALATransition{V<:AbstractVector} +struct MALATransition{V<:AbstractVector, L<:Real} x::V - logp::Float64 + logp::L accepted::Bool end function AbstractMCMC.step( rng::Random.AbstractRNG, model::DensityModel, - sampler::MALASampler; + sampler::MALASampler{FP}; initial_params=nothing, kwargs..., -) +) where {FP} x = if initial_params !== nothing copy(initial_params) else - randn(rng, model.dim) + randn(rng, FP, model.dim) end logp_val = model.logdensity(x) t = MALATransition(x, logp_val, true) @@ -69,11 +80,12 @@ function AbstractMCMC.step( ϵ = sampler.epsilon D = model.dim - ξ = randn(rng, D) + ξ = randn(rng, eltype(x), D) u = rand(rng) x_next, accepted = MALA.mala_step_full( - model.logdensity, model.grad_logdensity, x, ϵ, ξ, u + model.logdensity, model.grad_logdensity, x, ϵ, ξ, u; + cholM=sampler.cholM, ) logp_val = accepted ? model.logdensity(x_next) : state.logp @@ -102,7 +114,7 @@ function AbstractMCMC.bundle_samples( internal_names = [:logp, :accepted] - vals = Matrix{Float64}(undef, N, D) + vals = Matrix{Float64}(undef, N, D) internals = Matrix{Float64}(undef, N, 2) for i in 1:N @@ -126,17 +138,17 @@ One element of the MALA noise tape: a noise vector `ξ ~ N(0,I)` and a uniform scalar `u ~ Uniform(0,1)`. Stored with a concrete vector type `V` for type stability inside `DEER.TapedRecursion`. """ -struct MALATapeElement{V<:AbstractVector} +struct MALATapeElement{FP<:AbstractFloat, V<:AbstractVector{FP}} ξ::V - u::Float64 + u::FP end """ - DEERSampler(epsilon; T, maxiter, tol_abs, tol_rel, jacobian, damping, probes) + DEERSampler(epsilon; T, maxiter, tol_abs, tol_rel, jacobian, damping, probes, cholM, backend) DEER-accelerated MALA sampler. -DEER solves for a trajectory of `T` steps in parallel (O(log T) iterations), +DEER solves for a trajectory of `T` steps in parallel (O(log T) sweep levels), then the AbstractMCMC interface returns samples from that trajectory sequentially. When the trajectory is exhausted a new tape is drawn and DEER re-solves starting from the last state. @@ -149,24 +161,33 @@ re-solves starting from the last state. - `jacobian` — Jacobian mode: `:diag`, `:stoch_diag`, or `:full` (default `:diag`). - `damping` — DEER damping in (0,1] (default 0.5; helps convergence). - `probes` — Hutchinson probes for `:stoch_diag` mode (default 1). +- `cholM` — optional Cholesky factor of a mass matrix `M` (default `nothing` = identity). +- `backend` — AD backend for Jacobian computation (default `AutoMooncake()`). + For GPU execution pass a GPU-compatible backend, e.g. `AutoEnzyme()`. + +# GPU use +The parallel-prefix scan (`solve_affine_scan_diag`) runs as pure array broadcasts +and is array-type-agnostic. To run DEER on GPU: +1. Implement `logdensity` and `grad_logdensity` using GPU-compatible operations. +2. Pass `backend = AutoEnzyme()` (or another GPU-compatible `ADTypes` backend). +3. Pass a GPU vector as `initial_params` to `sample`. # Parallel chains Both `MALASampler` and `DEERSampler` are compatible with `AbstractMCMC.sample(model, sampler, MCMCThreads(), N, nchains)`. Each chain -has its own immutable state and RNG so there is no shared mutable data. -Note that each `DEERSampler` chain internally uses `Base.Threads.@threads` -for the parallel scan, so running many chains via `MCMCThreads()` on a machine -with few threads may over-subscribe the thread pool. +has its own immutable state so there is no shared mutable data. """ -struct DEERSampler <: AbstractMCMC.AbstractSampler - epsilon::Float64 +struct DEERSampler{FP<:AbstractFloat, CM, AD} <: AbstractMCMC.AbstractSampler + epsilon::FP T::Int maxiter::Int - tol_abs::Float64 - tol_rel::Float64 + tol_abs::FP + tol_rel::FP jacobian::Symbol - damping::Float64 + damping::FP probes::Int + cholM::CM + backend::AD end function DEERSampler( @@ -178,11 +199,16 @@ function DEERSampler( jacobian::Symbol=:diag, damping::Real=0.5, probes::Int=1, + cholM=nothing, + backend=DEER.DEFAULT_BACKEND, ) - return DEERSampler( - Float64(epsilon), T, maxiter, - Float64(tol_abs), Float64(tol_rel), - jacobian, Float64(damping), probes, + epsilon > 0 || throw(ArgumentError("epsilon must be > 0, got $epsilon")) + eps_f = float(epsilon) + FP = typeof(eps_f) + return DEERSampler{FP, typeof(cholM), typeof(backend)}( + eps_f, T, maxiter, + FP(tol_abs), FP(tol_rel), + jacobian, FP(damping), probes, cholM, backend, ) end @@ -195,35 +221,36 @@ State for a `DEERSampler` chain. - `tape` — noise tape used for that solve. - `t` — index within `trajectory` of the last returned sample (1-indexed). """ -struct DEERState{V<:AbstractVector, M<:AbstractMatrix} +struct DEERState{V<:AbstractVector, L<:Real, M<:AbstractMatrix} x::V - logp::Float64 + logp::L trajectory::M - tape::Vector{MALATapeElement{V}} + tape::Vector{<:MALATapeElement} t::Int end """ One DEER sample: parameter vector `x` and its log-density `logp`. """ -struct DEERTransition{V<:AbstractVector} +struct DEERTransition{V<:AbstractVector, L<:Real} x::V - logp::Float64 + logp::L end function _build_mala_deer_rec( - model::DensityModel, ε::Float64, tape::Vector{<:MALATapeElement} + model::DensityModel, ε::Real, tape::Vector{<:MALATapeElement}; + cholM=nothing, backend=DEER.DEFAULT_BACKEND, ) logp = model.logdensity gradlogp = model.grad_logdensity - step_fwd = (x, te) -> MALA.mala_step_taped(logp, gradlogp, x, ε, te.ξ, te.u) - step_lin = (x, te, a) -> MALA.mala_step_surrogate(logp, gradlogp, x, ε, te.ξ, a) - consts = (x, te) -> (MALA.mala_accept_indicator(logp, gradlogp, x, ε, te.ξ, te.u),) + step_fwd = (x, te) -> MALA.mala_step_taped(logp, gradlogp, x, ε, te.ξ, te.u; cholM=cholM) + step_lin = (x, te, a) -> MALA.mala_step_surrogate(logp, gradlogp, x, ε, te.ξ, a; cholM=cholM) + consts = (x, te) -> (MALA.mala_accept_indicator(logp, gradlogp, x, ε, te.ξ, te.u; cholM=cholM),) return DEER.TapedRecursion( step_fwd, step_lin, tape; - consts=consts, const_example=(0.0,), + consts=consts, const_example=(0.0,), backend=backend, ) end @@ -235,8 +262,14 @@ function _deer_solve_new_tape( ) D = model.dim T = sampler.T - tape = [MALATapeElement(randn(rng, D), rand(rng)) for _ in 1:T] - rec = _build_mala_deer_rec(model, sampler.epsilon, tape) + FP = typeof(sampler.epsilon) + # Generate noise on the same device as x0: generate on CPU then copy via copyto!. + # copyto!(CuVector, Vector) performs a host-to-device transfer in CUDA.jl. + tape = map(1:T) do _ + ξ = copyto!(similar(x0, D), randn(rng, FP, D)) + MALATapeElement(ξ, FP(rand(rng))) + end + rec = _build_mala_deer_rec(model, sampler.epsilon, tape; cholM=sampler.cholM, backend=sampler.backend) S = DEER.solve( rec, x0; tol_abs = sampler.tol_abs, @@ -253,22 +286,21 @@ end function AbstractMCMC.step( rng::Random.AbstractRNG, model::DensityModel, - sampler::DEERSampler; + sampler::DEERSampler{FP}; initial_params=nothing, kwargs..., -) +) where {FP} x0 = if initial_params !== nothing copy(initial_params) else - randn(rng, model.dim) + randn(rng, FP, model.dim) end - S, tape = _deer_solve_new_tape(rng, model, sampler, x0) - - x1 = S[:, 1] - logp1 = Float64(model.logdensity(x1)) - trans = DEERTransition(x1, logp1) - state = DEERState(x1, logp1, S, tape, 1) + S, tape = _deer_solve_new_tape(rng, model, sampler, x0) + x1 = S[:, 1] + logp1 = model.logdensity(x1) + trans = DEERTransition(x1, logp1) + state = DEERState(x1, logp1, S, tape, 1) return trans, state end @@ -285,8 +317,8 @@ function AbstractMCMC.step( if t_next <= T # Consume the next cached sample from the trajectory. x_new = state.trajectory[:, t_next] - logp_new = Float64(model.logdensity(x_new)) - trans = DEERTransition(x_new, logp_new) + logp_new = model.logdensity(x_new) + trans = DEERTransition(x_new, logp_new) new_state = DEERState(x_new, logp_new, state.trajectory, state.tape, t_next) return trans, new_state else @@ -294,8 +326,8 @@ function AbstractMCMC.step( x0 = state.trajectory[:, T] S_new, tape = _deer_solve_new_tape(rng, model, sampler, x0) x_new = S_new[:, 1] - logp_new = Float64(model.logdensity(x_new)) - trans = DEERTransition(x_new, logp_new) + logp_new = model.logdensity(x_new) + trans = DEERTransition(x_new, logp_new) new_state = DEERState(x_new, logp_new, S_new, tape, 1) return trans, new_state end @@ -325,7 +357,7 @@ function AbstractMCMC.bundle_samples( internals = Matrix{Float64}(undef, N, 1) for i in 1:N - vals[i, :] .= samples[i].x + vals[i, :] .= samples[i].x internals[i, 1] = samples[i].logp end @@ -336,42 +368,6 @@ function AbstractMCMC.bundle_samples( ) end -""" - BatchedDensityModel(logdensity_batch, grad_logdensity_batch, dim) - -Wraps batched log-density and gradient functions for use with -`MALA.mala_step_batched`. - -- `logdensity_batch(X::AbstractMatrix) -> AbstractVector` — - `X` is D×N; returns a length-N vector of log-densities. -- `grad_logdensity_batch(X::AbstractMatrix) -> AbstractMatrix` — - `X` is D×N; returns a D×N gradient matrix (column `n` = gradient for chain `n`). -- `dim::Int` — dimensionality D of the parameter space. - -# GPU use -Pass `CuMatrix` inputs and implement the two functions to return `CuArray`s. -Requires `cholM=nothing` in `mala_step_batched` for fully on-device execution. - -# Example -```julia -logp_b(X) = vec(sum(x -> -0.5x^2, X; dims=1)) # standard normal, N chains -gradlogp_b(X) = -X -bmodel = BatchedDensityModel(logp_b, gradlogp_b, 3) - -X = randn(3, 100) # 100 chains of dimension 3 -Xi = randn(3, 100) -u = rand(100) -X_next, accepted = MALA.mala_step_batched(bmodel.logdensity_batch, - bmodel.grad_logdensity_batch, - X, 0.1, Xi, u) -``` -""" -struct BatchedDensityModel{F,G} <: AbstractMCMC.AbstractModel - logdensity_batch::F - grad_logdensity_batch::G - dim::Int -end - """ AdaptiveMALASampler(epsilon_init; n_warmup, target_accept, gamma, t0, kappa, cholM) @@ -414,13 +410,13 @@ Load `LogDensityProblems` (and optionally `LogDensityProblemsAD`) then use the `DensityModel(ld)` constructor to wrap any `LogDensityProblems`-compatible model (including Turing/DynamicPPL models) directly. """ -struct AdaptiveMALASampler{CM} <: AbstractMCMC.AbstractSampler - epsilon_init::Float64 +struct AdaptiveMALASampler{FP<:AbstractFloat, CM} <: AbstractMCMC.AbstractSampler + epsilon_init::FP n_warmup::Int - target_accept::Float64 - gamma::Float64 - t0::Float64 - kappa::Float64 + target_accept::FP + gamma::FP + t0::FP + kappa::FP cholM::CM end @@ -433,48 +429,58 @@ function AdaptiveMALASampler( kappa::Real=0.75, cholM=nothing, ) - return AdaptiveMALASampler( - Float64(epsilon_init), n_warmup, - Float64(target_accept), Float64(gamma), Float64(t0), Float64(kappa), + epsilon_init > 0 || throw(ArgumentError("epsilon_init must be > 0, got $epsilon_init")) + 0 < target_accept < 1 || throw(ArgumentError("target_accept must be in (0,1), got $target_accept")) + gamma > 0 || throw(ArgumentError("gamma must be > 0, got $gamma")) + t0 > 0 || throw(ArgumentError("t0 must be > 0, got $t0")) + 0.5 < kappa <= 1.0 || throw(ArgumentError("kappa must be in (0.5, 1], got $kappa")) + + eps_f = float(epsilon_init) + FP = typeof(eps_f) + return AdaptiveMALASampler{FP, typeof(cholM)}( + eps_f, n_warmup, + FP(target_accept), FP(gamma), FP(t0), FP(kappa), cholM, ) end -struct AdaptiveMALAState{V<:AbstractVector} +struct AdaptiveMALAState{V<:AbstractVector, FP<:AbstractFloat} x::V - logp::Float64 - epsilon::Float64 # instantaneous step size ε_m - epsilon_bar::Float64 # smoothed step size ε̄_m (frozen after warmup) - H_bar::Float64 # dual-average statistic H̄_m - step::Int # warmup step counter (0 = initialisation) + logp::FP + epsilon::FP # instantaneous step size ε_m + epsilon_bar::FP # smoothed step size ε̄_m (frozen after warmup) + H_bar::FP # dual-average statistic H̄_m + step::Int # warmup step counter (0 = initialisation) end -struct AdaptiveMALATransition{V<:AbstractVector} +struct AdaptiveMALATransition{V<:AbstractVector, FP<:AbstractFloat} x::V - logp::Float64 + logp::FP accepted::Bool - step_size::Float64 # ε used for this step + step_size::FP # ε used for this step is_warmup::Bool end function _dual_average_update( - epsilon_init::Float64, - epsilon_bar::Float64, - H_bar::Float64, + epsilon_init::FP, + epsilon_bar::FP, + H_bar::FP, m::Int, - logα::Float64, - sampler::AdaptiveMALASampler, -) - α = min(1.0, exp(logα)) + logα::FP, + sampler::AdaptiveMALASampler{FP}, +) where {FP<:AbstractFloat} + α = min(one(FP), exp(logα)) δ = sampler.target_accept γ = sampler.gamma t0 = sampler.t0 κ = sampler.kappa - μ = log(10.0 * epsilon_init) # fixed shrinkage target + μ = log(10 * epsilon_init) # fixed shrinkage target - H_bar_new = (1.0 - 1.0 / (m + t0)) * H_bar + (1.0 / (m + t0)) * (δ - α) - log_ε = μ - sqrt(Float64(m)) / γ * H_bar_new - log_ε_bar_new = Float64(m)^(-κ) * log_ε + (1.0 - Float64(m)^(-κ)) * log(epsilon_bar) + inv_mt0 = one(FP) / (FP(m) + t0) + H_bar_new = (one(FP) - inv_mt0) * H_bar + inv_mt0 * (δ - α) + log_ε = μ - sqrt(FP(m)) / γ * H_bar_new + mk = FP(m)^(-κ) + log_ε_bar_new = mk * log_ε + (one(FP) - mk) * log(epsilon_bar) return exp(log_ε), exp(log_ε_bar_new), H_bar_new end @@ -482,33 +488,33 @@ end function AbstractMCMC.step( rng::Random.AbstractRNG, model::DensityModel, - sampler::AdaptiveMALASampler; + sampler::AdaptiveMALASampler{FP}; initial_params=nothing, kwargs..., -) +) where {FP} x = if initial_params !== nothing copy(initial_params) else - randn(rng, model.dim) + randn(rng, FP, model.dim) end - logp_val = Float64(model.logdensity(x)) + logp_val = FP(model.logdensity(x)) trans = AdaptiveMALATransition(x, logp_val, true, sampler.epsilon_init, true) - state = AdaptiveMALAState(x, logp_val, sampler.epsilon_init, sampler.epsilon_init, 0.0, 0) + state = AdaptiveMALAState(x, logp_val, sampler.epsilon_init, sampler.epsilon_init, zero(FP), 0) return trans, state end function AbstractMCMC.step( rng::Random.AbstractRNG, model::DensityModel, - sampler::AdaptiveMALASampler, + sampler::AdaptiveMALASampler{FP}, state::AdaptiveMALAState; kwargs..., -) - D = model.dim - in_warmup = state.step < sampler.n_warmup - ε = in_warmup ? state.epsilon : state.epsilon_bar +) where {FP} + D = model.dim + in_warmup = state.step < sampler.n_warmup + ε = in_warmup ? state.epsilon : state.epsilon_bar - ξ = randn(rng, D) + ξ = randn(rng, eltype(state.x), D) u = rand(rng) x_next, accepted, logα = MALA.mala_step_with_logα( @@ -516,12 +522,15 @@ function AbstractMCMC.step( cholM=sampler.cholM, ) - logp_next = accepted ? Float64(model.logdensity(x_next)) : state.logp + logp_next = accepted ? FP(model.logdensity(x_next)) : state.logp # Dual-average adaptation (only during warmup) - m_new = state.step + 1 + m_new = state.step + 1 ε_new, ε_bar_new, H_bar_new = if in_warmup - _dual_average_update(sampler.epsilon_init, state.epsilon_bar, state.H_bar, m_new, logα, sampler) + _dual_average_update( + sampler.epsilon_init, state.epsilon_bar, state.H_bar, + m_new, FP(logα), sampler, + ) else state.epsilon, state.epsilon_bar, state.H_bar end @@ -558,9 +567,9 @@ function AbstractMCMC.bundle_samples( s = samples[i] vals[i, :] .= s.x internals[i, 1] = s.logp - internals[i, 2] = s.accepted ? 1.0 : 0.0 + internals[i, 2] = s.accepted ? 1.0 : 0.0 internals[i, 3] = s.step_size - internals[i, 4] = s.is_warmup ? 1.0 : 0.0 + internals[i, 4] = s.is_warmup ? 1.0 : 0.0 end return MCMCChains.Chains( diff --git a/test/Project.toml b/test/Project.toml index 693918c..8b27d91 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,8 @@ [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" diff --git a/test/test-Batched-MALA.jl b/test/test-Batched-MALA.jl deleted file mode 100644 index 5edcacd..0000000 --- a/test/test-Batched-MALA.jl +++ /dev/null @@ -1,111 +0,0 @@ -using Test -using Random -using LinearAlgebra -using Statistics - -using ParallelMCMC -const MALA = ParallelMCMC.MALA - -# Standard normal target: logp(x) = -0.5 ||x||², grad = -x -logp_batch(X) = vec(-0.5 .* sum(abs2, X; dims=1)) # D×N → N -gradlogp_batch(X) = -X # D×N → D×N - -@testset "BatchedDensityModel construction" begin - bm = BatchedDensityModel(logp_batch, gradlogp_batch, 4) - @test bm.dim == 4 - - X = randn(4, 10) - @test length(bm.logdensity_batch(X)) == 10 - @test size(bm.grad_logdensity_batch(X)) == (4, 10) -end - -@testset "mala_step_batched output shapes" begin - rng = MersenneTwister(1) - D, N = 3, 20 - X = randn(rng, D, N) - Ξ = randn(rng, D, N) - u = rand(rng, N) - - X_next, accepted = MALA.mala_step_batched(logp_batch, gradlogp_batch, X, 0.1, Ξ, u) - - @test size(X_next) == (D, N) - @test length(accepted) == N - @test eltype(X_next) == Float64 -end - -@testset "mala_step_batched single accepted chain matches scalar step" begin - rng = MersenneTwister(42) - D = 3 - - x = randn(rng, D) - ξ = randn(rng, D) - u = rand(rng) - - # Scalar step - x_scalar, acc_scalar = MALA.mala_step_full( - x -> -0.5 * dot(x, x), x -> -x, x, 0.1, ξ, u, - ) - - # Batched step with N=1 - X = reshape(copy(x), D, 1) - Xi = reshape(copy(ξ), D, 1) - u_vec = [u] - X_next, accepted_vec = MALA.mala_step_batched( - logp_batch, gradlogp_batch, X, 0.1, Xi, u_vec, - ) - - @test isapprox(vec(X_next), x_scalar; atol=1e-12) - @test Bool(accepted_vec[1]) == acc_scalar -end - -@testset "mala_step_batched preserves rejected chains" begin - rng = MersenneTwister(5) - D, N = 4, 50 - - X = randn(rng, D, N) - Ξ = randn(rng, D, N) - # Force all rejections: u very close to 1 so log(u) ≈ 0 > any logα - u = ones(N) .* (1 - 1e-15) - - X_next, accepted = MALA.mala_step_batched(logp_batch, gradlogp_batch, X, 0.1, Ξ, u) - - # With u ≈ 1, log(u) ≈ 0 which is > most logα; almost all should be rejected - # (standard normal is well-behaved so this should reject everything or near it) - n_accepted = sum(accepted) - # At this u value essentially all should reject - @test n_accepted < N -end - -@testset "mala_step_batched DimensionMismatch errors" begin - X = randn(3, 5) - Ξ_bad = randn(3, 4) # wrong N - u_good = rand(5) - u_bad = rand(4) # wrong N - - @test_throws DimensionMismatch MALA.mala_step_batched( - logp_batch, gradlogp_batch, X, 0.1, Ξ_bad, u_good - ) - @test_throws DimensionMismatch MALA.mala_step_batched( - logp_batch, gradlogp_batch, X, 0.1, randn(3, 5), u_bad - ) -end - -@testset "mala_step_batched stationary distribution (many chains)" begin - rng = MersenneTwister(2025) - D, N, T = 3, 200, 2_000 - - X = randn(rng, D, N) - - for _ in 1:T - Ξ = randn(rng, D, N) - u = rand(rng, N) - X, _ = MALA.mala_step_batched(logp_batch, gradlogp_batch, X, 0.1, Ξ, u) - end - - # Pool all chains after burn-in (already burned in by running T steps) - mu = vec(mean(X; dims=2)) - vars = vec(var(X; dims=2)) - - @test maximum(abs.(mu)) < 0.15 - @test maximum(abs.(vars .- 1.0)) < 0.20 -end diff --git a/test/test-GPU-DEER.jl b/test/test-GPU-DEER.jl new file mode 100644 index 0000000..8501b41 --- /dev/null +++ b/test/test-GPU-DEER.jl @@ -0,0 +1,111 @@ +using Test +using Random +using LinearAlgebra + +using ParallelMCMC +import ParallelMCMC: DEER +import ADTypes + +cuda_ok = try + using CUDA + CUDA.functional() +catch + false +end + +if !cuda_ok + @warn "CUDA not available or not functional — skipping GPU DEER tests" +else + +# logp / gradlogp that work on any AbstractVector (CPU or GPU) +_logp_gpu(x) = -0.5f0 * dot(x, x) +_gradlogp_gpu(x) = -x + +@testset "solve_affine_scan_diag CPU vs GPU" begin + rng = MersenneTwister(1) + D, T = 4, 16 + + A_cpu = randn(rng, Float32, D, T) + B_cpu = randn(rng, Float32, D, T) + s0_cpu = randn(rng, Float32, D) + + S_cpu = DEER.solve_affine_scan_diag(A_cpu, B_cpu, s0_cpu) + + A_gpu = CUDA.cu(A_cpu) + B_gpu = CUDA.cu(B_cpu) + s0_gpu = CUDA.cu(s0_cpu) + + S_gpu = DEER.solve_affine_scan_diag(A_gpu, B_gpu, s0_gpu) + + @test S_gpu isa CUDA.CuMatrix + @test size(S_gpu) == (D, T) + @test Array(S_gpu) ≈ S_cpu atol=1e-5 +end + +enzyme_ok = try + using Enzyme + true +catch + false +end + +if !enzyme_ok + @warn "Enzyme not available — skipping GPU AD DEER tests" +else + + +@testset "DEER.solve on GPU with AutoEnzyme" begin + rng = MersenneTwister(42) + D, T = 3, 8 + + # Build a minimal TapedRecursion directly (avoids interface overhead) + ε = 0.05f0 + tape = map(1:T) do _ + ξ_gpu = CUDA.cu(randn(rng, Float32, D)) + u = Float32(rand(rng)) + MALATapeElement(ξ_gpu, u) + end + + rec = DEER.TapedRecursion( + (x, te) -> MALA.mala_step_taped(_logp_gpu, _gradlogp_gpu, x, ε, te.ξ, te.u), + (x, te, a) -> MALA.mala_step_surrogate(_logp_gpu, _gradlogp_gpu, x, ε, te.ξ, a), + tape; + consts = (x, te) -> (MALA.mala_accept_indicator(_logp_gpu, _gradlogp_gpu, x, ε, te.ξ, te.u),), + const_example = (0.0f0,), + backend = ADTypes.AutoEnzyme(), + ) + + s0_gpu = CUDA.cu(randn(rng, Float32, D)) + S, info = DEER.solve(rec, s0_gpu; maxiter=20, return_info=true) + + @test S isa CUDA.CuMatrix + @test size(S) == (D, T) + @test all(isfinite, Array(S)) +end + +@testset "DEERSampler interface on GPU" begin + rng = MersenneTwister(7) + D = 3 + + model = DensityModel(_logp_gpu, _gradlogp_gpu, D) + sampler = DEERSampler(0.05f0; T=8, maxiter=20, backend=ADTypes.AutoEnzyme()) + + x0_gpu = CUDA.cu(randn(rng, Float32, D)) + trans, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler; initial_params=x0_gpu) + + @test trans isa DEERTransition + @test state isa DEERState + @test trans.x isa CUDA.CuVector + @test length(trans.x) == D + @test isfinite(Float32(trans.logp)) + @test size(state.trajectory) == (D, 8) + @test state.trajectory isa CUDA.CuMatrix + + # Second step consumes from the cached trajectory + trans2, state2 = ParallelMCMC.AbstractMCMC.step(rng, model, sampler, state) + @test state2.t == 2 + @test trans2.x isa CUDA.CuVector +end + +end # enzyme_ok +end # cuda_ok From e8c4996d5732e5f3ed756df4e14ed72feaabffa8 Mon Sep 17 00:00:00 2001 From: Ryan Senne Date: Mon, 23 Mar 2026 12:59:10 -0400 Subject: [PATCH 5/8] get GPU tests passing --- test/LocalPreferences.toml | 3 ++ test/Project.toml | 4 ++ test/test-GPU-MALA.jl | 97 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 104 insertions(+) create mode 100644 test/LocalPreferences.toml create mode 100644 test/test-GPU-MALA.jl diff --git a/test/LocalPreferences.toml b/test/LocalPreferences.toml new file mode 100644 index 0000000..312f730 --- /dev/null +++ b/test/LocalPreferences.toml @@ -0,0 +1,3 @@ +[CUDA_Runtime_jll] +local = "true" +version = "12.8" diff --git a/test/Project.toml b/test/Project.toml index 693918c..e4bee46 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" @@ -8,3 +9,6 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[extras] +CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" diff --git a/test/test-GPU-MALA.jl b/test/test-GPU-MALA.jl new file mode 100644 index 0000000..d81ef67 --- /dev/null +++ b/test/test-GPU-MALA.jl @@ -0,0 +1,97 @@ +using Test +using Random +using LinearAlgebra +using Statistics + +using ParallelMCMC +const MALA = ParallelMCMC.MALA + +import CUDA + +# Check if a real GPU is accessible by attempting a small allocation. +# CUDA.functional() only checks that the library loads, not that a device exists. +const CUDA_AVAILABLE = try + CUDA.CuArray([1f0]) + true +catch + false +end + +if !CUDA_AVAILABLE + @info "No CUDA GPU detected — skipping GPU tests." +else + +# Standard normal target: logp(X) = -0.5 ||X||², grad = -X +# Broadcasting and sum work on CuArrays, so these are GPU-compatible. +logp_batch(X) = vec(-0.5f0 .* sum(abs2, X; dims=1)) +gradlogp_batch(X) = -X + +@testset "GPU inputs stay on device" begin + D, N = 4, 32 + + X = CUDA.randn(Float32, D, N) + Ξ = CUDA.randn(Float32, D, N) + u = CUDA.rand(Float32, N) + + X_next, accepted = MALA.mala_step_batched(logp_batch, gradlogp_batch, X, 0.1f0, Ξ, u) + + @test X_next isa CUDA.CuArray + @test accepted isa CUDA.CuArray + @test size(X_next) == (D, N) + @test length(accepted) == N +end + +@testset "GPU and CPU produce identical results (same seed)" begin + D, N = 3, 16 + + # Generate random numbers on CPU, copy to GPU. + rng = MersenneTwister(42) + X_cpu = randn(rng, Float32, D, N) + Ξ_cpu = randn(rng, Float32, D, N) + u_cpu = rand(rng, Float32, N) + + X_gpu = CUDA.CuArray(X_cpu) + Ξ_gpu = CUDA.CuArray(Ξ_cpu) + u_gpu = CUDA.CuArray(u_cpu) + + X_next_cpu, acc_cpu = MALA.mala_step_batched(logp_batch, gradlogp_batch, X_cpu, 0.1f0, Ξ_cpu, u_cpu) + X_next_gpu, acc_gpu = MALA.mala_step_batched(logp_batch, gradlogp_batch, X_gpu, 0.1f0, Ξ_gpu, u_gpu) + + @test Array(X_next_gpu) ≈ X_next_cpu atol=1f-5 + @test Array(acc_gpu) == acc_cpu +end + +@testset "GPU stationary distribution (standard normal)" begin + D, N, T = 3, 512, 1_000 + + X = CUDA.randn(Float32, D, N) + + for _ in 1:T + Ξ = CUDA.randn(Float32, D, N) + u = CUDA.rand(Float32, N) + X, _ = MALA.mala_step_batched(logp_batch, gradlogp_batch, X, 0.1f0, Ξ, u) + end + + X_cpu = Array(X) + mu = vec(mean(X_cpu; dims=2)) + vars = vec(var(X_cpu; dims=2)) + + @test maximum(abs.(mu)) < 0.15 + @test maximum(abs.(vars .- 1f0)) < 0.25 +end + +@testset "GPU DimensionMismatch errors" begin + X = CUDA.randn(Float32, 3, 5) + Ξ_bad = CUDA.randn(Float32, 3, 4) # wrong N + u_ok = CUDA.rand(Float32, 5) + u_bad = CUDA.rand(Float32, 4) # wrong N + + @test_throws DimensionMismatch MALA.mala_step_batched( + logp_batch, gradlogp_batch, X, 0.1f0, Ξ_bad, u_ok, + ) + @test_throws DimensionMismatch MALA.mala_step_batched( + logp_batch, gradlogp_batch, X, 0.1f0, CUDA.randn(Float32, 3, 5), u_bad, + ) +end + +end # if CUDA_AVAILABLE From 6141d2a5694c358735779dc29b2866a7bcb350a4 Mon Sep 17 00:00:00 2001 From: Ryan Senne Date: Mon, 23 Mar 2026 14:36:08 -0400 Subject: [PATCH 6/8] ADD MORE GPU TESTS --- test/test-GPU-MALA.jl | 166 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 155 insertions(+), 11 deletions(-) diff --git a/test/test-GPU-MALA.jl b/test/test-GPU-MALA.jl index d81ef67..b21fc4e 100644 --- a/test/test-GPU-MALA.jl +++ b/test/test-GPU-MALA.jl @@ -21,11 +21,23 @@ if !CUDA_AVAILABLE @info "No CUDA GPU detected — skipping GPU tests." else -# Standard normal target: logp(X) = -0.5 ||X||², grad = -X -# Broadcasting and sum work on CuArrays, so these are GPU-compatible. +# Standard normal: logp = -0.5 ||x||², grad = -x logp_batch(X) = vec(-0.5f0 .* sum(abs2, X; dims=1)) gradlogp_batch(X) = -X +# Scaled normal: dimension i has variance σᵢ² = i (so std = sqrt(i)) +# logp = -0.5 sum_i x_i²/i, grad_i = -x_i/i +function logp_scaled(X) + D = size(X, 1) + scales = CUDA.CuArray(Float32.(1:D)) # D-vector on GPU + return vec(-0.5f0 .* sum(X .^ 2 ./ scales; dims=1)) +end +function gradlogp_scaled(X) + D = size(X, 1) + scales = CUDA.CuArray(Float32.(1:D)) + return -X ./ scales +end + @testset "GPU inputs stay on device" begin D, N = 4, 32 @@ -37,14 +49,24 @@ gradlogp_batch(X) = -X @test X_next isa CUDA.CuArray @test accepted isa CUDA.CuArray - @test size(X_next) == (D, N) + @test size(X_next) == (D, N) @test length(accepted) == N end +@testset "GPU eltype preserved (Float32 in → Float32 out)" begin + D, N = 5, 20 + X = CUDA.randn(Float32, D, N) + Ξ = CUDA.randn(Float32, D, N) + u = CUDA.rand(Float32, N) + + X_next, _ = MALA.mala_step_batched(logp_batch, gradlogp_batch, X, 0.1f0, Ξ, u) + + @test eltype(X_next) == Float32 +end + @testset "GPU and CPU produce identical results (same seed)" begin D, N = 3, 16 - # Generate random numbers on CPU, copy to GPU. rng = MersenneTwister(42) X_cpu = randn(rng, Float32, D, N) Ξ_cpu = randn(rng, Float32, D, N) @@ -61,11 +83,85 @@ end @test Array(acc_gpu) == acc_cpu end +@testset "GPU single chain (N=1) matches scalar mala_step_full" begin + D = 4 + rng = MersenneTwister(7) + x = randn(rng, Float32, D) + ξ = randn(rng, Float32, D) + u = rand(rng, Float32) + + # Scalar step on CPU + logp_scalar = x -> -0.5f0 * dot(x, x) + grad_scalar = x -> -x + x_scalar, acc_scalar = MALA.mala_step_full(logp_scalar, grad_scalar, x, 0.1f0, ξ, u) + + # Batched N=1 on GPU + X_gpu = CUDA.CuArray(reshape(copy(x), D, 1)) + Ξ_gpu = CUDA.CuArray(reshape(copy(ξ), D, 1)) + u_gpu = CUDA.CuArray([u]) + + X_next_gpu, acc_gpu = MALA.mala_step_batched(logp_batch, gradlogp_batch, X_gpu, 0.1f0, Ξ_gpu, u_gpu) + + @test Array(vec(X_next_gpu)) ≈ x_scalar atol=1f-5 + @test Bool(Array(acc_gpu)[1]) == acc_scalar +end + +@testset "GPU rejected chains are unchanged" begin + D, N = 4, 64 + X = CUDA.randn(Float32, D, N) + Ξ = CUDA.randn(Float32, D, N) + # u very close to 1 ⟹ log(u) ≈ 0, forces rejection whenever logα ≤ 0. + # Some chains may still be accepted (logα > 0 when proposal lands at higher density), + # but for every rejected chain X_next must exactly equal X. + u = CUDA.fill(1f0 - 1f-6, N) + + X_next, accepted = MALA.mala_step_batched(logp_batch, gradlogp_batch, X, 0.1f0, Ξ, u) + + acc_cpu = Array(accepted) + X_cpu = Array(X) + Xn_cpu = Array(X_next) + rejected = .!acc_cpu + + # Rejected chains must be exactly unchanged. + @test Xn_cpu[:, rejected] == X_cpu[:, rejected] + # With u ≈ 1 most should be rejected (not a strict all-zero check). + @test sum(acc_cpu) < N +end + +@testset "GPU force all acceptances (u ≈ 0, tiny ε)" begin + D, N = 4, 64 + X = CUDA.randn(Float32, D, N) + Ξ = CUDA.randn(Float32, D, N) + # u very close to 0 ⟹ log(u) → -∞ < any logα ⟹ accept + u = CUDA.fill(1f-7, N) + + _, accepted = MALA.mala_step_batched(logp_batch, gradlogp_batch, X, 0.001f0, Ξ, u) + + @test sum(Array(accepted)) == N +end + +@testset "GPU acceptance rate in reasonable range" begin + # With ε=0.1 and standard normal, empirical acceptance rate should be + # well above 0 and below 1. + D, N, T = 5, 512, 200 + + n_accepted = 0 + X = CUDA.randn(Float32, D, N) + for _ in 1:T + Ξ = CUDA.randn(Float32, D, N) + u = CUDA.rand(Float32, N) + X, acc = MALA.mala_step_batched(logp_batch, gradlogp_batch, X, 0.1f0, Ξ, u) + n_accepted += sum(Array(acc)) + end + + rate = n_accepted / (T * N) + @test 0.3 < rate < 0.99 +end + @testset "GPU stationary distribution (standard normal)" begin D, N, T = 3, 512, 1_000 X = CUDA.randn(Float32, D, N) - for _ in 1:T Ξ = CUDA.randn(Float32, D, N) u = CUDA.rand(Float32, N) @@ -73,18 +169,66 @@ end end X_cpu = Array(X) - mu = vec(mean(X_cpu; dims=2)) - vars = vec(var(X_cpu; dims=2)) + @test maximum(abs.(vec(mean(X_cpu; dims=2)))) < 0.15 + @test maximum(abs.(vec(var(X_cpu; dims=2)) .- 1f0)) < 0.25 +end + +@testset "GPU stationary distribution (scaled normal)" begin + # Dimension i has variance i; after burn-in each row should have var ≈ i + D, N, T = 5, 512, 2_000 + + X = CUDA.randn(Float32, D, N) + for _ in 1:T + Ξ = CUDA.randn(Float32, D, N) + u = CUDA.rand(Float32, N) + X, _ = MALA.mala_step_batched(logp_scaled, gradlogp_scaled, X, 0.05f0, Ξ, u) + end + + X_cpu = Array(X) + vars = vec(var(X_cpu; dims=2)) # should ≈ [1, 2, 3, 4, 5] + target = Float32.(1:D) - @test maximum(abs.(mu)) < 0.15 - @test maximum(abs.(vars .- 1f0)) < 0.25 + @test maximum(abs.(vec(mean(X_cpu; dims=2)))) < 0.3 + @test maximum(abs.(vars .- target) ./ target) < 0.3 # within 30% relative +end + +@testset "GPU logq_mala_batched matches CPU" begin + D, N = 4, 16 + rng = MersenneTwister(3) + Y_cpu = randn(rng, Float32, D, N) + X_cpu = randn(rng, Float32, D, N) + G_cpu = randn(rng, Float32, D, N) + + lq_cpu = MALA.logq_mala_batched(Y_cpu, X_cpu, G_cpu, 0.1f0) + + Y_gpu = CUDA.CuArray(Y_cpu) + X_gpu = CUDA.CuArray(X_cpu) + G_gpu = CUDA.CuArray(G_cpu) + + lq_gpu = MALA.logq_mala_batched(Y_gpu, X_gpu, G_gpu, 0.1f0) + + @test Array(lq_gpu) ≈ lq_cpu atol=1f-4 +end + +@testset "GPU large-scale (D=128, N=1024)" begin + # Smoke test: just verify it runs and returns correct shapes without OOM. + D, N = 128, 1024 + X = CUDA.randn(Float32, D, N) + Ξ = CUDA.randn(Float32, D, N) + u = CUDA.rand(Float32, N) + + X_next, accepted = MALA.mala_step_batched(logp_batch, gradlogp_batch, X, 0.05f0, Ξ, u) + + @test size(X_next) == (D, N) + @test length(accepted) == N + @test eltype(X_next) == Float32 end @testset "GPU DimensionMismatch errors" begin X = CUDA.randn(Float32, 3, 5) - Ξ_bad = CUDA.randn(Float32, 3, 4) # wrong N + Ξ_bad = CUDA.randn(Float32, 3, 4) u_ok = CUDA.rand(Float32, 5) - u_bad = CUDA.rand(Float32, 4) # wrong N + u_bad = CUDA.rand(Float32, 4) @test_throws DimensionMismatch MALA.mala_step_batched( logp_batch, gradlogp_batch, X, 0.1f0, Ξ_bad, u_ok, From a2738a021a8102eccc21051cbfc7e9f8e76dd77e Mon Sep 17 00:00:00 2001 From: Ryan Senne Date: Mon, 23 Mar 2026 19:42:29 -0400 Subject: [PATCH 7/8] type preservation fixes --- .gitignore | 1 + src/MALA/MALA.jl | 12 +++++++----- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 4d3d00e..2b0c91d 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ coverage docs/build/ env node_modules +LocalPreferences.toml diff --git a/src/MALA/MALA.jl b/src/MALA/MALA.jl index 8f1e936..8b7b947 100644 --- a/src/MALA/MALA.jl +++ b/src/MALA/MALA.jl @@ -15,7 +15,7 @@ function _quad_Minv(r, cholM::Cholesky) return dot(w, w) end -_logdet_M(::Nothing) = 0.0 +_logdet_M(::Nothing) = false # Bool promotes to any numeric type without widening _logdet_M(cholM::Cholesky) = logdet(cholM) """ @@ -30,10 +30,11 @@ function logq_mala( y::AbstractVector, x::AbstractVector, gradlogp_x::AbstractVector, ϵ::Real; cholM=nothing, ) + T = typeof(ϵ) μ = x .+ ϵ .* _apply_M(gradlogp_x, cholM) d = length(x) r = y .- μ - return -0.5 * _quad_Minv(r, cholM) / (2ϵ) - (d / 2) * log(4π * ϵ) - 0.5 * _logdet_M(cholM) + return -T(0.5) * _quad_Minv(r, cholM) / (2ϵ) - (T(d) / 2) * log(T(4π) * ϵ) - T(0.5) * _logdet_M(cholM) end """ @@ -164,7 +165,7 @@ log acceptance ratio `logα = log p(y) + log q(x|y) - log p(x) - log q(y|x)`. The returned `logα` is the un-clamped value; the actual acceptance probability is `min(1, exp(logα))`. This is needed by adaptive step-size schemes (dual averaging). -Returns `(x_next, accepted::Bool, logα::Float64)`. +Returns `(x_next, accepted::Bool, logα)`. """ function mala_step_with_logα( logp, gradlogp, x::AbstractVector, ϵ::Real, ξ::AbstractVector, u::Real; @@ -186,7 +187,7 @@ function mala_step_with_logα( accepted = log(u) < logα x_next = accepted ? y : x - return x_next, accepted, Float64(logα) + return x_next, accepted, logα end """ @@ -238,12 +239,13 @@ function logq_mala_batched( ε::Real; cholM=nothing, ) + T = typeof(ε) D = size(X, 1) μ = X .+ ε .* _apply_M_batched(gradlogp_X, cholM) R = Y .- μ q = _quad_Minv_batched(R, cholM) ldet = _logdet_M(cholM) - return @. -0.5 * q / (2ε) - (D / 2) * log(4π * ε) - 0.5 * ldet + return @. -T(0.5) * q / (2ε) - (T(D) / 2) * log(T(4π) * ε) - T(0.5) * ldet end """ From 19d39613a9a377eb41f04e17c77c090df09376cd Mon Sep 17 00:00:00 2001 From: Ryan Senne Date: Fri, 27 Mar 2026 10:05:21 -0400 Subject: [PATCH 8/8] remove LocalPreferences.toml --- test/LocalPreferences.toml | 3 --- 1 file changed, 3 deletions(-) delete mode 100644 test/LocalPreferences.toml diff --git a/test/LocalPreferences.toml b/test/LocalPreferences.toml deleted file mode 100644 index 312f730..0000000 --- a/test/LocalPreferences.toml +++ /dev/null @@ -1,3 +0,0 @@ -[CUDA_Runtime_jll] -local = "true" -version = "12.8"