Conversation
…e.jl. Minor tweaks to tests/erro handling
…perability with Turing/MCMCChains/AbstracTMCMC
There was a problem hiding this comment.
Pull request overview
This PR expands ParallelMCMC’s MALA implementation to support batched/GPU execution and adds a fuller AbstractMCMC-compatible public interface (including Adaptive MALA and LogDensityProblems integration), along with a broad new test suite.
Changes:
- Add mass-matrix (Cholesky) support, plus batched
mala_step_batched/logq_mala_batchedintended to work on GPU inputs. - Introduce AbstractMCMC interface types (
DensityModel, samplers/states/transitions) and implementAbstractMCMC.step/bundle_samplesfor MALA, Adaptive MALA, and DEER. - Add LogDensityProblems extension + extensive tests (CPU + conditional CUDA GPU tests) and update test dependencies.
Reviewed changes
Copilot reviewed 13 out of 14 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
src/ParallelMCMC.jl |
Wires in the new interface layer, adjusts imports, and exports the new public API. |
src/MALA/MALA.jl |
Adds mass-matrix support, efficient full-step API, and batched/GPU-oriented MALA kernels. |
src/interface.jl |
Implements AbstractMCMC model/sampler/state/transition types and sampling/bundling behavior. |
src/DEER/DEER.jl |
Adds configurable AD backend plumbing and a damping update simplification. |
ext/LogDensityProblemsExt.jl |
Adds DensityModel(ld) constructor via Julia’s extension mechanism. |
Project.toml |
Updates dependencies/weakdeps/extensions and compat entries to support the new features. |
test/Project.toml |
Adds CUDA/MCMCChains/Statistics to the test environment. |
test/LocalPreferences.toml |
Forces a local CUDA runtime preference for the test environment. |
test/test-GPU-MALA.jl |
Adds conditional CUDA tests validating device residency and correctness properties. |
test/test-Batched-MALA.jl |
Adds CPU tests for batched MALA shapes, rejection behavior, and stationarity. |
test/test-Adaptive-MALA.jl |
Adds tests for dual-averaging adaptation and Chains output internals. |
test/test-AbstractMCMC-Interface.jl |
Adds end-to-end interface tests for step/sample/Chains bundling. |
test/test-DEER-Interface.jl |
Adds end-to-end interface tests for DEER construction and sampling. |
test/test-Deer-vs-MALA.jl |
Minor change within existing test logic. |
Comments suppressed due to low confidence (1)
src/interface.jl:609
DensityModelis already a concrete type defined earlier in this module. Definingfunction DensityModel endattempts to rebind the same name to a generic function, which will error at load time ("invalid redefinition of constant DensityModel"). Remove this line; the extension can addParallelMCMC.DensityModel(::Any)methods without a placeholder.
This method is defined in the `LogDensityProblemsExt` extension and is only
available when `LogDensityProblems` has been loaded.
"""
function DensityModel end # extended by LogDensityProblemsExt
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| μ = x .+ ϵ .* _apply_M(gradlogp_x, cholM) | ||
| d = length(x) | ||
| r = y .- μ | ||
| return -0.5 * dot(r, r) / (2ϵ) - (d / 2) * log(4π * ϵ) | ||
| return -0.5 * _quad_Minv(r, cholM) / (2ϵ) - (d / 2) * log(4π * ϵ) - 0.5 * _logdet_M(cholM) | ||
| end |
There was a problem hiding this comment.
logq_mala’s constant terms (0.0, -0.5, 4π, etc.) force Float64 arithmetic even when x/ϵ are Float32 (notably on GPU), which can significantly slow kernels and undermine the goal of staying in Float32. Consider making constants type-stable w.r.t. ϵ/eltype(x) (e.g., use T = typeof(ϵ) and construct T(4)*T(π) / zero(T) / T(0.5) so the result stays in Float32 when appropriate).
src/MALA/MALA.jl
Outdated
| μ = X .+ ε .* _apply_M_batched(gradlogp_X, cholM) | ||
| R = Y .- μ | ||
| q = _quad_Minv_batched(R, cholM) | ||
| ldet = _logdet_M(cholM) | ||
| return @. -0.5 * q / (2ε) - (D / 2) * log(4π * ε) - 0.5 * ldet |
There was a problem hiding this comment.
logq_mala_batched will promote to Float64 for Float32 inputs because of Float64 literals (-0.5, 4π, and _logdet_M(::Nothing)=0.0). On GPU this can be a large performance hit. Consider rewriting the expression to keep computations in eltype(X)/typeof(ε) (typed constants and zero(T)), so batched GPU paths remain Float32 throughout.
| μ = X .+ ε .* _apply_M_batched(gradlogp_X, cholM) | |
| R = Y .- μ | |
| q = _quad_Minv_batched(R, cholM) | |
| ldet = _logdet_M(cholM) | |
| return @. -0.5 * q / (2ε) - (D / 2) * log(4π * ε) - 0.5 * ldet | |
| T = typeof(ε) | |
| εT = T(ε) | |
| μ = X .+ εT .* _apply_M_batched(gradlogp_X, cholM) | |
| R = Y .- μ | |
| q = _quad_Minv_batched(R, cholM) | |
| ldet = _logdet_M(cholM) | |
| half = T(-0.5) | |
| two = T(2) | |
| fourπ = T(4π) | |
| D_over_2 = T(D) / two | |
| qT = T.(q) | |
| ldetT = T(ldet) | |
| return @. half * qT / (two * εT) - D_over_2 * log(fourπ * εT) - half * ldetT |
| function logp_scaled(X) | ||
| D = size(X, 1) | ||
| scales = CUDA.CuArray(Float32.(1:D)) # D-vector on GPU | ||
| return vec(-0.5f0 .* sum(X .^ 2 ./ scales; dims=1)) | ||
| end | ||
| function gradlogp_scaled(X) | ||
| D = size(X, 1) | ||
| scales = CUDA.CuArray(Float32.(1:D)) |
There was a problem hiding this comment.
logp_scaled/gradlogp_scaled allocate a new scales CuArray on every call. In the stationary-distribution loop this happens thousands of times and can dominate runtime. Consider precomputing scales once (outside the functions, or by closing over it) and reusing it.
| function logp_scaled(X) | |
| D = size(X, 1) | |
| scales = CUDA.CuArray(Float32.(1:D)) # D-vector on GPU | |
| return vec(-0.5f0 .* sum(X .^ 2 ./ scales; dims=1)) | |
| end | |
| function gradlogp_scaled(X) | |
| D = size(X, 1) | |
| scales = CUDA.CuArray(Float32.(1:D)) | |
| const _scaled_normal_scales_cache = Dict{Int, CUDA.CuArray{Float32,1}}() | |
| # Retrieve or create the GPU scales vector for a given dimension D. | |
| # This avoids repeated CuArray allocations inside tight loops. | |
| function _get_scaled_normal_scales(D::Int) | |
| scales = get(_scaled_normal_scales_cache, D, nothing) | |
| if scales === nothing || length(scales) != D | |
| scales = CUDA.CuArray(Float32.(1:D)) # D-vector on GPU | |
| _scaled_normal_scales_cache[D] = scales | |
| end | |
| return scales | |
| end | |
| function logp_scaled(X) | |
| D = size(X, 1) | |
| scales = _get_scaled_normal_scales(D) | |
| return vec(-0.5f0 .* sum(X .^ 2 ./ scales; dims=1)) | |
| end | |
| function gradlogp_scaled(X) | |
| D = size(X, 1) | |
| scales = _get_scaled_normal_scales(D) |
Welcome to Codecov 🎉Once you merge this PR into your default branch, you're all set! Codecov will compare coverage reports and display results in all future pull requests. Thanks for integrating Codecov - We've got you covered ☂️ |
Adds proper GPU support