-
Notifications
You must be signed in to change notification settings - Fork 48
Open
Description
As in the title and inspired by #485, see e.g.
AdvancedHMC.jl/src/integrator.jl
Lines 216 to 265 in 6bc0c74
| function step( | |
| lf::DefaultLeapfrog{FT,T}, | |
| h::Hamiltonian, | |
| z::P, | |
| n_steps::Int=1; | |
| fwd::Bool=n_steps > 0, # simulate hamiltonian backward when n_steps < 0 | |
| full_trajectory::Val{FullTraj}=Val(false), | |
| ) where {FT<:AbstractFloat,T<:AbstractScalarOrVec{FT},P<:PhasePoint,FullTraj} | |
| n_steps = abs(n_steps) # to support `n_steps < 0` cases | |
| ϵ = fwd ? step_size(lf) : -step_size(lf) | |
| ϵ = ϵ' | |
| res = FullTraj ? Vector{P}(undef, n_steps) : nothing | |
| (; θ, r) = z | |
| (; value, gradient) = z.ℓπ | |
| for i in 1:n_steps | |
| # Tempering | |
| r = temper(lf, r, (i=i, is_half=true), n_steps) | |
| # Take a half leapfrog step for momentum variable | |
| r = r - ϵ / 2 .* gradient | |
| # Take a full leapfrog step for position variable | |
| ∇r = ∂H∂r(h, r) | |
| θ = θ + ϵ .* ∇r | |
| # Take a half leapfrog step for momentum variable | |
| (; value, gradient) = ∂H∂θ(h, θ) | |
| r = r - ϵ / 2 .* gradient | |
| # Tempering | |
| r = temper(lf, r, (i=i, is_half=false), n_steps) | |
| # Create a new phase point by caching the logdensity and gradient | |
| z = phasepoint(h, θ, r; ℓπ=DualValue(value, gradient)) | |
| # Update result | |
| if !isnothing(res) | |
| res[i] = z | |
| end | |
| if !isfinite(z) | |
| # Remove undef | |
| if !isnothing(res) | |
| resize!(res, i) | |
| end | |
| break | |
| end | |
| end | |
| return if FullTraj | |
| res | |
| else | |
| z | |
| end | |
| end |
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels