-
Notifications
You must be signed in to change notification settings - Fork 1
Turing compatibility #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Oh, I realised the issue with stepsize=1.0, that gives rubbish results. Turing's NUTS implementation automatically determines stepsize=0.05, and that gives much better results, but is quite a bit slower. |
|
Oh, cool! Let me check it out, there are a few things that you have to be careful about. |
|
Do you mind also sharing the cmdstanpy code? |
|
Of course: data {
int<lower=0> J; // number of schools
array[J] real y; // estimated treatment
array[J] real<lower=0> sigma; // std of estimated effect
}
parameters {
array[J] real theta; // treatment effect in school j
real mu; // hyper-parameter of mean
real<lower=0> tau; // hyper-parameter of sdv
}
model {
tau ~ cauchy(0, 5); // a non-informative prior
theta ~ normal(mu, tau);
y ~ normal(theta, sigma);
mu ~ normal(0, 5);
}and then: from cmdstanpy import CmdStanModel, install_cmdstan
from pathlib import Path
import time
install_cmdstan()
DATA = {
"y": [28, 8, -3, 7, -1, 1, 18, 12],
"sigma": [15, 10, 16, 11, 9, 11, 10, 18],
"J": 8,
}
def main():
stan_file = Path(__file__).parent / "eight_schools_centered.stan"
model = CmdStanModel(stan_file=stan_file)
x = time.time()
fit = model.sample(data=DATA, chains=1,
iter_warmup=10000, save_warmup=False,
iter_sampling=10000, thin=10)
y = time.time()
print(fit.summary())
print(f"Time taken: {y - x} seconds")
if __name__ == "__main__":
main() |
|
Great, thanks! I'll have a closer look tomorrow! |
|
Hey @penelopeysm, first of all, thanks for taking a stab at this, seeing the code you wrote is very helpful! However, I fear there has been a misunderstanding about what NUTS.jl is supposed to provide - I blame whoever is responsible for the sampler interface in Turing being Do note by the way that I haven't registered NUTS.jl and am also not using its code in my WarmupHMC.jl package, because a) there's a small but non-negligible risk that there's something wrong in the implementation of NUTS, and b) any significant speed benefits would only materialize for very small/cheap models anyways, which is not the type of posteriors I'm targeting with WarmupHMC.jl anyways. In any case, The "proper" comparison is thus not with Turing.jl's The below code sets up similar benchmarking as you did, but for this "proper" comparison. Performance has to be measured as "how many leapfrog steps can be performed per time", as the sampling algorithms should be identical. Differences in performance in this benchmark are (almost) fully due to different overheads in the method implmentation, neglecting any optimizations that the C++ compiler can perform for CmdStan because it's all C++ and transparent to it. This benchmark is now using the same underlying Stan model for (almost) all methods. ENV["CMDSTAN"] = "/home/niko/.cmdstan/cmdstan-2.36.0"
using NUTS, LinearAlgebra, AdvancedHMC, BridgeStan, StanLogDensityProblems, Stan, StanBlocks, JSON, Random, DynamicHMC
# Stan.set_cmdstan_home!(ENV["CMDSTAN"])
NUTS.square(M::UniformScaling) = M * M
nuts_sample!(samples, rng, posterior; stepsize, position=randn(rng, size(samples, 1)), n_samples=size(samples, 2)) = begin
state = (;rng, posterior, stepsize, position)
n_leapfrog = 0
for i in 1:n_samples
state = nuts!!(state)
samples[:, i] .= state.position
n_leapfrog += state.n_leapfrog
end
n_leapfrog
end
faulty_nuts_sample!(samples, rng, posterior; stepsize, position=randn(rng, size(samples, 1)), n_samples=size(samples, 2)) = begin
state = (;rng, posterior, stepsize, position)
for i in 1:n_samples
# After first use of nuts!!, optional properties, including preallocated working memory,
# will be set in the returned `new_state` and reused in subsequent calls.
(;rng, posterior, stepsize, position) = nuts!!(state)
state = (;rng, posterior, stepsize, position)
samples[:, i] .= state.position
end
state
end
dynamichmc_sample!(samples, rng, posterior; stepsize, position=randn(rng, size(samples, 1)), n_samples=size(samples, 2)) = begin
algorithm = DynamicHMC.NUTS()
m = Diagonal(ones(size(samples, 1)))
H = DynamicHMC.Hamiltonian(DynamicHMC.GaussianKineticEnergy(m, inv(m)), posterior)
Q = DynamicHMC.evaluate_ℓ(posterior, position; strict=true)
n_leapfrog = 0
for i in 1:n_samples
Q, stats = DynamicHMC.sample_tree(rng, algorithm, H, Q, stepsize)
samples[:, i] .= Q.q
n_leapfrog += stats.steps
end
n_leapfrog
end
advancedhmc_sample!(samples, rng, posterior; stepsize, position=randn(rng, size(samples, 1)), n_samples=size(samples, 2)) = begin
d = size(samples, 1)
h = AdvancedHMC.Hamiltonian(UnitEuclideanMetric(d), posterior)
kernel = HMCKernel(Trajectory{MultinomialTS}(Leapfrog(stepsize), StrictGeneralisedNoUTurn()))
z = AdvancedHMC.phasepoint(rng, position, h)
n_leapfrog = 0
for i in 1:n_samples
(;stat, z) = AdvancedHMC.transition(rng, h, kernel, z)
samples[:, i] .= z.θ
n_leapfrog += stat.n_steps
end
n_leapfrog
end
stan_sample!(samples, rng, sm, data; stepsize, position=randn(rng, size(samples, 1)), n_samples=size(samples, 2)) = Stan.stan_sample(
sm; data, num_chains=1, num_samples=n_samples, num_warmups=0, engaged=false, stepsize, init=Dict("mu"=>position[1], "tau"=>exp(position[2]), "theta_xi"=>position[3:end]), seed=1
)
function NUTS.log_density_gradient!(
prob::StanProblem{M,nan_on_error}, x, g
) where {M,nan_on_error}
m = prob.model
z = convert(Vector{Float64}, x)
try
return BridgeStan.log_density_gradient!(m, z, g)[1]
catch
nan_on_error || rethrow()
g .= NaN
return NaN
end
end
begin
myprint = Base.Fix1(println, "Total number of leapfrog steps: ")
nc_eight_schools = @slic begin
mu ~ normal(0, 5)
tau ~ cauchy(0, 5;lower=0)
theta_xi ~ std_normal(;n=J)
y ~ normal(mu + theta_xi * tau, sigma)
end
y = [28, 8, -3, 7, -1, 1, 18, 12] .|> Float64
sigma = [15, 10, 16, 11, 9, 11, 10, 18] .|> Float64
sm = (@isdefined sm) ? sm : Stan.SampleModel("nc_eight_schools", stan_code(nc_eight_schools(;y, sigma, J=length(y))))
tmp = stan_code(nc_eight_schools(;y, sigma, J=length(y)))
sm_no_gq = (@isdefined sm_no_gq) ? sm_no_gq : Stan.SampleModel("nc_eight_schools_no_gq", tmp[1:findfirst("generated", tmp)[1]-1])
n_samples = 1000
samples = zeros((10, n_samples))
data = Dict("y"=>y, "sigma"=>sigma, "y_n"=>length(y), "sigma_n"=>length(sigma), "J"=>length(y))
nc_eight_schools_stan = stan_instantiate(nc_eight_schools(;y, sigma, J=length(y)))
stepsize = .05
@info "NUTS.jl"
nuts_sample!(samples, Xoshiro(1), nc_eight_schools_stan; stepsize) |> myprint
@time nuts_sample!(samples, Xoshiro(1), nc_eight_schools_stan; stepsize)
@info "Faulty NUTS.jl (doesn't actually seem to make a big difference)"
faulty_nuts_sample!(samples, Xoshiro(1), nc_eight_schools_stan; stepsize)
myprint("Same as before")
@time faulty_nuts_sample!(samples, Xoshiro(1), nc_eight_schools_stan; stepsize)
@info "DynamicHMC.jl"
dynamichmc_sample!(samples, Xoshiro(1), nc_eight_schools_stan; stepsize) |> myprint
@time dynamichmc_sample!(samples, Xoshiro(1), nc_eight_schools_stan; stepsize)
@info "AdvancedHMC.jl"
advancedhmc_sample!(samples, Xoshiro(1), nc_eight_schools_stan; stepsize) |> myprint
@time advancedhmc_sample!(samples, Xoshiro(1), nc_eight_schools_stan; stepsize)
@info "CmdStan via Stan.jl (including generated quantities)"
stan_sample!(samples, Xoshiro(1), sm, data; stepsize)
Stan.read_samples(sm, :dataframe; include_internals=true).n_leapfrog__ |> sum |> myprint
@time stan_sample!(samples, Xoshiro(1), sm, data; stepsize)
@info "CmdStan via Stan.jl (without generated quantities)"
stan_sample!(samples, Xoshiro(1), sm_no_gq, data; stepsize)
Stan.read_samples(sm_no_gq, :dataframe; include_internals=true).n_leapfrog__ |> sum |> myprint
@time stan_sample!(samples, Xoshiro(1), sm_no_gq, data; stepsize)
nothing
endThe output that I get is the following: I.e. for this posterior, NUTS.jl is as fast as CmdStan via Stan.jl, and DynamicHMC.jl is roughly twice as slow, and AdvancedHMC.jl is roughly three times as slow. Output from |
|
And finally, if you want to compare WarmupHMC.jl against Turing.jl's "full" sampling functionality via Of course, because now we do not simply care about wall time, or the number of leapfrog steps that can be performed per time, we have to be a bit more careful about the benchmarking/comparison. For example, you can't really do it properly for the centered version of the eight schools model, because all of the options would give you the wrong answer, and it doesn't really matter how efficiently/quicky they can provide you with a wrong answer - they aren't LLMs after all. The evaluation metric now could reasonably be the effective sample size per amount of work, where amount of work can of course be measured as wall time (relevant to the user, but depends on the efficiency of the implementation, especially for small models) or as number of leapfrog steps required (relevant to the user only if the cost of evaluating the log density gradient can be expected to dominate the wall time for the posteriors the user cares about). As I'm not targeting small/cheap models, I personally don't really care about the wall time for small models such as the eight schools model - though I am a bit surprised by the amount of overhead WarmupHMC.jl introduces 😅 using WarmupHMC, Term, DataFrames, MCMCDiagnosticTools, Statistics
rows = map(1:10) do seed
stan_time = @elapsed Stan.stan_sample(sm_no_gq; data, num_chains=1, num_samples=n_samples, num_warmups=n_samples, save_warmup=true, engaged=true, seed)
df = Stan.read_samples(sm_no_gq, :dataframe; include_internals=true)
stan_steps = sum(df.n_leapfrog__)
warmuphmc_time = @elapsed warmuphmc_rv = WarmupHMC.adaptive_warmup_mcmc(Xoshiro(seed), nc_eight_schools_stan)
warmuphmc_steps = warmuphmc_rv.total_evaluation_counter
warmuphmc_ess = minimum(ess(reshape(warmuphmc_rv.posterior_position', (n_samples, 1, :))))
warmuphmc_eff = 1e3 * warmuphmc_ess / warmuphmc_steps
advancedhmc_time = @elapsed advancedhmc_rv = AdvancedHMC.sample(Xoshiro(seed), nc_eight_schools_stan, AdvancedHMC.NUTS(.8), 2*n_samples, n_adapts=n_samples, progress=false)
advancedhmc_steps = sum(x->x.stat.n_steps, advancedhmc_rv)
advancedhmc_ess = minimum(ess(reshape(mapreduce(x->x.z.θ, hcat, advancedhmc_rv[n_samples+1:end])', (n_samples, 1, :))))
advancedhmc_eff = 1e3 * advancedhmc_ess / advancedhmc_steps
(;
stan_time,
stan_steps,
warmuphmc_time,
warmuphmc_steps,
warmuphmc_ess,
warmuphmc_eff,
advancedhmc_time,
advancedhmc_steps,
advancedhmc_ess,
advancedhmc_eff,
)
end
df = DataFrame(rows)
mapcols(median, df)This outputs for me: I.e. for this posterior Stan's implementation is most efficient in terms of wall time, followed closely by AdvancedHMC.jl, with WarmupHMC.jl being almost twice as slow. WarmupHMC.jl performing slightly worse for this posterior is okay to me. It may be in part due to the hyperparameters passed to WarmupHMC.jl aren't great for this posterior, but that's alright - they can't be perfect for all posteriors after all! |
|
If you want to have an example where WarmupHMC.jl outperforms Stan or AdvancedHMC.jl due to algorithmic improvements, you can add PosteriorDB.jl to the project and run the following code: using PosteriorDB
if !@isdefined pdb
const pdb = PosteriorDB.database()
end
stan_problem(path, data) = StanProblem(
path, data;
nan_on_error=true,
make_args=["STAN_THREADS=TRUE"],
warn=false
)
stan_problem(posterior_name::AbstractString) = stan_problem(
PosteriorDB.path(PosteriorDB.implementation(PosteriorDB.model(PosteriorDB.posterior(pdb, (posterior_name))), "stan")),
PosteriorDB.load(PosteriorDB.dataset(PosteriorDB.posterior(pdb, (posterior_name))), String)
)
posterior_name = "diamonds-diamonds"
@info posterior_name
posterior = stan_problem(posterior_name)
# A single run is sufficient because AdvancedHMC.jl (always) does so badly and I can't be bothered to wait for several minutes.
rows = map(1:1) do seed
warmuphmc_time = @elapsed warmuphmc_rv = WarmupHMC.adaptive_warmup_mcmc(Xoshiro(seed), posterior)
warmuphmc_steps = warmuphmc_rv.total_evaluation_counter
warmuphmc_ess = minimum(ess(reshape(warmuphmc_rv.posterior_position', (n_samples, 1, :))))
warmuphmc_eff = 1e3 * warmuphmc_ess / warmuphmc_steps
advancedhmc_time = @elapsed advancedhmc_rv = AdvancedHMC.sample(Xoshiro(seed), posterior, AdvancedHMC.NUTS(.8), 2*n_samples, n_adapts=n_samples, progress=false)
advancedhmc_steps = sum(x->x.stat.n_steps, advancedhmc_rv)
advancedhmc_ess = minimum(ess(reshape(mapreduce(x->x.z.θ, hcat, advancedhmc_rv[n_samples+1:end])', (n_samples, 1, :))))
advancedhmc_eff = 1e3 * advancedhmc_ess / advancedhmc_steps
(;
warmuphmc_time,
warmuphmc_steps,
warmuphmc_ess,
warmuphmc_eff,
advancedhmc_time,
advancedhmc_steps,
advancedhmc_ess,
advancedhmc_eff,
)
end
df = DataFrame(rows)
mapcols(median, df)which outputs for me: I.e. for that posterior, WarmupHMC.jl is more than 30 times as fast (in terms of wall time) and more than 140 times as efficient (in terms of ESS per leapfrog step). Edit: This particular posterior is known to have strong correlations between its parameters, which is why AdvancedHMC.jl's default hyperparameter tuning is so inefficient - I'd think that using a dense mass matrix would help a bit - or considerably). |
You need to load several development versions, specifically
TuringLang/DynamicPPL.jl#1129 (which gives the recently discussed speedups)
TuringLang/Turing.jl#2713 (which adds compatibility for the DynamicPPL changes)
plus this PR of course!
(A fix for TuringLang/AbstractMCMC.jl#185 would likely make things faster by a constant amount too, but it's not needed for this)
Once you do, you get the following:
(CmdStan via Python takes around 0.5 seconds.)