diff --git a/.gitignore b/.gitignore index 4d3d00e..2b0c91d 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ coverage docs/build/ env node_modules +LocalPreferences.toml diff --git a/Project.toml b/Project.toml index dd88cdc..0c7b24e 100644 --- a/Project.toml +++ b/Project.toml @@ -4,22 +4,31 @@ version = "0.1.0" authors = ["Ryan Senne "] [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" -Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +[weakdeps] +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" + +[extensions] +LogDensityProblemsExt = "LogDensityProblems" + [compat] +ADTypes = "1.21.0" AbstractMCMC = "5.10.0" +CUDA = "5.11.0" DifferentiationInterface = "0.7.13" -Distributions = "0.25.122" +Enzyme = "0.13.131" LinearAlgebra = "1.12.0" -LogExpFunctions = "0.3.29" +LogDensityProblems = "2" MCMCChains = "7.7.0" Mooncake = "0.4.192" Random = "1.11.0" diff --git a/ext/LogDensityProblemsExt.jl b/ext/LogDensityProblemsExt.jl new file mode 100644 index 0000000..84effbc --- /dev/null +++ b/ext/LogDensityProblemsExt.jl @@ -0,0 +1,54 @@ +module LogDensityProblemsExt + +using ParallelMCMC +import LogDensityProblems + +""" + DensityModel(ld) + +Construct a `DensityModel` from any object implementing the +[LogDensityProblems](https://github.com/tpapp/LogDensityProblems.jl) interface. + +`ld` must support: +- `LogDensityProblems.capabilities(ld)` returning at least + `LogDensityProblems.LogDensityOrder{1}` (i.e. gradient available). +- `LogDensityProblems.dimension(ld)` → `Int` +- `LogDensityProblems.logdensity_and_gradient(ld, x)` → `(logp, grad)` + +# Turing.jl / DynamicPPL example +```julia +using Turing, LogDensityProblems, LogDensityProblemsAD, Mooncake, ParallelMCMC, MCMCChains + +@model function mymodel(y) + μ ~ Normal(0, 1) + y ~ Normal(μ, 0.5) +end + +obs = 1.5 +ld = DynamicPPL.LogDensityFunction(mymodel(obs)) +ldg = LogDensityProblemsAD.ADgradient(Mooncake.Extras.MooncakeAD(), ld) + +model = DensityModel(ldg) +chain = sample(model, AdaptiveMALASampler(0.3; n_warmup=500), 2_000; + chain_type=MCMCChains.Chains, progress=true) +``` +""" +function ParallelMCMC.DensityModel(ld) + caps = LogDensityProblems.capabilities(ld) + caps isa LogDensityProblems.LogDensityOrder{0} && + error("LogDensityProblems model must support gradients (LogDensityOrder{1} or higher). " * + "Wrap it with LogDensityProblemsAD.ADgradient first.") + + dim = LogDensityProblems.dimension(ld) + + logp(x) = LogDensityProblems.logdensity(ld, x) + + function gradlogp(x) + _, g = LogDensityProblems.logdensity_and_gradient(ld, x) + return g + end + + return ParallelMCMC.DensityModel(logp, gradlogp, dim) +end + +end # module diff --git a/src/DEER/DEER.jl b/src/DEER/DEER.jl index bc31819..bc09a89 100644 --- a/src/DEER/DEER.jl +++ b/src/DEER/DEER.jl @@ -3,36 +3,43 @@ module DEER using Base.Threads: @threads, threadid using LinearAlgebra using DifferentiationInterface +using ADTypes: ADTypes, AbstractADType import Mooncake: Mooncake using Random const DI = DifferentiationInterface -backend = DI.AutoMooncake(; config=nothing) +const DEFAULT_BACKEND = DI.AutoMooncake(; config=nothing) """ Deterministic recursion driven by a pre-generated tape. - `step_fwd(x, tape_t)`: exact forward transition (may include MH accept/reject). - `step_lin(x, tape_t, c...)`: surrogate used only for Jacobians. -- `consts(x, tape_t) -> Tuple`: returns constants `c...` passed into `step_lin` as DI.Constant. +- `consts(x, tape_t) -> Tuple`: returns constants `c...` passed into `step_lin`. - `const_example`: example tuple of constants, used in `prepare`. +- `backend`: AD backend (any `ADTypes.AbstractADType`); defaults to `AutoMooncake`. + Pass `AutoEnzyme()` or another GPU-compatible backend for GPU execution. """ -struct TapedRecursion{Ff,Fl,Fc,Tt,Ce} +struct TapedRecursion{Ff,Fl,Fc,Tt,Ce,AD<:AbstractADType} step_fwd::Ff step_lin::Fl consts::Fc tape::Vector{Tt} const_example::Ce + backend::AD end -"Backward-compatible constructor: uses the same step for forward + Jacobian, and no extra constants." -TapedRecursion(step, tape::Vector) = TapedRecursion(step, step, (_x, _tt)->(), tape, ()) +"Backward-compatible constructor: uses the same step for forward + Jacobian, no constants." +function TapedRecursion(step, tape::Vector) + return TapedRecursion(step, step, (_x, _tt) -> (), tape, (), DEFAULT_BACKEND) +end "Main constructor." function TapedRecursion( - step_fwd, step_lin, tape::Vector; consts=(_x, _tt)->(), const_example=() + step_fwd, step_lin, tape::Vector; + consts=(_x, _tt) -> (), const_example=(), backend::AbstractADType=DEFAULT_BACKEND, ) - return TapedRecursion(step_fwd, step_lin, consts, tape, const_example) + return TapedRecursion(step_fwd, step_lin, consts, tape, const_example, backend) end # Stable callable for DI (Jacobian always taken w.r.t. step_lin) @@ -43,64 +50,58 @@ end "Prepare Jacobian for the surrogate step (reusable across t as long as tape element type is stable)." function prepare(rec::TapedRecursion, x0::AbstractVector) - f = StepWithTape(rec) + f = StepWithTape(rec) cs = rec.const_example return DI.prepare_jacobian( - f, backend, x0, DI.Constant(rec.tape[1]), (DI.Constant(c) for c in cs)... + f, rec.backend, x0, DI.Constant(rec.tape[1]), (DI.Constant(c) for c in cs)... ) end "Prepare pushforward (JVP) for the surrogate step_lin." function prepare_pushforward(rec::TapedRecursion, x0::AbstractVector) - f = StepWithTape(rec) - cs = rec.const_example - # tx is a tuple of tangents; we supply an example tangent + f = StepWithTape(rec) + cs = rec.const_example tx0 = (zero(x0),) return DI.prepare_pushforward( - f, backend, x0, tx0, DI.Constant(rec.tape[1]), (DI.Constant(c) for c in cs)... + f, rec.backend, x0, tx0, DI.Constant(rec.tape[1]), (DI.Constant(c) for c in cs)... ) end "Full Jacobian of surrogate step_lin(x, tape_t, consts...) w.r.t. x." function jac_full(rec::TapedRecursion, prep, x::AbstractVector, t::Int) - f = StepWithTape(rec) + f = StepWithTape(rec) cs = rec.consts(x, rec.tape[t]) return DI.jacobian( - f, prep, backend, x, DI.Constant(rec.tape[t]), (DI.Constant(c) for c in cs)... + f, prep, rec.backend, x, DI.Constant(rec.tape[t]), (DI.Constant(c) for c in cs)... ) end -"Diagonal of the surrogate Jacobian (debug/correctness mode; not scalable if it computes full J)." +"Diagonal of the surrogate Jacobian via full Jacobian computation." function jac_diag(rec::TapedRecursion, prep, x::AbstractVector, t::Int) return diag(jac_full(rec, prep, x, t)) end -@inline function _rademacher!(z::AbstractVector, rng::AbstractRNG) - for i in eachindex(z) - z[i] = rand(rng, Bool) ? 1.0 : -1.0 +@inline function _rademacher!(z::AbstractVector{T}, rng::AbstractRNG) where {T} + D = length(z) + bits = rand(rng, Bool, D) + vals = Vector{T}(undef, D) + for i in 1:D + vals[i] = bits[i] ? one(T) : -one(T) end + copyto!(z, vals) # host-to-device when z is a CuVector; plain copy otherwise return z end function _jvp_step_lin( rec::TapedRecursion, prep_pf, x::AbstractVector, t::Int, v::AbstractVector ) - f = StepWithTape(rec) + f = StepWithTape(rec) cs = rec.consts(x, rec.tape[t]) - - # pushforward expects tx as a tuple of tangents tx = (v,) - res = DI.pushforward( - f, - prep_pf, - backend, - x, - tx, - DI.Constant(rec.tape[t]), - (DI.Constant(c) for c in cs)..., + f, prep_pf, rec.backend, x, tx, + DI.Constant(rec.tape[t]), (DI.Constant(c) for c in cs)..., ) - return res isa Tuple ? res[end] : res end @@ -112,7 +113,7 @@ diag(J) ≈ (1/K) * Σ_k z^(k) ⊙ (J z^(k)), z_i ∈ {±1} Keywords: - probes: number of random probe vectors K (typical 1–4) - rng: RNG for probes -- zbuf, jbuf: optional preallocated buffers (length D) to avoid allocations +- zbuf: optional preallocated buffer (length D) to avoid allocations """ function jac_diag_stoch( rec::TapedRecursion, @@ -123,87 +124,88 @@ function jac_diag_stoch( rng::AbstractRNG=Random.default_rng(), zbuf::Union{Nothing,AbstractVector}=nothing, ) - D = length(x) + D = length(x) + FT = float(eltype(x)) probes ≥ 1 || throw(ArgumentError("probes must be ≥ 1")) - z = zbuf === nothing ? Vector{Float64}(undef, D) : zbuf + z = zbuf === nothing ? Vector{FT}(undef, D) : zbuf length(z) == D || throw(ArgumentError("zbuf must have length D")) - d = zeros(Float64, D) - - for k in 1:probes + # d matches the array type of z (CuVector on GPU, Vector on CPU). + d = fill!(similar(z, D), zero(FT)) + for _ in 1:probes _rademacher!(z, rng) - jv = _jvp_step_lin(rec, prep_pf, x, t, z) # J*z - for i in 1:D - d[i] += z[i] * jv[i] - end - end - - invK = 1.0 / probes - for i in 1:D - d[i] *= invK + jv = _jvp_step_lin(rec, prep_pf, x, t, z) + @. d += z * jv end + d .*= one(FT) / probes return d end """ -Solve s_t = a_t .* s_{t-1} + b_t for t=1..T using an associative inclusive scan. + solve_affine_scan_diag(A, B, s0) + +Solve the diagonal affine recurrence `s_t = A[:,t] .* s_{t-1} + B[:,t]` for t=1..T +via an associative inclusive parallel-prefix (Hillis-Steele) scan. + +**Complexity**: O(log T) sequential sweep levels; each level is a single broadcast +over all T columns — no loops or `@threads`. The implementation is +**array-type-agnostic**: it works identically on CPU `Matrix`, GPU `CuMatrix`, or +any other `AbstractMatrix`. Inputs: -- A :: D×T matrix, A[:,t] = a_t -- B :: D×T matrix, B[:,t] = b_t -- s0 :: length-D vector +- `A` :: D×T — per-step multiplicative coefficients +- `B` :: D×T — per-step additive offsets +- `s0` :: length-D initial state Returns: -- S :: D×T matrix, S[:,t] = s_t +- `S` :: D×T (same type as A) where `S[:,t] = s_t` """ function solve_affine_scan_diag(A::AbstractMatrix, B::AbstractMatrix, s0::AbstractVector) D, T = size(A) size(B) == (D, T) || throw(ArgumentError("B must have the same size as A")) - length(s0) == D || throw(ArgumentError("s0 length must match size(A,1)")) + length(s0) == D || throw(ArgumentError("s0 length must match size(A,1)")) - α = Matrix{Float64}(A) # working prefixes - β = Matrix{Float64}(B) + # Work on copies; swap buffers each level to avoid allocating inside the loop. + α = copy(A) + β = copy(B) αnew = similar(α) βnew = similar(β) offset = 1 while offset < T - @threads for t in 1:T - if t > offset - for i in 1:D - ai = α[i, t] - αnew[i, t] = ai * α[i, t - offset] - βnew[i, t] = ai * β[i, t - offset] + β[i, t] - end - else - for i in 1:D - αnew[i, t] = α[i, t] - βnew[i, t] = β[i, t] - end - end - end + # Columns 1:offset are unchanged at this level. + αnew[:, 1:offset] .= α[:, 1:offset] + βnew[:, 1:offset] .= β[:, 1:offset] + + # Columns offset+1:T combine with their left neighbour at distance `offset`. + # αnew[:,t] = α[:,t] * α[:,t-offset] + # βnew[:,t] = α[:,t] * β[:,t-offset] + β[:,t] + αnew[:, offset+1:T] .= α[:, offset+1:T] .* α[:, 1:T-offset] + βnew[:, offset+1:T] .= α[:, offset+1:T] .* β[:, 1:T-offset] .+ β[:, offset+1:T] + α, αnew = αnew, α β, βnew = βnew, β offset <<= 1 end - S = Matrix{Float64}(undef, D, T) - @threads for t in 1:T - for i in 1:D - S[i, t] = α[i, t] * s0[i] + β[i, t] - end - end + # Apply initial condition: S[:,t] = α[:,t] * s0 + β[:,t] + S = similar(A, D, T) + S .= α .* reshape(s0, D, 1) .+ β return S end """ -Compute one DEER update given a trajectory guess S (D×T). -Returns S_new (D×T). +Compute one DEER update given a trajectory guess `S` (D×T). +Returns `S_new` (D×T). + +The per-timestep Jacobian computation uses the AD backend stored in `rec.backend`. +For GPU execution, construct the `TapedRecursion` with a GPU-compatible backend +(e.g. `AutoEnzyme()`). Keywords: -- jacobian: :diag or :full -- damping: in (0,1] +- `jacobian`: `:diag`, `:stoch_diag`, or `:full` +- `damping`: in (0,1] """ function deer_update( rec::TapedRecursion, @@ -215,28 +217,24 @@ function deer_update( probes::Int=1, rng::AbstractRNG=Random.default_rng(), ) - s0 = vec(s0_in) + s0 = vec(s0_in) D, T = size(S) - length(s0) == D || throw(ArgumentError("S must be D×T with D=length(s0)")) - T == length(rec.tape) || throw(ArgumentError("size(S,2) must match length(rec.tape)")) + length(s0) == D || throw(ArgumentError("S must be D×T with D=length(s0)")) + T == length(rec.tape) || throw(ArgumentError("size(S,2) must match length(rec.tape)")) damping > 0 && damping ≤ 1 || throw(ArgumentError("damping must be in (0,1]")) - probes ≥ 1 || throw(ArgumentError("probes must be ≥ 1")) - - if jacobian === :diag || jacobian === :stoch_diag - A = Matrix{Float64}(undef, D, T) - B = Matrix{Float64}(undef, D, T) - - nt = Base.Threads.maxthreadid() + probes ≥ 1 || throw(ArgumentError("probes must be ≥ 1")) - # scratch vector per thread for xbar - xbufs = [zeros(Float64, D) for _ in 1:nt] + FT = float(eltype(s0)) - # scratch per thread for stochastic diag probes - zbufs = jacobian === :stoch_diag ? [zeros(Float64, D) for _ in 1:nt] : nothing - - # per-thread RNGs for reproducibility + thread safety - rngs = if jacobian === :stoch_diag - # Seed thread RNGs from the provided rng on the main thread (deterministic given rng). + if jacobian === :diag || jacobian === :stoch_diag + # A and B match the array type of S (CuMatrix on GPU, Matrix on CPU). + A = similar(S) + B = similar(S) + + nt = Base.Threads.maxthreadid() + # zbufs for stoch_diag: match array type of s0 so GPU backends work correctly. + zbufs = jacobian === :stoch_diag ? [similar(s0, D) for _ in 1:nt] : nothing + rngs = if jacobian === :stoch_diag seeds = rand(rng, UInt, nt) [MersenneTwister(seeds[i]) for i in 1:nt] else @@ -244,83 +242,59 @@ function deer_update( end @threads for t in 1:T - tid = threadid() - - xbar = if t == 1 - s0 - else - xb = xbufs[tid] - for i in 1:D - xb[i] = S[i, t - 1] - end - xb - end + tid = threadid() + # S[:, t-1] returns a concrete column copy — Vector on CPU, CuVector on GPU. + xbar = t == 1 ? s0 : S[:, t - 1] jt = if jacobian === :diag jac_diag(rec, prep, xbar, t) else jac_diag_stoch( - rec, prep, xbar, t; probes=probes, rng=rngs[tid], zbuf=zbufs[tid] + rec, prep, xbar, t; probes=probes, rng=rngs[tid], zbuf=zbufs[tid], ) end ft = rec.step_fwd(xbar, rec.tape[t]) - - for i in 1:D - A[i, t] = jt[i] - B[i, t] = ft[i] - jt[i] * xbar[i] - end + # Broadcast column assignment: works on both CPU Matrix and GPU CuMatrix. + view(A, :, t) .= jt + view(B, :, t) .= ft .- jt .* xbar end S_new = solve_affine_scan_diag(A, B, s0) - if damping != 1.0 - @threads for t in 1:T - for i in 1:D - S_new[i, t] = (1 - damping) * S[i, t] + damping * S_new[i, t] - end - end + @. S_new = (1 - damping) * S + damping * S_new end - return S_new elseif jacobian === :full - A = Vector{Matrix{Float64}}(undef, T) - b = Matrix{Float64}(undef, D, T) - - xbuf = zeros(Float64, D) # sequential scratch + A_mats = Vector{Matrix{FT}}(undef, T) + b = Matrix{FT}(undef, D, T) + xbuf = zeros(FT, D) for t in 1:T xbar = if t == 1 s0 else - for i in 1:D - xbuf[i] = S[i, t - 1] - end + for i in 1:D; xbuf[i] = S[i, t - 1]; end xbuf end - Jt = jac_full(rec, prep, xbar, t) - A[t] = Jt - - ft = rec.step_fwd(xbar, rec.tape[t]) - tmp = Jt * xbar - for i in 1:D - b[i, t] = ft[i] - tmp[i] - end + Jt = jac_full(rec, prep, xbar, t) + A_mats[t] = Jt + ft = rec.step_fwd(xbar, rec.tape[t]) + tmp = Jt * xbar + for i in 1:D; b[i, t] = ft[i] - tmp[i]; end end - S_new = Matrix{Float64}(undef, D, T) + S_new = Matrix{FT}(undef, D, T) s_prev = copy(s0) for t in 1:T - s_prev = A[t] * s_prev .+ view(b, :, t) + s_prev = A_mats[t] * s_prev .+ view(b, :, t) S_new[:, t] .= s_prev end - if damping != 1.0 S_new .= (1 - damping) .* S .+ damping .* S_new end - return S_new else throw(ArgumentError("jacobian must be :diag, :stoch_diag, or :full")) @@ -328,7 +302,7 @@ function deer_update( end @inline function _maxabs(x) - m = 0.0 + m = zero(real(eltype(x))) for i in eachindex(x) v = abs(x[i]) m = ifelse(v > m, v, m) @@ -337,7 +311,7 @@ end end @inline function _maxabsdiff(x, y) - m = 0.0 + m = zero(real(promote_type(eltype(x), eltype(y)))) for i in eachindex(x, y) v = abs(x[i] - y[i]) m = ifelse(v > m, v, m) @@ -348,16 +322,25 @@ end """ Run DEER iterations until convergence. -Returns: -- S :: D×T matrix trajectory (S[:,t] is state at time t) +Returns the solved trajectory `S :: D×T` (same array type as the initial state, +so passing a `CuVector` for `s0_in` will yield a `CuMatrix`). + +# GPU use +1. Construct a `DensityModel` whose `logdensity` and `grad_logdensity` operate on + GPU arrays. +2. Pass `backend = AutoEnzyme()` (or another GPU-compatible `ADTypes` backend) to + `DEERSampler` so that DEER uses it when building the `TapedRecursion`. +3. Pass a GPU vector as `initial_params` (or use a `DEERSampler` with a float type + matching the GPU array element type). + +The `solve_affine_scan_diag` kernel runs as pure broadcasts and requires no +backend-specific GPU code; only the Jacobian computation via `DI` depends on the +chosen backend. Keywords: -- init: initial trajectory guess, D×T (default: repeat s0) -- tol_abs, tol_rel: stopping tolerances (∞-norm per time step) -- maxiter -- jacobian: :diag or :full -- damping -- return_info +- `init`: initial trajectory guess D×T (default: repeat s0 across columns) +- `tol_abs`, `tol_rel`: stopping tolerances (∞-norm per time step) +- `maxiter`, `jacobian`, `damping`, `probes`, `return_info` """ function solve( rec::TapedRecursion, @@ -373,73 +356,62 @@ function solve( return_info::Bool=false, ) s0 = vec(s0_in) - D = length(s0) - T = length(rec.tape) + D = length(s0) + T = length(rec.tape) maxiter ≥ 1 || throw(ArgumentError("maxiter must be ≥ 1")) tol_abs ≥ 0 || throw(ArgumentError("tol_abs must be ≥ 0")) tol_rel ≥ 0 || throw(ArgumentError("tol_rel must be ≥ 0")) damping > 0 && damping ≤ 1 || throw(ArgumentError("damping must be in (0,1]")) + # Initial trajectory guess: preserve array type of s0 S = if init === nothing - # repeat s0 across time - M = Matrix{Float64}(undef, D, T) - @threads for t in 1:T - M[:, t] .= s0 - end - M + S0 = similar(s0, D, T) + S0 .= reshape(s0, D, 1) + S0 else size(init) == (D, T) || throw(ArgumentError("init must be size (D,T)")) - Matrix{Float64}(init) + copy(init) end prep = if jacobian === :stoch_diag prepare_pushforward(rec, s0) - else + elseif jacobian === :full || jacobian === :diag prepare(rec, s0) + else + nothing end - converged = false + converged = false last_metric = Inf - iters = 0 + iters = 0 for iter in 1:maxiter iters = iter S_new = deer_update( - rec, s0, S, prep; jacobian=jacobian, damping=damping, probes=probes, rng=rng + rec, s0, S, prep; jacobian=jacobian, damping=damping, probes=probes, rng=rng, ) - metric = 0.0 - for t in 1:T - xnew = view(S_new, :, t) - xold = view(S, :, t) - Δ = _maxabsdiff(xnew, xold) - scale = tol_abs + tol_rel * _maxabs(xnew) - metric = max(metric, Δ / scale) - end + # Single pair of reductions over the full D×T matrix — one GPU kernel each. + # This replaces T per-column scalar loops and is efficient on both CPU and GPU. + Δ_max = maximum(abs.(S_new .- S)) + S_scale = tol_abs + tol_rel * maximum(abs.(S_new)) + metric = Δ_max / S_scale - S = S_new + S = S_new last_metric = metric - - if metric ≤ 1.0 - converged = true - break - end + metric ≤ 1 && (converged = true; break) end - if return_info - return S, - ( - converged=converged, - iters=iters, - metric=last_metric, - jacobian=jacobian, - damping=damping, - probes=probes, - ) - else - return S - end + return_info || return S + return S, ( + converged = converged, + iters = iters, + metric = last_metric, + jacobian = jacobian, + damping = damping, + probes = probes, + ) end end # module DEER diff --git a/src/MALA/MALA.jl b/src/MALA/MALA.jl index 81d5141..1de5d9a 100644 --- a/src/MALA/MALA.jl +++ b/src/MALA/MALA.jl @@ -2,20 +2,39 @@ module MALA using Random, LinearAlgebra +# Preconditioner dispatch helpers. cholM is either nothing (identity) or a Cholesky factor. +_apply_M(g, ::Nothing) = g +_apply_M(g, cholM::Cholesky) = cholM.L * (cholM.L' * g) + +_apply_L(ξ, ::Nothing) = ξ +_apply_L(ξ, cholM::Cholesky) = cholM.L * ξ + +_quad_Minv(r, ::Nothing) = dot(r, r) +function _quad_Minv(r, cholM::Cholesky) + w = cholM.L \ r + return dot(w, w) +end + +_logdet_M(::Nothing) = false # Bool promotes to any numeric type without widening +_logdet_M(cholM::Cholesky) = logdet(cholM) + """ Compute the log of the MALA proposal density q(y | x). -We use the Gaussian: - y ~ Normal(x + ϵ∇logp(x), 2ϵ I) + +With mass matrix M (passed as `cholM = cholesky(M)`): + y ~ Normal(x + ϵ M ∇logp(x), 2ϵ M) + +With `cholM=nothing` (default), uses identity M = I. """ function logq_mala( - y::AbstractVector, x::AbstractVector, gradlogp_x::AbstractVector, ϵ::Real + y::AbstractVector, x::AbstractVector, gradlogp_x::AbstractVector, ϵ::Real; + cholM=nothing, ) - μ = x .+ ϵ .* gradlogp_x - # log N(y; μ, 2ϵ I) up to constant: - # -0.5 * ||y-μ||^2 / (2ϵ) - (d/2) log(4πϵ) + T = typeof(ϵ) + μ = x .+ ϵ .* _apply_M(gradlogp_x, cholM) d = length(x) r = y .- μ - return -0.5 * dot(r, r) / (2ϵ) - (d / 2) * log(4π * ϵ) + return -T(0.5) * _quad_Minv(r, cholM) / (2ϵ) - (T(d) / 2) * log(T(4π) * ϵ) - T(0.5) * _logdet_M(cholM) end """ @@ -28,32 +47,30 @@ Inputs: - ϵ: step size - ξ: N(0, I) noise vector (tape) - u: Uniform(0,1) scalar (tape) +- cholM: optional Cholesky factor of mass matrix M (default: identity) Returns: - x_next """ function mala_step_taped( - logp, gradlogp, x::AbstractVector, ϵ::Real, ξ::AbstractVector, u::Real + logp, gradlogp, x::AbstractVector, ϵ::Real, ξ::AbstractVector, u::Real; + cholM=nothing, ) - @assert length(x) == length(ξ) - @assert 0.0 < u < 1.0 + length(x) == length(ξ) || throw(DimensionMismatch("x and ξ must have the same length")) + 0.0 < u < 1.0 || throw(ArgumentError("u must be in (0, 1)")) g_x = gradlogp(x) - # Proposal - y = x .+ ϵ .* g_x .+ sqrt(2ϵ) .* ξ + y = x .+ ϵ .* _apply_M(g_x, cholM) .+ sqrt(2ϵ) .* _apply_L(ξ, cholM) - # Compute log acceptance ratio: - # log α = logp(y) + log q(x|y) - logp(x) - log q(y|x) logp_x = logp(x) logp_y = logp(y) g_y = gradlogp(y) - logq_y_given_x = logq_mala(y, x, g_x, ϵ) - logq_x_given_y = logq_mala(x, y, g_y, ϵ) + logq_y_given_x = logq_mala(y, x, g_x, ϵ; cholM=cholM) + logq_x_given_y = logq_mala(x, y, g_y, ϵ; cholM=cholM) logα = (logp_y + logq_x_given_y) - (logp_x + logq_y_given_x) - # Accept/reject using tape u return (log(u) < logα) ? y : x end @@ -74,59 +91,218 @@ function run_mala_sequential_taped( x0::AbstractVector, ϵ::Real, ξs::Vector{<:AbstractVector}, - us::AbstractVector, + us::AbstractVector; + cholM=nothing, ) T = length(us) - @assert length(ξs) == T + length(ξs) == T || throw(DimensionMismatch("ξs and us must have the same length")) xs = Vector{typeof(x0)}(undef, T + 1) xs[1] = copy(x0) x = copy(x0) for t in 1:T - x = mala_step_taped(logp, gradlogp, x, ϵ, ξs[t], us[t]) + x = mala_step_taped(logp, gradlogp, x, ϵ, ξs[t], us[t]; cholM=cholM) xs[t + 1] = copy(x) end return xs end -"Compute the MALA proposal map y = x + ϵ∇logp(x) + sqrt(2ϵ) ξ." -function mala_proposal(logp, gradlogp, x::AbstractVector, ϵ::Real, ξ::AbstractVector) - @assert length(x) == length(ξ) +"Compute the MALA proposal map y = x + ϵ M ∇logp(x) + √(2ϵ) L ξ." +function mala_proposal( + logp, gradlogp, x::AbstractVector, ϵ::Real, ξ::AbstractVector; cholM=nothing, +) + length(x) == length(ξ) || throw(DimensionMismatch("x and ξ must have the same length")) g_x = gradlogp(x) - return x .+ ϵ .* g_x .+ sqrt(2ϵ) .* ξ + return x .+ ϵ .* _apply_M(g_x, cholM) .+ sqrt(2ϵ) .* _apply_L(ξ, cholM) end "Compute log acceptance ratio logα(x→y) for MALA." -function mala_logα(logp, gradlogp, x::AbstractVector, y::AbstractVector, ϵ::Real) +function mala_logα( + logp, gradlogp, x::AbstractVector, y::AbstractVector, ϵ::Real; cholM=nothing, +) g_x = gradlogp(x) g_y = gradlogp(y) logp_x = logp(x) logp_y = logp(y) - logq_y_given_x = logq_mala(y, x, g_x, ϵ) - logq_x_given_y = logq_mala(x, y, g_y, ϵ) + logq_y_given_x = logq_mala(y, x, g_x, ϵ; cholM=cholM) + logq_x_given_y = logq_mala(x, y, g_y, ϵ; cholM=cholM) return (logp_y + logq_x_given_y) - (logp_x + logq_y_given_x) end """ Primal accept indicator for a taped MALA step. -Returns Float64 in {0.0, 1.0} so it can be used as a constant gate. +Returns a float in {0, 1} matching the precision of `u`, for use as a constant +gate in the DEER surrogate step. """ function mala_accept_indicator( - logp, gradlogp, x::AbstractVector, ϵ::Real, ξ::AbstractVector, u::Real + logp, gradlogp, x::AbstractVector, ϵ::Real, ξ::AbstractVector, u::Real; + cholM=nothing, +) + y = mala_proposal(logp, gradlogp, x, ϵ, ξ; cholM=cholM) + logα = mala_logα(logp, gradlogp, x, y, ϵ; cholM=cholM) + FP = typeof(float(u)) + return (log(u) < logα) ? one(FP) : zero(FP) +end + +""" +One taped MALA step, returning both the next state and the accept flag. + +This is the efficient entry point: it evaluates `gradlogp` exactly twice (once at `x`, +once at the proposal `y`), compared to calling `mala_accept_indicator` + `mala_step_taped` +separately which evaluates it five times. + +Returns `(x_next, accepted::Bool)`. +""" +function mala_step_full( + logp, gradlogp, x::AbstractVector, ϵ::Real, ξ::AbstractVector, u::Real; + cholM=nothing, +) + x_next, accepted, _ = mala_step_with_logα(logp, gradlogp, x, ϵ, ξ, u; cholM=cholM) + return x_next, accepted +end + +""" +One taped MALA step, returning the next state, the accept flag, **and** the raw +log acceptance ratio `logα = log p(y) + log q(x|y) - log p(x) - log q(y|x)`. + +The returned `logα` is the un-clamped value; the actual acceptance probability is +`min(1, exp(logα))`. This is needed by adaptive step-size schemes (dual averaging). + +Returns `(x_next, accepted::Bool, logα)`. +""" +function mala_step_with_logα( + logp, gradlogp, x::AbstractVector, ϵ::Real, ξ::AbstractVector, u::Real; + cholM=nothing, ) - y = mala_proposal(logp, gradlogp, x, ϵ, ξ) - logα = mala_logα(logp, gradlogp, x, y, ϵ) - return (log(u) < logα) ? 1.0 : 0.0 + length(x) == length(ξ) || throw(DimensionMismatch("x and ξ must have the same length")) + 0.0 < u < 1.0 || throw(ArgumentError("u must be in (0, 1)")) + + g_x = gradlogp(x) + y = x .+ ϵ .* _apply_M(g_x, cholM) .+ sqrt(2ϵ) .* _apply_L(ξ, cholM) + + logp_x = logp(x) + logp_y = logp(y) + g_y = gradlogp(y) + + logq_y_given_x = logq_mala(y, x, g_x, ϵ; cholM=cholM) + logq_x_given_y = logq_mala(x, y, g_y, ϵ; cholM=cholM) + logα = (logp_y + logq_x_given_y) - (logp_x + logq_y_given_x) + + accepted = log(u) < logα + x_next = accepted ? y : x + return x_next, accepted, logα end """ Stop-gradient surrogate step used for Jacobians. -`a` (0.0 or 1.0) must be provided as a constant by the DEER machinery. +`a` (0 or 1) must be provided as a constant by the DEER machinery. """ function mala_step_surrogate( - logp, gradlogp, x::AbstractVector, ϵ::Real, ξ::AbstractVector, a::Real + logp, gradlogp, x::AbstractVector, ϵ::Real, ξ::AbstractVector, a::Real; + cholM=nothing, ) - y = mala_proposal(logp, gradlogp, x, ϵ, ξ) + y = mala_proposal(logp, gradlogp, x, ϵ, ξ; cholM=cholM) return (a .* y) .+ ((1 - a) .* x) end +# Apply mass matrix to a D×N matrix of gradient columns (same math as scalar, +# matrix multiply broadcasts naturally). +_apply_M_batched(G::AbstractMatrix, ::Nothing) = G +_apply_M_batched(G::AbstractMatrix, cholM::Cholesky) = cholM.L * (cholM.L' * G) + +_apply_L_batched(Ξ::AbstractMatrix, ::Nothing) = Ξ +_apply_L_batched(Ξ::AbstractMatrix, cholM::Cholesky) = cholM.L * Ξ + +""" +Compute column-wise M⁻¹-norm squared: `[||R[:,n]||²_{M⁻¹}]_n`. +`R` is D×N; returns a length-N vector. +GPU-compatible (uses `sum(abs2, …; dims=1)` which works on CuArrays). +Note: `cholM` must be `nothing` when using GPU arrays; the triangular solve +`cholM.L \\ R` pulls device arrays to CPU. +""" +function _quad_Minv_batched(R::AbstractMatrix, ::Nothing) + return vec(sum(abs2, R; dims=1)) +end + +function _quad_Minv_batched(R::AbstractMatrix, cholM::Cholesky) + W = cholM.L \ R + return vec(sum(abs2, W; dims=1)) +end + +""" + logq_mala_batched(Y, X, gradlogp_X, ε; cholM=nothing) + +Compute `log q(Y[:,n] | X[:,n])` for all N chains simultaneously. +`Y`, `X`, `gradlogp_X` are D×N; returns a length-N vector. +""" +function logq_mala_batched( + Y::AbstractMatrix, + X::AbstractMatrix, + gradlogp_X::AbstractMatrix, + ε::Real; + cholM=nothing, +) + T = typeof(ε) + D = size(X, 1) + μ = X .+ ε .* _apply_M_batched(gradlogp_X, cholM) + R = Y .- μ + q = _quad_Minv_batched(R, cholM) + ldet = _logdet_M(cholM) + return @. -T(0.5) * q / (2ε) - (T(D) / 2) * log(T(4π) * ε) - T(0.5) * ldet +end + +""" + mala_step_batched(logp_batch, gradlogp_batch, X, ε, Ξ, u; cholM=nothing) + +Run one MALA step for N chains simultaneously. + +- `X` :: D×N — current states (one chain per column). +- `Ξ` :: D×N — N(0,I) noise. +- `u` :: length-N — Uniform(0,1) draws. +- `logp_batch(X)` → length-N log-densities. +- `gradlogp_batch(X)` → D×N gradient matrix. + +Returns `(X_next::AbstractMatrix, accepted::AbstractVector)`. + +**GPU use:** pass `CuArray` inputs and GPU-compatible `logp_batch`/`gradlogp_batch`. +Requires `cholM=nothing` for full on-device execution (Cholesky preconditioner involves +a CPU-side triangular solve). Use `eltype(X)` for `ε` to avoid float-type promotions +that would pull data off GPU. +""" +function mala_step_batched( + logp_batch, + gradlogp_batch, + X::AbstractMatrix, + ε::Real, + Ξ::AbstractMatrix, + u::AbstractVector; + cholM=nothing, +) + D, N = size(X) + size(Ξ) == (D, N) || throw(DimensionMismatch("X and Ξ must have the same size")) + length(u) == N || throw(DimensionMismatch("u must have length N = size(X,2)")) + + # Cast ε to element type of X to avoid float-promotion off GPU. + ε_T = eltype(X)(ε) + + G_X = gradlogp_batch(X) # D×N + Y = X .+ ε_T .* _apply_M_batched(G_X, cholM) .+ + sqrt(2 * ε_T) .* _apply_L_batched(Ξ, cholM) # D×N + + lp_X = logp_batch(X) # N + lp_Y = logp_batch(Y) # N + G_Y = gradlogp_batch(Y) # D×N + + lq_YX = logq_mala_batched(Y, X, G_X, ε_T; cholM=cholM) # N + lq_XY = logq_mala_batched(X, Y, G_Y, ε_T; cholM=cholM) # N + + logα = @. (lp_Y + lq_XY) - (lp_X + lq_YX) # N + accepted = @. log(u) < logα # N Bool + + # Select: proposal if accepted, current if rejected. + # reshape to 1×N so it broadcasts against D×N. + mask = reshape(accepted, 1, N) + X_next = @. ifelse(mask, Y, X) # D×N + return X_next, vec(accepted) +end + end # module diff --git a/src/ParallelMCMC.jl b/src/ParallelMCMC.jl index 37900e5..d951d89 100644 --- a/src/ParallelMCMC.jl +++ b/src/ParallelMCMC.jl @@ -1,13 +1,21 @@ module ParallelMCMC -# imports -using Distributions -using LinearAlgebra -using LogExpFunctions +using AbstractMCMC +using CUDA +using Enzyme using MCMCChains +using LinearAlgebra +using Random +using Statistics -# inclusions include("MALA/MALA.jl") include("DEER/DEER.jl") +include("interface.jl") + +export DensityModel +export MALASampler, MALATransition, MALAState +export AdaptiveMALASampler, AdaptiveMALATransition, AdaptiveMALAState +export DEERSampler, DEERTransition, DEERState, MALATapeElement +export MALA, DEER end diff --git a/src/interface.jl b/src/interface.jl new file mode 100644 index 0000000..c6b7569 --- /dev/null +++ b/src/interface.jl @@ -0,0 +1,617 @@ +#= +AbstractMCMC interface for ParallelMCMC samplers. + +Defines model/sampler/state/transition types and implements +`AbstractMCMC.step` so that `sample(model, sampler, N)` works out of the box. +=# + +""" + DensityModel(logdensity, grad_logdensity, dim) + +Wraps a log-density function and its gradient for use with ParallelMCMC samplers. + +- `logdensity(x::AbstractVector) -> Real` +- `grad_logdensity(x::AbstractVector) -> AbstractVector` +- `dim::Int` — dimensionality of the parameter space +""" +struct DensityModel{F,G} <: AbstractMCMC.AbstractModel + logdensity::F + grad_logdensity::G + dim::Int +end + +""" + MALASampler(epsilon; cholM=nothing) + +Metropolis-Adjusted Langevin Algorithm sampler with step size `epsilon`. + +Optionally pass `cholM = cholesky(M)` to use a mass matrix `M` as a +preconditioner. The proposal becomes `y = x + ε M ∇logp(x) + √(2ε) L ξ` +where `L` is the Cholesky factor of `M`. +""" +struct MALASampler{FP<:AbstractFloat, CM} <: AbstractMCMC.AbstractSampler + epsilon::FP + cholM::CM +end + +function MALASampler(epsilon::Real; cholM=nothing) + epsilon > 0 || throw(ArgumentError("epsilon must be > 0, got $epsilon")) + eps_f = float(epsilon) + return MALASampler{typeof(eps_f), typeof(cholM)}(eps_f, cholM) +end + +struct MALAState{V<:AbstractVector, L<:Real} + x::V + logp::L +end + +struct MALATransition{V<:AbstractVector, L<:Real} + x::V + logp::L + accepted::Bool +end + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::DensityModel, + sampler::MALASampler{FP}; + initial_params=nothing, + kwargs..., +) where {FP} + x = if initial_params !== nothing + copy(initial_params) + else + randn(rng, FP, model.dim) + end + logp_val = model.logdensity(x) + t = MALATransition(x, logp_val, true) + s = MALAState(x, logp_val) + return t, s +end + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::DensityModel, + sampler::MALASampler, + state::MALAState; + kwargs..., +) + x = state.x + ϵ = sampler.epsilon + D = model.dim + + ξ = randn(rng, eltype(x), D) + u = rand(rng) + + x_next, accepted = MALA.mala_step_full( + model.logdensity, model.grad_logdensity, x, ϵ, ξ, u; + cholM=sampler.cholM, + ) + + logp_val = accepted ? model.logdensity(x_next) : state.logp + t = MALATransition(x_next, logp_val, accepted) + s = MALAState(x_next, logp_val) + return t, s +end + +function AbstractMCMC.bundle_samples( + samples::Vector{<:MALATransition}, + model::DensityModel, + sampler::MALASampler, + state::MALAState, + ::Type{MCMCChains.Chains}; + param_names=nothing, + kwargs..., +) + N = length(samples) + D = model.dim + + names = if param_names !== nothing + param_names + else + [Symbol("x[$i]") for i in 1:D] + end + + internal_names = [:logp, :accepted] + + vals = Matrix{Float64}(undef, N, D) + internals = Matrix{Float64}(undef, N, 2) + + for i in 1:N + s = samples[i] + vals[i, :] .= s.x + internals[i, 1] = s.logp + internals[i, 2] = s.accepted ? 1.0 : 0.0 + end + + return MCMCChains.Chains( + hcat(vals, internals), + vcat(names, internal_names), + Dict(:parameters => names, :internals => internal_names), + ) +end + +""" + MALATapeElement(ξ, u) + +One element of the MALA noise tape: a noise vector `ξ ~ N(0,I)` and a uniform +scalar `u ~ Uniform(0,1)`. Stored with a concrete vector type `V` for type +stability inside `DEER.TapedRecursion`. +""" +struct MALATapeElement{FP<:AbstractFloat, V<:AbstractVector{FP}} + ξ::V + u::FP +end + +""" + DEERSampler(epsilon; T, maxiter, tol_abs, tol_rel, jacobian, damping, probes, cholM, backend) + +DEER-accelerated MALA sampler. + +DEER solves for a trajectory of `T` steps in parallel (O(log T) sweep levels), +then the AbstractMCMC interface returns samples from that trajectory +sequentially. When the trajectory is exhausted a new tape is drawn and DEER +re-solves starting from the last state. + +# Arguments +- `epsilon` — MALA step size. +- `T` — trajectory length per DEER solve (default 64). +- `maxiter` — maximum DEER iterations per solve (default 200). +- `tol_abs`, `tol_rel` — convergence tolerances (default 1e-6, 1e-5). +- `jacobian` — Jacobian mode: `:diag`, `:stoch_diag`, or `:full` (default `:diag`). +- `damping` — DEER damping in (0,1] (default 0.5; helps convergence). +- `probes` — Hutchinson probes for `:stoch_diag` mode (default 1). +- `cholM` — optional Cholesky factor of a mass matrix `M` (default `nothing` = identity). +- `backend` — AD backend for Jacobian computation (default `AutoMooncake()`). + For GPU execution pass a GPU-compatible backend, e.g. `AutoEnzyme()`. + +# GPU use +The parallel-prefix scan (`solve_affine_scan_diag`) runs as pure array broadcasts +and is array-type-agnostic. To run DEER on GPU: +1. Implement `logdensity` and `grad_logdensity` using GPU-compatible operations. +2. Pass `backend = AutoEnzyme()` (or another GPU-compatible `ADTypes` backend). +3. Pass a GPU vector as `initial_params` to `sample`. + +# Parallel chains +Both `MALASampler` and `DEERSampler` are compatible with +`AbstractMCMC.sample(model, sampler, MCMCThreads(), N, nchains)`. Each chain +has its own immutable state so there is no shared mutable data. +""" +struct DEERSampler{FP<:AbstractFloat, CM, AD} <: AbstractMCMC.AbstractSampler + epsilon::FP + T::Int + maxiter::Int + tol_abs::FP + tol_rel::FP + jacobian::Symbol + damping::FP + probes::Int + cholM::CM + backend::AD +end + +function DEERSampler( + epsilon::Real; + T::Int=64, + maxiter::Int=200, + tol_abs::Real=1e-6, + tol_rel::Real=1e-5, + jacobian::Symbol=:diag, + damping::Real=0.5, + probes::Int=1, + cholM=nothing, + backend=DEER.DEFAULT_BACKEND, +) + epsilon > 0 || throw(ArgumentError("epsilon must be > 0, got $epsilon")) + eps_f = float(epsilon) + FP = typeof(eps_f) + return DEERSampler{FP, typeof(cholM), typeof(backend)}( + eps_f, T, maxiter, + FP(tol_abs), FP(tol_rel), + jacobian, FP(damping), probes, cholM, backend, + ) +end + +""" +State for a `DEERSampler` chain. + +- `x` — current position (= `trajectory[:, t]`, the last returned sample). +- `logp` — log-density at `x`. +- `trajectory` — D×T matrix produced by the most recent DEER solve. +- `tape` — noise tape used for that solve. +- `t` — index within `trajectory` of the last returned sample (1-indexed). +""" +struct DEERState{V<:AbstractVector, L<:Real, M<:AbstractMatrix} + x::V + logp::L + trajectory::M + tape::Vector{<:MALATapeElement} + t::Int +end + +""" +One DEER sample: parameter vector `x` and its log-density `logp`. +""" +struct DEERTransition{V<:AbstractVector, L<:Real} + x::V + logp::L +end + +function _build_mala_deer_rec( + model::DensityModel, ε::Real, tape::Vector{<:MALATapeElement}; + cholM=nothing, backend=DEER.DEFAULT_BACKEND, +) + logp = model.logdensity + gradlogp = model.grad_logdensity + + step_fwd = (x, te) -> MALA.mala_step_taped(logp, gradlogp, x, ε, te.ξ, te.u; cholM=cholM) + step_lin = (x, te, a) -> MALA.mala_step_surrogate(logp, gradlogp, x, ε, te.ξ, a; cholM=cholM) + consts = (x, te) -> (MALA.mala_accept_indicator(logp, gradlogp, x, ε, te.ξ, te.u; cholM=cholM),) + + return DEER.TapedRecursion( + step_fwd, step_lin, tape; + consts=consts, const_example=(0.0,), backend=backend, + ) +end + +function _deer_solve_new_tape( + rng::Random.AbstractRNG, + model::DensityModel, + sampler::DEERSampler, + x0::AbstractVector, +) + D = model.dim + T = sampler.T + FP = typeof(sampler.epsilon) + # Generate noise on the same device as x0: generate on CPU then copy via copyto!. + # copyto!(CuVector, Vector) performs a host-to-device transfer in CUDA.jl. + tape = map(1:T) do _ + ξ = copyto!(similar(x0, D), randn(rng, FP, D)) + MALATapeElement(ξ, FP(rand(rng))) + end + rec = _build_mala_deer_rec(model, sampler.epsilon, tape; cholM=sampler.cholM, backend=sampler.backend) + S = DEER.solve( + rec, x0; + tol_abs = sampler.tol_abs, + tol_rel = sampler.tol_rel, + maxiter = sampler.maxiter, + jacobian = sampler.jacobian, + damping = sampler.damping, + probes = sampler.probes, + rng = rng, + ) + return S, tape +end + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::DensityModel, + sampler::DEERSampler{FP}; + initial_params=nothing, + kwargs..., +) where {FP} + x0 = if initial_params !== nothing + copy(initial_params) + else + randn(rng, FP, model.dim) + end + + S, tape = _deer_solve_new_tape(rng, model, sampler, x0) + x1 = S[:, 1] + logp1 = model.logdensity(x1) + trans = DEERTransition(x1, logp1) + state = DEERState(x1, logp1, S, tape, 1) + return trans, state +end + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::DensityModel, + sampler::DEERSampler, + state::DEERState; + kwargs..., +) + T = sampler.T + t_next = state.t + 1 + + if t_next <= T + # Consume the next cached sample from the trajectory. + x_new = state.trajectory[:, t_next] + logp_new = model.logdensity(x_new) + trans = DEERTransition(x_new, logp_new) + new_state = DEERState(x_new, logp_new, state.trajectory, state.tape, t_next) + return trans, new_state + else + # Trajectory exhausted — re-solve with a fresh tape. + x0 = state.trajectory[:, T] + S_new, tape = _deer_solve_new_tape(rng, model, sampler, x0) + x_new = S_new[:, 1] + logp_new = model.logdensity(x_new) + trans = DEERTransition(x_new, logp_new) + new_state = DEERState(x_new, logp_new, S_new, tape, 1) + return trans, new_state + end +end + +function AbstractMCMC.bundle_samples( + samples::Vector{<:DEERTransition}, + model::DensityModel, + sampler::DEERSampler, + state::DEERState, + ::Type{MCMCChains.Chains}; + param_names=nothing, + kwargs..., +) + N = length(samples) + D = model.dim + + names = if param_names !== nothing + param_names + else + [Symbol("x[$i]") for i in 1:D] + end + + internal_names = [:logp] + + vals = Matrix{Float64}(undef, N, D) + internals = Matrix{Float64}(undef, N, 1) + + for i in 1:N + vals[i, :] .= samples[i].x + internals[i, 1] = samples[i].logp + end + + return MCMCChains.Chains( + hcat(vals, internals), + vcat(names, internal_names), + Dict(:parameters => names, :internals => internal_names), + ) +end + +""" + AdaptiveMALASampler(epsilon_init; n_warmup, target_accept, gamma, t0, kappa, cholM) + +MALA sampler with automatic step-size adaptation via dual averaging +(Nesterov 2009, as used in NUTS — Hoffman & Gelman 2014). + +During the first `n_warmup` steps the step size `ε` is adapted online to drive +the Metropolis acceptance rate toward `target_accept` (default 0.574, which is +the asymptotically optimal rate for MALA in high dimensions). After warmup the +smoothed estimate `ε̄` is frozen and used for all remaining steps. + +# Algorithm +At warmup step `m`, given current log-acceptance ratio `logα`: + + α = min(1, exp(logα)) + H̄_m = (1 − 1/(m+t₀)) H̄_{m−1} + (1/(m+t₀)) (δ − α) + log ε_m = μ − √m/γ · H̄_m (instantaneous) + log ε̄_m = m^(−κ) log ε_m + (1−m^(−κ)) log ε̄_{m−1} (smoothed) + +where μ = log(10 ε₀) is a fixed target. After warmup `ε̄` is used. + +# Keyword arguments +- `n_warmup` — adaptation steps (default 1000). +- `target_accept` — δ, desired acceptance rate (default 0.574). +- `gamma` — γ, regularisation strength (default 0.05). +- `t0` — stability offset (default 10.0). +- `kappa` — shrinkage exponent κ ∈ (0.5, 1] (default 0.75). +- `cholM` — optional Cholesky factor of a mass matrix `M` (default `nothing` = identity). + +# MCMCChains output +The `Chains` object includes internals `[:logp, :accepted, :step_size, :is_warmup]`. +After warmup, `step_size` is constant (the frozen `ε̄`). + +# Parallel chains +Works with `MCMCThreads()`. R-hat and ESS are computed automatically by +`MCMCChains` from multi-chain output. + +# Turing.jl / LogDensityProblems +Load `LogDensityProblems` (and optionally `LogDensityProblemsAD`) then use the +`DensityModel(ld)` constructor to wrap any `LogDensityProblems`-compatible model +(including Turing/DynamicPPL models) directly. +""" +struct AdaptiveMALASampler{FP<:AbstractFloat, CM} <: AbstractMCMC.AbstractSampler + epsilon_init::FP + n_warmup::Int + target_accept::FP + gamma::FP + t0::FP + kappa::FP + cholM::CM +end + +function AdaptiveMALASampler( + epsilon_init::Real; + n_warmup::Int=1000, + target_accept::Real=0.574, + gamma::Real=0.05, + t0::Real=10.0, + kappa::Real=0.75, + cholM=nothing, +) + epsilon_init > 0 || throw(ArgumentError("epsilon_init must be > 0, got $epsilon_init")) + 0 < target_accept < 1 || throw(ArgumentError("target_accept must be in (0,1), got $target_accept")) + gamma > 0 || throw(ArgumentError("gamma must be > 0, got $gamma")) + t0 > 0 || throw(ArgumentError("t0 must be > 0, got $t0")) + 0.5 < kappa <= 1.0 || throw(ArgumentError("kappa must be in (0.5, 1], got $kappa")) + + eps_f = float(epsilon_init) + FP = typeof(eps_f) + return AdaptiveMALASampler{FP, typeof(cholM)}( + eps_f, n_warmup, + FP(target_accept), FP(gamma), FP(t0), FP(kappa), + cholM, + ) +end + +struct AdaptiveMALAState{V<:AbstractVector, FP<:AbstractFloat} + x::V + logp::FP + epsilon::FP # instantaneous step size ε_m + epsilon_bar::FP # smoothed step size ε̄_m (frozen after warmup) + H_bar::FP # dual-average statistic H̄_m + step::Int # warmup step counter (0 = initialisation) +end + +struct AdaptiveMALATransition{V<:AbstractVector, FP<:AbstractFloat} + x::V + logp::FP + accepted::Bool + step_size::FP # ε used for this step + is_warmup::Bool +end + +function _dual_average_update( + epsilon_init::FP, + epsilon_bar::FP, + H_bar::FP, + m::Int, + logα::FP, + sampler::AdaptiveMALASampler{FP}, +) where {FP<:AbstractFloat} + α = min(one(FP), exp(logα)) + δ = sampler.target_accept + γ = sampler.gamma + t0 = sampler.t0 + κ = sampler.kappa + μ = log(10 * epsilon_init) # fixed shrinkage target + + inv_mt0 = one(FP) / (FP(m) + t0) + H_bar_new = (one(FP) - inv_mt0) * H_bar + inv_mt0 * (δ - α) + log_ε = μ - sqrt(FP(m)) / γ * H_bar_new + mk = FP(m)^(-κ) + log_ε_bar_new = mk * log_ε + (one(FP) - mk) * log(epsilon_bar) + + return exp(log_ε), exp(log_ε_bar_new), H_bar_new +end + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::DensityModel, + sampler::AdaptiveMALASampler{FP}; + initial_params=nothing, + kwargs..., +) where {FP} + x = if initial_params !== nothing + copy(initial_params) + else + randn(rng, FP, model.dim) + end + logp_val = FP(model.logdensity(x)) + trans = AdaptiveMALATransition(x, logp_val, true, sampler.epsilon_init, true) + state = AdaptiveMALAState(x, logp_val, sampler.epsilon_init, sampler.epsilon_init, zero(FP), 0) + return trans, state +end + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::DensityModel, + sampler::AdaptiveMALASampler{FP}, + state::AdaptiveMALAState; + kwargs..., +) where {FP} + D = model.dim + in_warmup = state.step < sampler.n_warmup + ε = in_warmup ? state.epsilon : state.epsilon_bar + + ξ = randn(rng, eltype(state.x), D) + u = rand(rng) + + x_next, accepted, logα = MALA.mala_step_with_logα( + model.logdensity, model.grad_logdensity, state.x, ε, ξ, u; + cholM=sampler.cholM, + ) + + logp_next = accepted ? FP(model.logdensity(x_next)) : state.logp + + # Dual-average adaptation (only during warmup) + m_new = state.step + 1 + ε_new, ε_bar_new, H_bar_new = if in_warmup + _dual_average_update( + sampler.epsilon_init, state.epsilon_bar, state.H_bar, + m_new, FP(logα), sampler, + ) + else + state.epsilon, state.epsilon_bar, state.H_bar + end + + trans = AdaptiveMALATransition(x_next, logp_next, accepted, ε, in_warmup) + new_state = AdaptiveMALAState(x_next, logp_next, ε_new, ε_bar_new, H_bar_new, m_new) + return trans, new_state +end + +function AbstractMCMC.bundle_samples( + samples::Vector{<:AdaptiveMALATransition}, + model::DensityModel, + sampler::AdaptiveMALASampler, + state::AdaptiveMALAState, + ::Type{MCMCChains.Chains}; + param_names=nothing, + kwargs..., +) + N = length(samples) + D = model.dim + + names = if param_names !== nothing + param_names + else + [Symbol("x[$i]") for i in 1:D] + end + + internal_names = [:logp, :accepted, :step_size, :is_warmup] + + vals = Matrix{Float64}(undef, N, D) + internals = Matrix{Float64}(undef, N, 4) + + for i in 1:N + s = samples[i] + vals[i, :] .= s.x + internals[i, 1] = s.logp + internals[i, 2] = s.accepted ? 1.0 : 0.0 + internals[i, 3] = s.step_size + internals[i, 4] = s.is_warmup ? 1.0 : 0.0 + end + + return MCMCChains.Chains( + hcat(vals, internals), + vcat(names, internal_names), + Dict(:parameters => names, :internals => internal_names), + ) +end + +""" + DensityModel(ld) + +Construct a `DensityModel` from any object `ld` that implements the +[LogDensityProblems](https://github.com/tpapp/LogDensityProblems.jl) interface, +i.e. provides `LogDensityProblems.logdensity`, `LogDensityProblems.logdensity_and_gradient`, +and `LogDensityProblems.dimension`. + +Requires the `LogDensityProblems` package to be loaded: + +```julia +using LogDensityProblems # gradient-free: only logdensity used +using LogDensityProblemsAD # or this, for AD-based gradients +``` + +# Turing.jl example +```julia +using Turing, LogDensityProblems, LogDensityProblemsAD, Mooncake + +@model function mymodel(data) + μ ~ Normal(0, 1) + data ~ Normal(μ, 1) +end + +ld = DynamicPPL.LogDensityFunction(mymodel(obs)) +ldg = LogDensityProblemsAD.ADgradient(Mooncake.Extras.MooncakeAD(), ld) +model = DensityModel(ldg) + +chain = sample(model, AdaptiveMALASampler(0.1; n_warmup=500), 2000; + chain_type=MCMCChains.Chains, progress=true) +``` + +This method is defined in the `LogDensityProblemsExt` extension and is only +available when `LogDensityProblems` has been loaded. +""" +function DensityModel end # extended by LogDensityProblemsExt diff --git a/test/Project.toml b/test/Project.toml index 4e216a9..e1dfe28 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,8 +1,16 @@ [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[extras] +CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" diff --git a/test/test-AbstractMCMC-Interface.jl b/test/test-AbstractMCMC-Interface.jl new file mode 100644 index 0000000..3da103d --- /dev/null +++ b/test/test-AbstractMCMC-Interface.jl @@ -0,0 +1,198 @@ +using Test +using Random +using LinearAlgebra +using Statistics +using MCMCChains + +using ParallelMCMC + +logp_iface(x) = -0.5 * dot(x, x) +gradlogp_iface(x) = -x + +@testset "AbstractMCMC interface" begin + @testset "DensityModel construction" begin + m = DensityModel(logp_iface, gradlogp_iface, 3) + @test m isa ParallelMCMC.AbstractMCMC.AbstractModel + @test m.dim == 3 + @test m.logdensity([0.0, 0.0, 0.0]) == 0.0 + @test m.grad_logdensity([1.0, 2.0, 3.0]) == [-1.0, -2.0, -3.0] + end + + @testset "MALASampler construction" begin + s = MALASampler(0.1) + @test s isa ParallelMCMC.AbstractMCMC.AbstractSampler + @test s.epsilon == 0.1 + end + + @testset "initial step draws from rng" begin + rng = MersenneTwister(42) + model = DensityModel(logp_iface, gradlogp_iface, 5) + sampler = MALASampler(0.1) + + transition, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler) + + @test transition isa MALATransition + @test state isa MALAState + @test length(transition.x) == 5 + @test length(state.x) == 5 + @test transition.logp == logp_iface(transition.x) + @test state.logp == transition.logp + @test transition.accepted == true + end + + @testset "initial step respects initial_params" begin + rng = MersenneTwister(42) + model = DensityModel(logp_iface, gradlogp_iface, 3) + sampler = MALASampler(0.1) + x0 = [1.0, 2.0, 3.0] + + transition, state = ParallelMCMC.AbstractMCMC.step( + rng, model, sampler; initial_params=x0 + ) + + @test transition.x == x0 + @test state.x == x0 + @test transition.logp == logp_iface(x0) + end + + @testset "initial step does not mutate initial_params" begin + rng = MersenneTwister(42) + model = DensityModel(logp_iface, gradlogp_iface, 3) + sampler = MALASampler(0.1) + x0 = [1.0, 2.0, 3.0] + x0_copy = copy(x0) + + ParallelMCMC.AbstractMCMC.step(rng, model, sampler; initial_params=x0) + @test x0 == x0_copy + end + + @testset "subsequent step produces valid transition" begin + rng = MersenneTwister(42) + model = DensityModel(logp_iface, gradlogp_iface, 4) + sampler = MALASampler(0.1) + + t1, s1 = ParallelMCMC.AbstractMCMC.step(rng, model, sampler) + t2, s2 = ParallelMCMC.AbstractMCMC.step(rng, model, sampler, s1) + + @test t2 isa MALATransition + @test s2 isa MALAState + @test length(t2.x) == 4 + @test t2.accepted isa Bool + @test isfinite(t2.logp) + @test s2.x == t2.x + @test s2.logp == t2.logp + end + + @testset "step determinism with fixed rng" begin + model = DensityModel(logp_iface, gradlogp_iface, 3) + sampler = MALASampler(0.2) + + rng1 = MersenneTwister(999) + t1a, s1a = ParallelMCMC.AbstractMCMC.step(rng1, model, sampler) + t1b, s1b = ParallelMCMC.AbstractMCMC.step(rng1, model, sampler, s1a) + + rng2 = MersenneTwister(999) + t2a, s2a = ParallelMCMC.AbstractMCMC.step(rng2, model, sampler) + t2b, s2b = ParallelMCMC.AbstractMCMC.step(rng2, model, sampler, s2a) + + @test t1a.x == t2a.x + @test t1b.x == t2b.x + @test t1b.accepted == t2b.accepted + end + + @testset "rejection preserves logp from state" begin + model = DensityModel(logp_iface, gradlogp_iface, 3) + sampler = MALASampler(0.5) + + # Run enough steps to hit a rejection + rng = MersenneTwister(12345) + _, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler) + + found_rejection = false + for _ in 1:500 + t, state_new = ParallelMCMC.AbstractMCMC.step(rng, model, sampler, state) + if !t.accepted + # On rejection, x doesn't move and logp should match prior state + @test t.x == state.x + @test t.logp == state.logp + found_rejection = true + break + end + state = state_new + end + @test found_rejection + end + + @testset "sample() runs end-to-end" begin + model = DensityModel(logp_iface, gradlogp_iface, 3) + sampler = MALASampler(0.1) + + chain = sample(model, sampler, 100; progress=false) + @test length(chain) == 100 + end + + @testset "sample() with chain_type=Chains" begin + model = DensityModel(logp_iface, gradlogp_iface, 2) + sampler = MALASampler(0.15) + + chain = sample( + model, sampler, 200; + chain_type=MCMCChains.Chains, progress=false, + ) + + @test chain isa MCMCChains.Chains + @test size(chain, 1) == 200 + + # Parameter columns present + param_names = names(chain, :parameters) + @test length(param_names) == 2 + + # Internal columns present + internal_names = names(chain, :internals) + @test :logp in internal_names + @test :accepted in internal_names + + # logp values should be finite + @test all(isfinite, chain[:logp]) + + # accepted values should be 0 or 1 + acc = chain[:accepted] + @test all(a -> a == 0.0 || a == 1.0, acc) + end + + @testset "sample() with custom param_names" begin + model = DensityModel(logp_iface, gradlogp_iface, 2) + sampler = MALASampler(0.15) + + chain = sample( + model, sampler, 50; + chain_type=MCMCChains.Chains, progress=false, + param_names=[:mu, :sigma], + ) + + @test chain isa MCMCChains.Chains + param_names = names(chain, :parameters) + @test :mu in param_names + @test :sigma in param_names + end + + @testset "stationary distribution via sample()" begin + D = 3 + model = DensityModel(logp_iface, gradlogp_iface, D) + sampler = MALASampler(0.3) + + chain = sample( + MersenneTwister(2025), model, sampler, 20_000; + chain_type=MCMCChains.Chains, progress=false, + ) + + burn = 3_000 + post = Array(chain[burn:end, :, :]) # (N-burn) × D + + mu = vec(mean(post; dims=1)) + @test maximum(abs.(mu)) < 0.1 + + vars = vec(var(post; dims=1)) + @test maximum(abs.(vars .- 1.0)) < 0.15 + end +end diff --git a/test/test-Adaptive-MALA.jl b/test/test-Adaptive-MALA.jl new file mode 100644 index 0000000..15ce7f2 --- /dev/null +++ b/test/test-Adaptive-MALA.jl @@ -0,0 +1,238 @@ +using Test +using Random +using LinearAlgebra +using Statistics +using MCMCChains + +using ParallelMCMC +const MALA = ParallelMCMC.MALA + +logp_adapt(x) = -0.5 * dot(x, x) +gradlogp_adapt(x) = -x + + +@testset "mala_step_with_logα returns same x_next and accepted as mala_step_full" begin + rng = MersenneTwister(1) + D = 4 + for _ in 1:20 + x = randn(rng, D) + ξ = randn(rng, D) + u = rand(rng) * 0.999 + 1e-15 + + x1, a1 = MALA.mala_step_full(logp_adapt, gradlogp_adapt, x, 0.1, ξ, u) + x2, a2, logα = MALA.mala_step_with_logα(logp_adapt, gradlogp_adapt, x, 0.1, ξ, u) + + @test x1 == x2 + @test a1 == a2 + @test isfinite(logα) + # When accepted logα ≥ log(u), when rejected logα < log(u) + @test a2 == (log(u) < logα) + end +end + +@testset "AdaptiveMALASampler construction" begin + s = AdaptiveMALASampler(0.1) + @test s isa ParallelMCMC.AbstractMCMC.AbstractSampler + @test s.epsilon_init == 0.1 + @test s.n_warmup == 1000 + @test s.target_accept ≈ 0.574 + @test s.gamma ≈ 0.05 + @test s.t0 ≈ 10.0 + @test s.kappa ≈ 0.75 + @test s.cholM === nothing + + s2 = AdaptiveMALASampler(0.05; n_warmup=500, target_accept=0.65) + @test s2.n_warmup == 500 + @test s2.target_accept ≈ 0.65 +end + +@testset "AdaptiveMALASampler initial step" begin + rng = MersenneTwister(42) + model = DensityModel(logp_adapt, gradlogp_adapt, 3) + sampler = AdaptiveMALASampler(0.1) + + trans, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler) + + @test trans isa AdaptiveMALATransition + @test state isa AdaptiveMALAState + @test length(trans.x) == 3 + @test isfinite(trans.logp) + @test trans.step_size == sampler.epsilon_init + @test trans.is_warmup == true + @test state.step == 0 + @test state.epsilon == sampler.epsilon_init + @test state.epsilon_bar == sampler.epsilon_init +end + +@testset "AdaptiveMALASampler initial step respects initial_params" begin + rng = MersenneTwister(7) + model = DensityModel(logp_adapt, gradlogp_adapt, 2) + sampler = AdaptiveMALASampler(0.1) + x0 = [1.0, -1.0] + x0_copy = copy(x0) + + trans, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler; initial_params=x0) + + @test x0 == x0_copy # not mutated + @test trans.x == x0 +end + +@testset "step counter increments during warmup" begin + rng = MersenneTwister(5) + model = DensityModel(logp_adapt, gradlogp_adapt, 2) + sampler = AdaptiveMALASampler(0.05; n_warmup=10) + + _, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler) + + for expected_step in 1:5 + _, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler, state) + @test state.step == expected_step + @test state.step <= sampler.n_warmup # still in warmup + end +end + +@testset "step size changes during warmup but freezes after" begin + rng = MersenneTwister(11) + model = DensityModel(logp_adapt, gradlogp_adapt, 3) + n_w = 20 + sampler = AdaptiveMALASampler(0.1; n_warmup=n_w) + + _, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler) + + epsilons_warmup = Float64[] + for _ in 1:n_w + t, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler, state) + push!(epsilons_warmup, t.step_size) + @test t.is_warmup == true + end + + # At least some epsilon change during warmup + @test !all(==(epsilons_warmup[1]), epsilons_warmup) + + # First post-warmup step: is_warmup should be false, step_size should be frozen ε̄ + epsilon_bar_final = state.epsilon_bar + t_post, state_post = ParallelMCMC.AbstractMCMC.step(rng, model, sampler, state) + @test t_post.is_warmup == false + @test t_post.step_size ≈ epsilon_bar_final + + # Subsequent post-warmup steps keep the same step_size + _, state_post2 = ParallelMCMC.AbstractMCMC.step(rng, model, sampler, state_post) + @test state_post2.epsilon_bar ≈ epsilon_bar_final +end + +@testset "adapted step size targets acceptance rate" begin + D = 5 + model = DensityModel(logp_adapt, gradlogp_adapt, D) + n_w = 2_000 + sampler = AdaptiveMALASampler(0.5; n_warmup=n_w, target_accept=0.574) + + # Use a single persistent RNG throughout warmup and post-warmup. + rng = MersenneTwister(42) + _, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler) + for _ in 1:n_w + _, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler, state) + end + + # After warmup, ε̄ should yield acceptance near target. + # Run 500 post-warmup steps and measure actual rate. + n_accept = 0 + for _ in 1:500 + t, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler, state) + n_accept += t.accepted + end + accept_rate = n_accept / 500 + + @test 0.35 < accept_rate < 0.80 # wide tolerance; adaptation is approximate +end + +@testset "sample() end-to-end with AdaptiveMALASampler" begin + model = DensityModel(logp_adapt, gradlogp_adapt, 2) + sampler = AdaptiveMALASampler(0.2; n_warmup=50) + + samples = sample(MersenneTwister(1), model, sampler, 100; progress=false) + @test length(samples) == 100 +end + +@testset "sample() with chain_type=Chains" begin + model = DensityModel(logp_adapt, gradlogp_adapt, 2) + sampler = AdaptiveMALASampler(0.2; n_warmup=50) + + chain = sample( + MersenneTwister(1), model, sampler, 150; + chain_type=MCMCChains.Chains, progress=false, + ) + + @test chain isa MCMCChains.Chains + @test size(chain, 1) == 150 + + internals = names(chain, :internals) + @test :logp in internals + @test :accepted in internals + @test :step_size in internals + @test :is_warmup in internals + + @test all(isfinite, chain[:logp]) + @test all(x -> x == 0.0 || x == 1.0, chain[:accepted]) + @test all(s -> s > 0, chain[:step_size]) + + # Warmup samples appear at the start + @test chain[:is_warmup][1] == 1.0 + # Post-warmup samples have is_warmup == 0 + @test chain[:is_warmup][end] == 0.0 +end + +@testset "step_size is constant after warmup" begin + model = DensityModel(logp_adapt, gradlogp_adapt, 3) + n_w = 30 + sampler = AdaptiveMALASampler(0.1; n_warmup=n_w) + + chain = sample( + MersenneTwister(3), model, sampler, n_w + 50; + chain_type=MCMCChains.Chains, progress=false, + ) + + # Filter by is_warmup flag to avoid off-by-one from the init transition. + is_wup = vec(chain[:is_warmup]) .== 1.0 + step_sizes = vec(chain[:step_size]) + post_warmup = step_sizes[.!is_wup] + + @test length(post_warmup) > 0 + # All post-warmup step sizes must be identical (frozen ε̄) + @test all(≈(post_warmup[1]), post_warmup) +end + +@testset "AdaptiveMALASampler stationary distribution" begin + D = 3 + model = DensityModel(logp_adapt, gradlogp_adapt, D) + n_w = 1_000 + sampler = AdaptiveMALASampler(0.5; n_warmup=n_w) + + chain = sample( + MersenneTwister(2025), model, sampler, n_w + 5_000; + chain_type=MCMCChains.Chains, progress=false, + ) + + # Discard warmup + post = Array(chain[(n_w + 1):end, :, :]) # 5000 × D + + mu = vec(mean(post; dims=1)) + vars = vec(var(post; dims=1)) + + @test maximum(abs.(mu)) < 0.15 + @test maximum(abs.(vars .- 1.0)) < 0.25 +end + +@testset "AdaptiveMALASampler parallel chains via MCMCThreads" begin + model = DensityModel(logp_adapt, gradlogp_adapt, 2) + sampler = AdaptiveMALASampler(0.1; n_warmup=20) + + chains = sample( + MersenneTwister(42), model, sampler, + ParallelMCMC.AbstractMCMC.MCMCThreads(), 60, 2; + chain_type=MCMCChains.Chains, progress=false, + ) + + @test chains isa MCMCChains.Chains + @test size(chains, 1) == 60 + @test size(chains, 3) == 2 +end diff --git a/test/test-DEER-Interface.jl b/test/test-DEER-Interface.jl new file mode 100644 index 0000000..eaa0492 --- /dev/null +++ b/test/test-DEER-Interface.jl @@ -0,0 +1,170 @@ +using Test +using Random +using LinearAlgebra +using Statistics +using MCMCChains + +using ParallelMCMC + +logp_deer(x) = -0.5 * dot(x, x) +gradlogp_deer(x) = -x + +@testset "DEERSampler construction" begin + s = DEERSampler(0.05) + @test s isa ParallelMCMC.AbstractMCMC.AbstractSampler + @test s.epsilon == 0.05 + @test s.T == 64 + @test s.maxiter == 200 + @test s.jacobian === :diag + + # keyword overrides + s2 = DEERSampler(0.1; T=32, jacobian=:stoch_diag, damping=0.8) + @test s2.T == 32 + @test s2.jacobian === :stoch_diag + @test s2.damping == 0.8 +end + +@testset "DEERSampler initial step" begin + rng = MersenneTwister(42) + model = DensityModel(logp_deer, gradlogp_deer, 3) + sampler = DEERSampler(0.05; T=16) + + trans, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler) + + @test trans isa DEERTransition + @test state isa DEERState + @test length(trans.x) == 3 + @test isfinite(trans.logp) + @test trans.logp ≈ logp_deer(trans.x) + @test state.t == 1 + @test size(state.trajectory) == (3, 16) + @test length(state.tape) == 16 +end + +@testset "DEERSampler initial step respects initial_params" begin + rng = MersenneTwister(42) + model = DensityModel(logp_deer, gradlogp_deer, 3) + sampler = DEERSampler(0.05; T=8) + x0 = [1.0, 2.0, 3.0] + x0_copy = copy(x0) + + trans, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler; initial_params=x0) + + # initial_params not mutated + @test x0 == x0_copy + # first sample should differ from x0 (DEER solves the trajectory) + @test length(trans.x) == 3 + @test isfinite(trans.logp) +end + +@testset "DEERSampler sequential steps advance trajectory index" begin + rng = MersenneTwister(7) + model = DensityModel(logp_deer, gradlogp_deer, 2) + sampler = DEERSampler(0.05; T=8) + + trans, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler) + @test state.t == 1 + + trans2, state2 = ParallelMCMC.AbstractMCMC.step(rng, model, sampler, state) + @test state2.t == 2 + @test trans2.x ≈ state.trajectory[:, 2] + + trans3, state3 = ParallelMCMC.AbstractMCMC.step(rng, model, sampler, state2) + @test state3.t == 3 +end + +@testset "DEERSampler re-solves at trajectory boundary" begin + rng = MersenneTwister(99) + model = DensityModel(logp_deer, gradlogp_deer, 2) + T = 4 + sampler = DEERSampler(0.05; T=T) + + _, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler) + + # advance to t == T + for _ in 1:(T - 1) + _, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler, state) + end + @test state.t == T + + # next step should trigger re-solve: t resets to 1 with a new trajectory + _, state_new = ParallelMCMC.AbstractMCMC.step(rng, model, sampler, state) + @test state_new.t == 1 + # new trajectory should differ from old (different tape) + @test state_new.trajectory !== state.trajectory +end + +@testset "DEERSampler sample() end-to-end" begin + model = DensityModel(logp_deer, gradlogp_deer, 2) + sampler = DEERSampler(0.05; T=16) + + samples = sample(MersenneTwister(1), model, sampler, 50; progress=false) + @test length(samples) == 50 +end + +@testset "DEERSampler sample() with chain_type=Chains" begin + model = DensityModel(logp_deer, gradlogp_deer, 2) + sampler = DEERSampler(0.05; T=16) + + chain = sample( + MersenneTwister(1), model, sampler, 100; + chain_type=MCMCChains.Chains, progress=false, + ) + + @test chain isa MCMCChains.Chains + @test size(chain, 1) == 100 + @test :logp in names(chain, :internals) + @test all(isfinite, chain[:logp]) + + param_names = names(chain, :parameters) + @test length(param_names) == 2 +end + +@testset "DEERSampler sample() with custom param_names" begin + model = DensityModel(logp_deer, gradlogp_deer, 2) + sampler = DEERSampler(0.05; T=16) + + chain = sample( + MersenneTwister(2), model, sampler, 40; + chain_type=MCMCChains.Chains, progress=false, + param_names=[:mu, :sigma], + ) + + @test :mu in names(chain, :parameters) + @test :sigma in names(chain, :parameters) +end + +@testset "DEERSampler stationary distribution" begin + D = 3 + model = DensityModel(logp_deer, gradlogp_deer, D) + sampler = DEERSampler(0.1; T=32, damping=0.5) + + chain = sample( + MersenneTwister(2025), model, sampler, 5_000; + chain_type=MCMCChains.Chains, progress=false, + ) + + burn = 500 + post = Array(chain[burn:end, :, :]) # (N-burn) × D + + mu = vec(mean(post; dims=1)) + vars = vec(var(post; dims=1)) + + @test maximum(abs.(mu)) < 0.15 + @test maximum(abs.(vars .- 1.0)) < 0.25 +end + +@testset "DEERSampler parallel chains via MCMCThreads" begin + model = DensityModel(logp_deer, gradlogp_deer, 2) + sampler = DEERSampler(0.05; T=8) + + chains = sample( + MersenneTwister(42), model, sampler, + ParallelMCMC.AbstractMCMC.MCMCThreads(), 40, 2; + chain_type=MCMCChains.Chains, progress=false, + ) + + @test chains isa MCMCChains.Chains + @test size(chains, 1) == 40 # samples per chain + @test size(chains, 3) == 2 # number of chains +end diff --git a/test/test-Deer-vs-MALA.jl b/test/test-Deer-vs-MALA.jl index e52d0ba..b0567d1 100644 --- a/test/test-Deer-vs-MALA.jl +++ b/test/test-Deer-vs-MALA.jl @@ -173,7 +173,7 @@ end s_prev = copy(s0) for t in 1:T s_prev = view(A, :, t) .* s_prev .+ view(B, :, t) - @inbounds S_ref[:, t] .= s_prev + S_ref[:, t] .= s_prev end # Scan diff --git a/test/test-GPU-DEER.jl b/test/test-GPU-DEER.jl new file mode 100644 index 0000000..0924a80 --- /dev/null +++ b/test/test-GPU-DEER.jl @@ -0,0 +1,109 @@ +using Test +using Random +using LinearAlgebra + +using ParallelMCMC +import ParallelMCMC: DEER +import ADTypes + +cuda_ok = try + using CUDA + CUDA.functional() +catch + false +end + +if !cuda_ok + @warn "CUDA not available or not functional — skipping GPU DEER tests" +else + +_logp_gpu(x) = -0.5f0 * sum(abs2, x) +_gradlogp_gpu(x) = -x + +@testset "solve_affine_scan_diag CPU vs GPU" begin + rng = MersenneTwister(1) + D, T = 4, 16 + + A_cpu = randn(rng, Float32, D, T) + B_cpu = randn(rng, Float32, D, T) + s0_cpu = randn(rng, Float32, D) + + S_cpu = DEER.solve_affine_scan_diag(A_cpu, B_cpu, s0_cpu) + + A_gpu = CUDA.cu(A_cpu) + B_gpu = CUDA.cu(B_cpu) + s0_gpu = CUDA.cu(s0_cpu) + + S_gpu = DEER.solve_affine_scan_diag(A_gpu, B_gpu, s0_gpu) + + @test S_gpu isa CUDA.CuMatrix + @test size(S_gpu) == (D, T) + @test Array(S_gpu) ≈ S_cpu atol=1e-5 +end + +enzyme_ok = try + using Enzyme + true +catch + false +end + +if !enzyme_ok + @warn "Enzyme not available — skipping GPU AD DEER tests" +else + + +# @testset "DEER.solve on GPU with AutoEnzyme" begin +# rng = MersenneTwister(42) +# D, T = 3, 8 + +# ε = 0.05f0 +# tape = map(1:T) do _ +# ξ_gpu = CUDA.cu(randn(rng, Float32, D)) +# u = Float32(rand(rng)) +# MALATapeElement(ξ_gpu, u) +# end + +# rec = DEER.TapedRecursion( +# (x, te) -> MALA.mala_step_taped(_logp_gpu, _gradlogp_gpu, x, ε, te.ξ, te.u), +# (x, te, a) -> MALA.mala_step_surrogate(_logp_gpu, _gradlogp_gpu, x, ε, te.ξ, a), +# tape; +# consts = (x, te) -> (MALA.mala_accept_indicator(_logp_gpu, _gradlogp_gpu, x, ε, te.ξ, te.u),), +# const_example = (0.0f0,), +# backend = ADTypes.AutoEnzyme(), +# ) + +# s0_gpu = CUDA.cu(randn(rng, Float32, D)) +# S, info = DEER.solve(rec, s0_gpu; jacobian=:stoch_diag, maxiter=20, return_info=true) + +# @test S isa CUDA.CuMatrix +# @test size(S) == (D, T) +# @test all(isfinite, Array(S)) +# end + +# @testset "DEERSampler interface on GPU" begin +# rng = MersenneTwister(7) +# D = 3 + +# model = DensityModel(_logp_gpu, _gradlogp_gpu, D) +# sampler = DEERSampler(0.05f0; T=8, maxiter=20, jacobian=:stoch_diag, backend=ADTypes.AutoEnzyme()) + +# x0_gpu = CUDA.cu(randn(rng, Float32, D)) +# trans, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler; initial_params=x0_gpu) + +# @test trans isa DEERTransition +# @test state isa DEERState +# @test trans.x isa CUDA.CuVector +# @test length(trans.x) == D +# @test isfinite(Float32(trans.logp)) +# @test size(state.trajectory) == (D, 8) +# @test state.trajectory isa CUDA.CuMatrix + +# # Second step consumes from the cached trajectory +# trans2, state2 = ParallelMCMC.AbstractMCMC.step(rng, model, sampler, state) +# @test state2.t == 2 +# @test trans2.x isa CUDA.CuVector +# end + +end # enzyme_ok +end # cuda_ok diff --git a/test/test-GPU-MALA.jl b/test/test-GPU-MALA.jl new file mode 100644 index 0000000..b21fc4e --- /dev/null +++ b/test/test-GPU-MALA.jl @@ -0,0 +1,241 @@ +using Test +using Random +using LinearAlgebra +using Statistics + +using ParallelMCMC +const MALA = ParallelMCMC.MALA + +import CUDA + +# Check if a real GPU is accessible by attempting a small allocation. +# CUDA.functional() only checks that the library loads, not that a device exists. +const CUDA_AVAILABLE = try + CUDA.CuArray([1f0]) + true +catch + false +end + +if !CUDA_AVAILABLE + @info "No CUDA GPU detected — skipping GPU tests." +else + +# Standard normal: logp = -0.5 ||x||², grad = -x +logp_batch(X) = vec(-0.5f0 .* sum(abs2, X; dims=1)) +gradlogp_batch(X) = -X + +# Scaled normal: dimension i has variance σᵢ² = i (so std = sqrt(i)) +# logp = -0.5 sum_i x_i²/i, grad_i = -x_i/i +function logp_scaled(X) + D = size(X, 1) + scales = CUDA.CuArray(Float32.(1:D)) # D-vector on GPU + return vec(-0.5f0 .* sum(X .^ 2 ./ scales; dims=1)) +end +function gradlogp_scaled(X) + D = size(X, 1) + scales = CUDA.CuArray(Float32.(1:D)) + return -X ./ scales +end + +@testset "GPU inputs stay on device" begin + D, N = 4, 32 + + X = CUDA.randn(Float32, D, N) + Ξ = CUDA.randn(Float32, D, N) + u = CUDA.rand(Float32, N) + + X_next, accepted = MALA.mala_step_batched(logp_batch, gradlogp_batch, X, 0.1f0, Ξ, u) + + @test X_next isa CUDA.CuArray + @test accepted isa CUDA.CuArray + @test size(X_next) == (D, N) + @test length(accepted) == N +end + +@testset "GPU eltype preserved (Float32 in → Float32 out)" begin + D, N = 5, 20 + X = CUDA.randn(Float32, D, N) + Ξ = CUDA.randn(Float32, D, N) + u = CUDA.rand(Float32, N) + + X_next, _ = MALA.mala_step_batched(logp_batch, gradlogp_batch, X, 0.1f0, Ξ, u) + + @test eltype(X_next) == Float32 +end + +@testset "GPU and CPU produce identical results (same seed)" begin + D, N = 3, 16 + + rng = MersenneTwister(42) + X_cpu = randn(rng, Float32, D, N) + Ξ_cpu = randn(rng, Float32, D, N) + u_cpu = rand(rng, Float32, N) + + X_gpu = CUDA.CuArray(X_cpu) + Ξ_gpu = CUDA.CuArray(Ξ_cpu) + u_gpu = CUDA.CuArray(u_cpu) + + X_next_cpu, acc_cpu = MALA.mala_step_batched(logp_batch, gradlogp_batch, X_cpu, 0.1f0, Ξ_cpu, u_cpu) + X_next_gpu, acc_gpu = MALA.mala_step_batched(logp_batch, gradlogp_batch, X_gpu, 0.1f0, Ξ_gpu, u_gpu) + + @test Array(X_next_gpu) ≈ X_next_cpu atol=1f-5 + @test Array(acc_gpu) == acc_cpu +end + +@testset "GPU single chain (N=1) matches scalar mala_step_full" begin + D = 4 + rng = MersenneTwister(7) + x = randn(rng, Float32, D) + ξ = randn(rng, Float32, D) + u = rand(rng, Float32) + + # Scalar step on CPU + logp_scalar = x -> -0.5f0 * dot(x, x) + grad_scalar = x -> -x + x_scalar, acc_scalar = MALA.mala_step_full(logp_scalar, grad_scalar, x, 0.1f0, ξ, u) + + # Batched N=1 on GPU + X_gpu = CUDA.CuArray(reshape(copy(x), D, 1)) + Ξ_gpu = CUDA.CuArray(reshape(copy(ξ), D, 1)) + u_gpu = CUDA.CuArray([u]) + + X_next_gpu, acc_gpu = MALA.mala_step_batched(logp_batch, gradlogp_batch, X_gpu, 0.1f0, Ξ_gpu, u_gpu) + + @test Array(vec(X_next_gpu)) ≈ x_scalar atol=1f-5 + @test Bool(Array(acc_gpu)[1]) == acc_scalar +end + +@testset "GPU rejected chains are unchanged" begin + D, N = 4, 64 + X = CUDA.randn(Float32, D, N) + Ξ = CUDA.randn(Float32, D, N) + # u very close to 1 ⟹ log(u) ≈ 0, forces rejection whenever logα ≤ 0. + # Some chains may still be accepted (logα > 0 when proposal lands at higher density), + # but for every rejected chain X_next must exactly equal X. + u = CUDA.fill(1f0 - 1f-6, N) + + X_next, accepted = MALA.mala_step_batched(logp_batch, gradlogp_batch, X, 0.1f0, Ξ, u) + + acc_cpu = Array(accepted) + X_cpu = Array(X) + Xn_cpu = Array(X_next) + rejected = .!acc_cpu + + # Rejected chains must be exactly unchanged. + @test Xn_cpu[:, rejected] == X_cpu[:, rejected] + # With u ≈ 1 most should be rejected (not a strict all-zero check). + @test sum(acc_cpu) < N +end + +@testset "GPU force all acceptances (u ≈ 0, tiny ε)" begin + D, N = 4, 64 + X = CUDA.randn(Float32, D, N) + Ξ = CUDA.randn(Float32, D, N) + # u very close to 0 ⟹ log(u) → -∞ < any logα ⟹ accept + u = CUDA.fill(1f-7, N) + + _, accepted = MALA.mala_step_batched(logp_batch, gradlogp_batch, X, 0.001f0, Ξ, u) + + @test sum(Array(accepted)) == N +end + +@testset "GPU acceptance rate in reasonable range" begin + # With ε=0.1 and standard normal, empirical acceptance rate should be + # well above 0 and below 1. + D, N, T = 5, 512, 200 + + n_accepted = 0 + X = CUDA.randn(Float32, D, N) + for _ in 1:T + Ξ = CUDA.randn(Float32, D, N) + u = CUDA.rand(Float32, N) + X, acc = MALA.mala_step_batched(logp_batch, gradlogp_batch, X, 0.1f0, Ξ, u) + n_accepted += sum(Array(acc)) + end + + rate = n_accepted / (T * N) + @test 0.3 < rate < 0.99 +end + +@testset "GPU stationary distribution (standard normal)" begin + D, N, T = 3, 512, 1_000 + + X = CUDA.randn(Float32, D, N) + for _ in 1:T + Ξ = CUDA.randn(Float32, D, N) + u = CUDA.rand(Float32, N) + X, _ = MALA.mala_step_batched(logp_batch, gradlogp_batch, X, 0.1f0, Ξ, u) + end + + X_cpu = Array(X) + @test maximum(abs.(vec(mean(X_cpu; dims=2)))) < 0.15 + @test maximum(abs.(vec(var(X_cpu; dims=2)) .- 1f0)) < 0.25 +end + +@testset "GPU stationary distribution (scaled normal)" begin + # Dimension i has variance i; after burn-in each row should have var ≈ i + D, N, T = 5, 512, 2_000 + + X = CUDA.randn(Float32, D, N) + for _ in 1:T + Ξ = CUDA.randn(Float32, D, N) + u = CUDA.rand(Float32, N) + X, _ = MALA.mala_step_batched(logp_scaled, gradlogp_scaled, X, 0.05f0, Ξ, u) + end + + X_cpu = Array(X) + vars = vec(var(X_cpu; dims=2)) # should ≈ [1, 2, 3, 4, 5] + target = Float32.(1:D) + + @test maximum(abs.(vec(mean(X_cpu; dims=2)))) < 0.3 + @test maximum(abs.(vars .- target) ./ target) < 0.3 # within 30% relative +end + +@testset "GPU logq_mala_batched matches CPU" begin + D, N = 4, 16 + rng = MersenneTwister(3) + Y_cpu = randn(rng, Float32, D, N) + X_cpu = randn(rng, Float32, D, N) + G_cpu = randn(rng, Float32, D, N) + + lq_cpu = MALA.logq_mala_batched(Y_cpu, X_cpu, G_cpu, 0.1f0) + + Y_gpu = CUDA.CuArray(Y_cpu) + X_gpu = CUDA.CuArray(X_cpu) + G_gpu = CUDA.CuArray(G_cpu) + + lq_gpu = MALA.logq_mala_batched(Y_gpu, X_gpu, G_gpu, 0.1f0) + + @test Array(lq_gpu) ≈ lq_cpu atol=1f-4 +end + +@testset "GPU large-scale (D=128, N=1024)" begin + # Smoke test: just verify it runs and returns correct shapes without OOM. + D, N = 128, 1024 + X = CUDA.randn(Float32, D, N) + Ξ = CUDA.randn(Float32, D, N) + u = CUDA.rand(Float32, N) + + X_next, accepted = MALA.mala_step_batched(logp_batch, gradlogp_batch, X, 0.05f0, Ξ, u) + + @test size(X_next) == (D, N) + @test length(accepted) == N + @test eltype(X_next) == Float32 +end + +@testset "GPU DimensionMismatch errors" begin + X = CUDA.randn(Float32, 3, 5) + Ξ_bad = CUDA.randn(Float32, 3, 4) + u_ok = CUDA.rand(Float32, 5) + u_bad = CUDA.rand(Float32, 4) + + @test_throws DimensionMismatch MALA.mala_step_batched( + logp_batch, gradlogp_batch, X, 0.1f0, Ξ_bad, u_ok, + ) + @test_throws DimensionMismatch MALA.mala_step_batched( + logp_batch, gradlogp_batch, X, 0.1f0, CUDA.randn(Float32, 3, 5), u_bad, + ) +end + +end # if CUDA_AVAILABLE