Skip to content

Conversation

@penelopeysm
Copy link

@penelopeysm penelopeysm commented Nov 13, 2025

You need to load several development versions, specifically

TuringLang/DynamicPPL.jl#1129 (which gives the recently discussed speedups)
TuringLang/Turing.jl#2713 (which adds compatibility for the DynamicPPL changes)
plus this PR of course!

(A fix for TuringLang/AbstractMCMC.jl#185 would likely make things faster by a constant amount too, but it's not needed for this)

Once you do, you get the following:

using Turing
import NUTS as NNUTS
using LinearAlgebra: UniformScaling

NNUTS.square(M::UniformScaling) = M * M

y = [28, 8, -3, 7, -1, 1, 18, 12]
sigma = [15, 10, 16, 11, 9, 11, 10, 18]
@model function eight_schools(y, sigma)
    mu ~ Normal(0, 5)
    tau ~ truncated(Cauchy(0, 5); lower=0)
    theta ~ MvNormal(fill(mu, length(sigma)), tau^2 * I)
    for i in eachindex(sigma)
        y[i] ~ Normal(theta[i], sigma[i])
    end
    return (mu=mu, tau=tau)
end

model = eight_schools(y, sigma)
adtype = AutoForwardDiff()
N = 1_000
burnin = 10_000
# Annoyingly Turing.Inference.NUTS uses `nadapts` for burn-in, whereas most other samplers
# would use `num_warmup`.
kwargs = (chain_type=Any, thinning=10, progress=false, verbose=false)

spl = Turing.Inference.NUTS(; adtype=adtype)
sample(model, spl, N; nadapts=burnin, kwargs...);
@time sample(model, spl, N; nadapts=burnin, kwargs...);
# 1.777308 seconds (39.71 M allocations: 4.818 GiB, 13.76% gc time)

spl = NNUTS.FastNUTS(0.05, adtype)
sample(model, spl, N; num_warmup=burnin, kwargs...);
@time sample(model, spl, N; num_warmup=burnin, kwargs...);
# 7.014555 seconds (78.27 M allocations: 17.637 GiB, 9.12% gc time)

spl = NNUTS.FastNUTS(1.0, adtype)
sample(model, spl, N; num_warmup=burnin, kwargs...);
@time sample(model, spl, N; num_warmup=burnin, kwargs...);
# 0.259686 seconds (6.16 M allocations: 612.315 MiB, 14.63% gc time)

(CmdStan via Python takes around 0.5 seconds.)

@penelopeysm
Copy link
Author

Oh, I realised the issue with stepsize=1.0, that gives rubbish results. Turing's NUTS implementation automatically determines stepsize=0.05, and that gives much better results, but is quite a bit slower.

@nsiccha
Copy link
Owner

nsiccha commented Nov 13, 2025

Oh, cool! Let me check it out, there are a few things that you have to be careful about.

@nsiccha
Copy link
Owner

nsiccha commented Nov 13, 2025

Do you mind also sharing the cmdstanpy code?

@penelopeysm
Copy link
Author

Of course:

data {
  int<lower=0> J; // number of schools
  array[J] real y; // estimated treatment
  array[J] real<lower=0> sigma; // std of estimated effect
}
parameters {
  array[J] real theta; // treatment effect in school j
  real mu; // hyper-parameter of mean
  real<lower=0> tau; // hyper-parameter of sdv
}
model {
  tau ~ cauchy(0, 5); // a non-informative prior
  theta ~ normal(mu, tau);
  y ~ normal(theta, sigma);
  mu ~ normal(0, 5);
}

and then:

from cmdstanpy import CmdStanModel, install_cmdstan
from pathlib import Path
import time

install_cmdstan()

DATA = {
    "y": [28, 8, -3, 7, -1, 1, 18, 12],
    "sigma": [15, 10, 16, 11, 9, 11, 10, 18],
    "J": 8,
}

def main():
    stan_file = Path(__file__).parent / "eight_schools_centered.stan"
    model = CmdStanModel(stan_file=stan_file)
    x = time.time()
    fit = model.sample(data=DATA, chains=1,
                       iter_warmup=10000, save_warmup=False,
                       iter_sampling=10000, thin=10)
    y = time.time()
    print(fit.summary())
    print(f"Time taken: {y - x} seconds")

if __name__ == "__main__":
    main()

@nsiccha
Copy link
Owner

nsiccha commented Nov 13, 2025

Great, thanks! I'll have a closer look tomorrow!

@nsiccha
Copy link
Owner

nsiccha commented Nov 14, 2025

Hey @penelopeysm, first of all, thanks for taking a stab at this, seeing the code you wrote is very helpful!

However, I fear there has been a misunderstanding about what NUTS.jl is supposed to provide - I blame whoever is responsible for the sampler interface in Turing being Turing.Inference.NUTS(; adtype=adtype)!

Do note by the way that I haven't registered NUTS.jl and am also not using its code in my WarmupHMC.jl package, because a) there's a small but non-negligible risk that there's something wrong in the implementation of NUTS, and b) any significant speed benefits would only materialize for very small/cheap models anyways, which is not the type of posteriors I'm targeting with WarmupHMC.jl anyways.


In any case, NUTS.nuts!! is not supposed to compete with Turing.Inference.NUTS, but with AdvancedHMC.jl's HMCKernel(Trajectory{MultinomialTS}(Leapfrog(stepsize), StrictGeneralisedNoUTurn())). Meaning, that NUTS.nuts!! is not suppoed to be doing any of the kernel hyperparameter tuning that's necessary to actually efficiently get draws from a posterior - this is WarmupHMC.jl's job. Instead, NUTS.nuts!! would rely on another package (e.g. AdvancedHMC.jl or WarmupHMC.jl) providing tuning of these hyperparameters.

The "proper" comparison is thus not with Turing.jl's sample method (which also does the hyperparameter tuning), but with AdvancedHMC.jl's transition, or alternatively with DynamicHMC.jl's sample_tree, or with Stan.jl's stan_sample for a fixed step size and mass matrix.

The below code sets up similar benchmarking as you did, but for this "proper" comparison. Performance has to be measured as "how many leapfrog steps can be performed per time", as the sampling algorithms should be identical. Differences in performance in this benchmark are (almost) fully due to different overheads in the method implmentation, neglecting any optimizations that the C++ compiler can perform for CmdStan because it's all C++ and transparent to it. This benchmark is now using the same underlying Stan model for (almost) all methods.

ENV["CMDSTAN"] = "/home/niko/.cmdstan/cmdstan-2.36.0"
using NUTS, LinearAlgebra, AdvancedHMC, BridgeStan, StanLogDensityProblems, Stan, StanBlocks, JSON, Random, DynamicHMC
# Stan.set_cmdstan_home!(ENV["CMDSTAN"])
NUTS.square(M::UniformScaling) = M * M
nuts_sample!(samples, rng, posterior; stepsize, position=randn(rng, size(samples, 1)), n_samples=size(samples, 2)) = begin
    state = (;rng, posterior, stepsize, position)
    n_leapfrog = 0
    for i in 1:n_samples
        state = nuts!!(state)
        samples[:, i] .= state.position
        n_leapfrog += state.n_leapfrog
    end
    n_leapfrog
end
faulty_nuts_sample!(samples, rng, posterior; stepsize, position=randn(rng, size(samples, 1)), n_samples=size(samples, 2)) = begin
    state = (;rng, posterior, stepsize, position)
    for i in 1:n_samples
        # After first use of nuts!!, optional properties, including preallocated working memory,
        # will be set in the returned `new_state` and reused in subsequent calls. 
        (;rng, posterior, stepsize, position) = nuts!!(state)
        state = (;rng, posterior, stepsize, position)
        samples[:, i] .= state.position
    end
    state
end
dynamichmc_sample!(samples, rng, posterior; stepsize, position=randn(rng, size(samples, 1)), n_samples=size(samples, 2)) = begin
    algorithm = DynamicHMC.NUTS()
    m = Diagonal(ones(size(samples, 1)))
    H = DynamicHMC.Hamiltonian(DynamicHMC.GaussianKineticEnergy(m, inv(m)), posterior)
    Q = DynamicHMC.evaluate_ℓ(posterior, position; strict=true)
    n_leapfrog = 0
    for i in 1:n_samples
        Q, stats = DynamicHMC.sample_tree(rng, algorithm, H, Q, stepsize)
        samples[:, i] .= Q.q
        n_leapfrog += stats.steps
    end
    n_leapfrog
end
advancedhmc_sample!(samples, rng, posterior; stepsize, position=randn(rng, size(samples, 1)), n_samples=size(samples, 2)) = begin
    d = size(samples, 1)
    h = AdvancedHMC.Hamiltonian(UnitEuclideanMetric(d), posterior)
    kernel = HMCKernel(Trajectory{MultinomialTS}(Leapfrog(stepsize), StrictGeneralisedNoUTurn()))
    z = AdvancedHMC.phasepoint(rng, position, h) 
    n_leapfrog = 0
    for i in 1:n_samples
        (;stat, z) = AdvancedHMC.transition(rng, h, kernel, z)
        samples[:, i] .= z.θ
        n_leapfrog += stat.n_steps
    end
    n_leapfrog
end
stan_sample!(samples, rng, sm, data; stepsize, position=randn(rng, size(samples, 1)), n_samples=size(samples, 2)) = Stan.stan_sample(
    sm; data, num_chains=1, num_samples=n_samples, num_warmups=0, engaged=false, stepsize, init=Dict("mu"=>position[1], "tau"=>exp(position[2]), "theta_xi"=>position[3:end]), seed=1
)
function NUTS.log_density_gradient!(
    prob::StanProblem{M,nan_on_error}, x, g
) where {M,nan_on_error}
    m = prob.model
    z = convert(Vector{Float64}, x)
    try
        return BridgeStan.log_density_gradient!(m, z, g)[1]
    catch
        nan_on_error || rethrow()
        g .= NaN
        return NaN
    end
end


begin
    myprint = Base.Fix1(println, "Total number of leapfrog steps: ")
    nc_eight_schools = @slic begin 
        mu ~ normal(0, 5)
        tau ~ cauchy(0, 5;lower=0)
        theta_xi ~ std_normal(;n=J)
        y ~ normal(mu + theta_xi * tau, sigma)
    end
    y = [28, 8, -3, 7, -1, 1, 18, 12] .|> Float64
    sigma = [15, 10, 16, 11, 9, 11, 10, 18] .|> Float64
    sm = (@isdefined sm) ? sm : Stan.SampleModel("nc_eight_schools", stan_code(nc_eight_schools(;y, sigma, J=length(y))))
    tmp = stan_code(nc_eight_schools(;y, sigma, J=length(y)))
    sm_no_gq = (@isdefined sm_no_gq) ? sm_no_gq : Stan.SampleModel("nc_eight_schools_no_gq", tmp[1:findfirst("generated", tmp)[1]-1])
    n_samples = 1000
    samples = zeros((10, n_samples))
    data = Dict("y"=>y, "sigma"=>sigma, "y_n"=>length(y), "sigma_n"=>length(sigma), "J"=>length(y))
    nc_eight_schools_stan = stan_instantiate(nc_eight_schools(;y, sigma, J=length(y)))
    stepsize = .05
    @info "NUTS.jl"
    nuts_sample!(samples, Xoshiro(1), nc_eight_schools_stan; stepsize) |> myprint
    @time nuts_sample!(samples, Xoshiro(1), nc_eight_schools_stan; stepsize)
    @info "Faulty NUTS.jl (doesn't actually seem to make a big difference)"
    faulty_nuts_sample!(samples, Xoshiro(1), nc_eight_schools_stan; stepsize)
    myprint("Same as before")
    @time faulty_nuts_sample!(samples, Xoshiro(1), nc_eight_schools_stan; stepsize)
    @info "DynamicHMC.jl"
    dynamichmc_sample!(samples, Xoshiro(1), nc_eight_schools_stan; stepsize) |> myprint
    @time dynamichmc_sample!(samples, Xoshiro(1), nc_eight_schools_stan; stepsize)
    @info "AdvancedHMC.jl"
    advancedhmc_sample!(samples, Xoshiro(1), nc_eight_schools_stan; stepsize) |> myprint
    @time advancedhmc_sample!(samples, Xoshiro(1), nc_eight_schools_stan; stepsize)
    @info "CmdStan via Stan.jl (including generated quantities)"
    stan_sample!(samples, Xoshiro(1), sm, data; stepsize)
    Stan.read_samples(sm, :dataframe; include_internals=true).n_leapfrog__ |> sum |> myprint
    @time stan_sample!(samples, Xoshiro(1), sm, data; stepsize)
    @info "CmdStan via Stan.jl (without generated quantities)"
    stan_sample!(samples, Xoshiro(1), sm_no_gq, data; stepsize)
    Stan.read_samples(sm_no_gq, :dataframe; include_internals=true).n_leapfrog__ |> sum |> myprint
    @time stan_sample!(samples, Xoshiro(1), sm_no_gq, data; stepsize)
    nothing
end

The output that I get is the following:

[ Info: NUTS.jl
Total number of leapfrog steps: 82760
  0.113399 seconds (83.87 k allocations: 1.492 MiB)
[ Info: Faulty NUTS.jl (doesn't actually seem to make a big difference)
Total number of leapfrog steps: Same as before
  0.127954 seconds (183.77 k allocations: 15.866 MiB)
[ Info: DynamicHMC.jl
Total number of leapfrog steps: 83096
  0.242021 seconds (1.56 M allocations: 193.785 MiB)
[ Info: AdvancedHMC.jl
Total number of leapfrog steps: 82328
  0.303216 seconds (1.65 M allocations: 216.637 MiB, 14.31% gc time)
[ Info: CmdStan via Stan.jl (including generated quantities)
Total number of leapfrog steps: 82104
  0.119980 seconds (651 allocations: 48.992 KiB)
[ Info: CmdStan via Stan.jl (without generated quantities)
Total number of leapfrog steps: 82104
  0.122649 seconds (651 allocations: 48.992 KiB)

I.e. for this posterior, NUTS.jl is as fast as CmdStan via Stan.jl, and DynamicHMC.jl is roughly twice as slow, and AdvancedHMC.jl is roughly three times as slow.

Output from ] status (I've removed Turing/DynamicPPL because it didn't play nice with Pathfinder, which I needed for the second part below):

Status `~/github/penelopeysm/NUTS.jl/bench/Project.toml`
  [0bf59076] AdvancedHMC v0.8.3
  [c88b6f0a] BridgeStan v2.7.0
  [a93c6f00] DataFrames v1.8.1
  [bbc10e6e] DynamicHMC v3.5.1
⌅ [682c06a0] JSON v0.21.4
  [6fdf6af0] LogDensityProblems v2.2.0
  [be115224] MCMCDiagnosticTools v0.3.15
  [315f5978] NUTS v0.1.0 `..`
  [b1d3bc72] Pathfinder v0.9.27 ⚲
  [682df890] Stan v10.8.1
  [2e771a56] StanBlocks v0.1.4 `https://github.com/nsiccha/StanBlocks.jl#main`
  [a545de4d] StanLogDensityProblems v0.1.10
  [22787eb5] Term v2.0.7
  [60658175] WarmupHMC v0.2.0 `https://github.com/nsiccha/WarmupHMC.jl#main` ⚲
  [9a3f8284] Random

@nsiccha
Copy link
Owner

nsiccha commented Nov 14, 2025

And finally, if you want to compare WarmupHMC.jl against Turing.jl's "full" sampling functionality via sample, including the way it tunes the kernel hyperparameters, the below code does that (via AdvancedHMC.jl, because I can't have the proper versions of Turing.jl and WarmupHMC.jl in the same project).

Of course, because now we do not simply care about wall time, or the number of leapfrog steps that can be performed per time, we have to be a bit more careful about the benchmarking/comparison.

For example, you can't really do it properly for the centered version of the eight schools model, because all of the options would give you the wrong answer, and it doesn't really matter how efficiently/quicky they can provide you with a wrong answer - they aren't LLMs after all.

The evaluation metric now could reasonably be the effective sample size per amount of work, where amount of work can of course be measured as wall time (relevant to the user, but depends on the efficiency of the implementation, especially for small models) or as number of leapfrog steps required (relevant to the user only if the cost of evaluating the log density gradient can be expected to dominate the wall time for the posteriors the user cares about).

As I'm not targeting small/cheap models, I personally don't really care about the wall time for small models such as the eight schools model - though I am a bit surprised by the amount of overhead WarmupHMC.jl introduces 😅

    using WarmupHMC, Term, DataFrames, MCMCDiagnosticTools, Statistics
    rows = map(1:10) do seed
        stan_time = @elapsed Stan.stan_sample(sm_no_gq; data, num_chains=1, num_samples=n_samples, num_warmups=n_samples, save_warmup=true, engaged=true, seed)
        df = Stan.read_samples(sm_no_gq, :dataframe; include_internals=true)
        stan_steps = sum(df.n_leapfrog__)
        warmuphmc_time = @elapsed warmuphmc_rv = WarmupHMC.adaptive_warmup_mcmc(Xoshiro(seed), nc_eight_schools_stan)
        warmuphmc_steps = warmuphmc_rv.total_evaluation_counter
        warmuphmc_ess = minimum(ess(reshape(warmuphmc_rv.posterior_position', (n_samples, 1, :))))
        warmuphmc_eff = 1e3 * warmuphmc_ess / warmuphmc_steps
        advancedhmc_time = @elapsed advancedhmc_rv = AdvancedHMC.sample(Xoshiro(seed), nc_eight_schools_stan, AdvancedHMC.NUTS(.8), 2*n_samples, n_adapts=n_samples, progress=false)
        advancedhmc_steps = sum(x->x.stat.n_steps, advancedhmc_rv)
        advancedhmc_ess = minimum(ess(reshape(mapreduce(x->x.z.θ, hcat, advancedhmc_rv[n_samples+1:end])', (n_samples, 1, :))))
        advancedhmc_eff = 1e3 * advancedhmc_ess / advancedhmc_steps
        (;
            stan_time, 
            stan_steps, 
            warmuphmc_time,
            warmuphmc_steps,
            warmuphmc_ess,
            warmuphmc_eff,
            advancedhmc_time,
            advancedhmc_steps,
            advancedhmc_ess,
            advancedhmc_eff,
        )
    end
    df = DataFrame(rows)
    mapcols(median, df)

This outputs for me:

 Row │ stan_time  stan_steps  warmuphmc_time  warmuphmc_steps  warmuphmc_ess  warmuphmc_eff  advancedhmc_time  advancedhmc_steps  advancedhmc_ess  advancedhmc_eff 
     │ Float64    Float64     Float64         Float64          Float64        Float64        Float64           Float64            Float64          Float64         
─────┼─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
   1 │ 0.0455723     18037.5        0.112125          12060.0        552.504        25.1158          0.057312            19406.5          578.374          28.8559

I.e. for this posterior Stan's implementation is most efficient in terms of wall time, followed closely by AdvancedHMC.jl, with WarmupHMC.jl being almost twice as slow.
In terms of the number of leapfrog steps that had to be performed, Stan and AdvancedHMC.jl should actually be identical (as AFAIK their adaptation strategies are identical), but Stan appears to be slightly better - maybe this is only due to inherent statistical noise and the small sample size. WarmupHMC.jl does take significantly fewer steps than Stan or AdvancedHMC.jl, but this doesn't seem to result in a higher sampling efficiency, probably because the ESS estimates for this posterior are lower for WarmupHMC.jl's method.

WarmupHMC.jl performing slightly worse for this posterior is okay to me. It may be in part due to the hyperparameters passed to WarmupHMC.jl aren't great for this posterior, but that's alright - they can't be perfect for all posteriors after all!

@nsiccha
Copy link
Owner

nsiccha commented Nov 14, 2025

If you want to have an example where WarmupHMC.jl outperforms Stan or AdvancedHMC.jl due to algorithmic improvements, you can add PosteriorDB.jl to the project and run the following code:

    using PosteriorDB
    if !@isdefined pdb 
        const pdb = PosteriorDB.database()
    end
    stan_problem(path, data) = StanProblem(
        path, data;
        nan_on_error=true,
        make_args=["STAN_THREADS=TRUE"],
        warn=false
    )
    stan_problem(posterior_name::AbstractString) = stan_problem(
        PosteriorDB.path(PosteriorDB.implementation(PosteriorDB.model(PosteriorDB.posterior(pdb, (posterior_name))), "stan")), 
        PosteriorDB.load(PosteriorDB.dataset(PosteriorDB.posterior(pdb, (posterior_name))), String)
    )
    posterior_name = "diamonds-diamonds"
    @info posterior_name
    posterior = stan_problem(posterior_name)
    # A single run is sufficient because AdvancedHMC.jl (always) does so badly and I can't be bothered to wait for several minutes.
    rows = map(1:1) do seed
        warmuphmc_time = @elapsed warmuphmc_rv = WarmupHMC.adaptive_warmup_mcmc(Xoshiro(seed), posterior)
        warmuphmc_steps = warmuphmc_rv.total_evaluation_counter
        warmuphmc_ess = minimum(ess(reshape(warmuphmc_rv.posterior_position', (n_samples, 1, :))))
        warmuphmc_eff = 1e3 * warmuphmc_ess / warmuphmc_steps
        advancedhmc_time = @elapsed advancedhmc_rv = AdvancedHMC.sample(Xoshiro(seed), posterior, AdvancedHMC.NUTS(.8), 2*n_samples, n_adapts=n_samples, progress=false)
        advancedhmc_steps = sum(x->x.stat.n_steps, advancedhmc_rv)
        advancedhmc_ess = minimum(ess(reshape(mapreduce(x->x.z.θ, hcat, advancedhmc_rv[n_samples+1:end])', (n_samples, 1, :))))
        advancedhmc_eff = 1e3 * advancedhmc_ess / advancedhmc_steps
        (;
            warmuphmc_time,
            warmuphmc_steps,
            warmuphmc_ess,
            warmuphmc_eff,
            advancedhmc_time,
            advancedhmc_steps,
            advancedhmc_ess,
            advancedhmc_eff,
        )
    end
    df = DataFrame(rows)
    mapcols(median, df)

which outputs for me:

[ Info: diamonds-diamonds
 Row │ warmuphmc_time  warmuphmc_steps  warmuphmc_ess  warmuphmc_eff  advancedhmc_time  advancedhmc_steps  advancedhmc_ess  advancedhmc_eff 
     │ Float64         Float64          Float64        Float64        Float64           Float64            Float64          Float64         
─────┼──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
   1 │        2.78112          23620.0        663.724        28.1001           104.118          1.64197e6          330.432         0.201242

I.e. for that posterior, WarmupHMC.jl is more than 30 times as fast (in terms of wall time) and more than 140 times as efficient (in terms of ESS per leapfrog step).

Edit: This particular posterior is known to have strong correlations between its parameters, which is why AdvancedHMC.jl's default hyperparameter tuning is so inefficient - I'd think that using a dense mass matrix would help a bit - or considerably).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants