Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions .github/workflows/FloatTypes.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
name: Float type promotion

on:
push:
branches:
- main
pull_request:

# needed to allow julia-actions/cache to delete old caches that it has created
permissions:
actions: write
contents: read

# Cancel existing tests on the same PR if a new commit is added to a pull request
concurrency:
group: ${{ github.workflow }}-${{ github.ref || github.run_id }}
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}

jobs:
floattypes:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6

- uses: julia-actions/setup-julia@v2
with:
version: "1"

- uses: julia-actions/cache@v2

- name: Run float type tests
working-directory: test/floattypes
run: |
julia --project=. --color=yes -e 'using Pkg; Pkg.instantiate()'
julia --project=. --color=yes main.jl setup f64
julia --project=. --color=yes main.jl run f64
julia --project=. --color=yes main.jl setup f32
julia --project=. --color=yes main.jl run f32
julia --project=. --color=yes main.jl setup f16
julia --project=. --color=yes main.jl run f16
julia --project=. --color=yes main.jl setup min
julia --project=. --color=yes main.jl run min
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ docs/.jekyll-cache
.vscode
.DS_Store
Manifest.toml
/Manifest.toml
/test/Manifest.toml
LocalPreferences.toml

benchmarks/output/
12 changes: 12 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
# 0.43.3

Unify parameter initialisation for HMC and external samplers.
External samplers (like HMC) now attempt multiple times to generate valid initial parameters, instead of just taking the first set of parameters.

Re-exports `set_logprob_type!` from DynamicPPL to allow users to control the base log-probability type used when evaluating Turing models.
For example, calling `set_logprob_type!(Float32)` will mean that Turing will use `Float32` for log-probability calculations, only promoting if there is something in the model that causes it to be (e.g. a distribution that returns `Float64` log-probabilities).
Note that this is a compile-time preference: for it to take effect you will have to restart your Julia session after calling `set_logprob_type!`.

Furthermore, note that sampler support for non-`Float64` log-probabilities is currently limited.
Although DynamicPPL promises not promote float types unnecessarily, many samplers, including HMC and NUTS, still use `Float64` internally and thus will cause log-probabilities and parameters to be promoted to `Float64`, even if the model itself uses `Float32`.

# 0.43.2

Throw an `ArgumentError` when a `Gibbs` sampler is missing component samplers for any variable in the model.
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.43.2"
version = "0.43.3"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -61,7 +61,7 @@ DifferentiationInterface = "0.7"
Distributions = "0.25.77"
DocStringExtensions = "0.8, 0.9"
DynamicHMC = "3.4"
DynamicPPL = "0.40.6"
DynamicPPL = "0.40.15"
EllipticalSliceSampling = "0.5, 1, 2"
ForwardDiff = "0.10.3, 1"
Libtask = "0.9.14"
Expand Down
11 changes: 6 additions & 5 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ even though [`Prior()`](@ref) is actually defined in the `Turing.Inference` modu
| `setthreadsafe` | [`DynamicPPL.setthreadsafe`](@extref) | Mark a model as requiring threadsafe evaluation |
| `might_produce` | [`Libtask.might_produce`](@extref) | Mark a method signature as potentially calling `Libtask.produce` |
| `@might_produce` | [`Libtask.@might_produce`](@extref) | Mark a function name as potentially calling `Libtask.produce` |
| `set_logprob_type!` | [`DynamicPPL.set_logprob_type!`](@extref) | Set the base log-probability type used during evaluation of Turing models |

### Inference

Expand Down Expand Up @@ -81,11 +82,11 @@ even though [`Prior()`](@ref) is actually defined in the `Turing.Inference` modu

### Data structures

| Exported symbol | Documentation | Description |
|:--------------- |:------------------------------------------- |:----------------------------------- |
| `@vnt` | [`DynamicPPL.@vnt`](@extref) | Generate a `VarNameTuple` |
| `VarNamedTuple` | [`DynamicPPL.VarNamedTuple`](@extref) | A mapping from `VarName`s to values |
| `OrderedDict` | [`OrderedCollections.OrderedDict`](@extref) | An ordered dictionary |
| Exported symbol | Documentation | Description |
|:--------------- |:---------------------------------------------------- |:----------------------------------- |
| `@vnt` | [`DynamicPPL.VarNamedTuples.@vnt`](@extref) | Generate a `VarNameTuple` |
| `VarNamedTuple` | [`DynamicPPL.VarNamedTuples.VarNamedTuple`](@extref) | A mapping from `VarName`s to values |
| `OrderedDict` | [`OrderedCollections.OrderedDict`](@extref) | An ordered dictionary |

### DynamicPPL utilities

Expand Down
19 changes: 8 additions & 11 deletions ext/TuringDynamicHMCExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,26 +50,23 @@ function AbstractMCMC.step(
initial_params,
kwargs...,
)
# Define log-density function.
# TODO(penelopeysm) We need to check that the initial parameters are valid. Same as how
# we do it for HMC
_, vi = DynamicPPL.init!!(
rng, model, DynamicPPL.VarInfo(), initial_params, DynamicPPL.LinkAll()
)
ℓ = DynamicPPL.LogDensityFunction(
model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.adtype
# Construct LogDensityFunction
tfm_strategy = DynamicPPL.LinkAll()
ldf = DynamicPPL.LogDensityFunction(
model, DynamicPPL.getlogjoint_internal, tfm_strategy; adtype=spl.adtype
)
x = Turing.Inference.find_initial_params_ldf(rng, ldf, initial_params)

# Perform initial step.
results = DynamicHMC.mcmc_keep_warmup(
rng, , 0; initialization=(q=vi[:],), reporter=DynamicHMC.NoProgressReport()
rng, ldf, 0; initialization=(q=x,), reporter=DynamicHMC.NoProgressReport()
)
steps = DynamicHMC.mcmc_steps(results.sampling_logdensity, results.final_warmup_state)
Q, _ = DynamicHMC.mcmc_next_step(steps, results.final_warmup_state.Q)

# Create first sample and state.
sample = DynamicPPL.ParamsWithStats(Q.q, )
state = DynamicNUTSState(, Q, steps.H.κ, steps.ϵ)
sample = DynamicPPL.ParamsWithStats(Q.q, ldf)
state = DynamicNUTSState(ldf, Q, steps.H.κ, steps.ϵ)

return sample, state
end
Expand Down
6 changes: 5 additions & 1 deletion src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ using DynamicPPL:
InitFromParams,
setthreadsafe,
filldist,
arraydist
arraydist,
set_logprob_type!

using StatsBase: predict
using OrderedCollections: OrderedDict
using Libtask: might_produce, @might_produce
Expand Down Expand Up @@ -163,6 +165,8 @@ export
fix,
unfix,
OrderedDict, # OrderedCollections
# Log-prob types in accumulators
set_logprob_type!,
# Initialisation strategies for models
InitFromPrior,
InitFromUniform,
Expand Down
38 changes: 38 additions & 0 deletions src/mcmc/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,44 @@ function _convert_initial_params(@nospecialize(_::Any))
throw(ArgumentError(errmsg))
end

"""
find_initial_params_ldf(rng, ldf, init_strategy; max_attempts=1000)

Given a `LogDensityFunction` and an initialization strategy, attempt to find valid initial
parameters by sampling from the initialization strategy and checking that the log density
(and gradient, if available) are finite. If valid parameters are not found after
`max_attempts`, throw an error.
"""
function find_initial_params_ldf(
rng::Random.AbstractRNG,
ldf::DynamicPPL.LogDensityFunction,
init_strategy::DynamicPPL.AbstractInitStrategy;
max_attempts::Int=1000,
)
for attempts in 1:max_attempts
# Get new parameters
x = rand(rng, ldf, init_strategy)
is_valid = if ldf.adtype === nothing
logp = LogDensityProblems.logdensity(ldf, x)
isfinite(logp)
else
logp, grad = LogDensityProblems.logdensity_and_gradient(ldf, x)
isfinite(logp) && all(isfinite, grad)
end

# If they're OK, return them
is_valid && return x

attempts == 10 &&
@warn "failed to find valid initial parameters in $(attempts) tries; consider providing a different initialisation strategy with the `initial_params` keyword"
end

# if we failed to find valid initial parameters, error
return error(
"failed to find valid initial parameters in $(max_attempts) tries. See https://turinglang.org/docs/uri/initial-parameters for common causes and solutions. If the issue persists, please open an issue at https://github.com/TuringLang/Turing.jl/issues",
)
end

#########################################
# Default definitions for the interface #
#########################################
Expand Down
29 changes: 12 additions & 17 deletions src/mcmc/external_sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,28 +149,17 @@ function AbstractMCMC.step(
) where {unconstrained}
sampler = sampler_wrapper.sampler

# Initialise varinfo with initial params and link the varinfo if needed.
tfm_strategy = unconstrained ? DynamicPPL.LinkAll() : DynamicPPL.UnlinkAll()
_, varinfo = DynamicPPL.init!!(rng, model, VarInfo(), initial_params, tfm_strategy)

# We need to extract the vectorised initial_params, because the later call to
# AbstractMCMC.step only sees a `LogDensityModel` which expects `initial_params`
# to be a vector.
initial_params_vector = varinfo[:]

# Construct LogDensityFunction
tfm_strategy = unconstrained ? DynamicPPL.LinkAll() : DynamicPPL.UnlinkAll()
f = DynamicPPL.LogDensityFunction(
model, DynamicPPL.getlogjoint_internal, varinfo; adtype=sampler_wrapper.adtype
model, DynamicPPL.getlogjoint_internal, tfm_strategy; adtype=sampler_wrapper.adtype
)
x = find_initial_params_ldf(rng, f, initial_params)

# Then just call `AbstractMCMC.step` with the right arguments.
_, state_inner = if initial_state === nothing
AbstractMCMC.step(
rng,
AbstractMCMC.LogDensityModel(f),
sampler;
initial_params=initial_params_vector,
kwargs...,
rng, AbstractMCMC.LogDensityModel(f), sampler; initial_params=x, kwargs...
)

else
Expand All @@ -179,7 +168,7 @@ function AbstractMCMC.step(
AbstractMCMC.LogDensityModel(f),
sampler,
initial_state;
initial_params=initial_params_vector,
initial_params=x,
kwargs...,
)
end
Expand All @@ -191,7 +180,13 @@ function AbstractMCMC.step(
new_stats = AbstractMCMC.getstats(state_inner)
DynamicPPL.ParamsWithStats(new_parameters, f, new_stats)
end
return (new_transition, TuringState(state_inner, varinfo, new_parameters, f))

# TODO(penelopeysm): this varinfo is only needed for Gibbs. The external sampler itself
# has no use for it. Get rid of this as soon as possible.
vi = DynamicPPL.link!!(VarInfo(model), model)
vi = DynamicPPL.unflatten!!(vi, x)

return (new_transition, TuringState(state_inner, vi, new_parameters, f))
end

function AbstractMCMC.step(
Expand Down
68 changes: 15 additions & 53 deletions src/mcmc/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,41 +150,10 @@ function AbstractMCMC.sample(
end
end

function find_initial_params(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
varinfo::DynamicPPL.AbstractVarInfo,
hamiltonian::AHMC.Hamiltonian,
init_strategy::DynamicPPL.AbstractInitStrategy;
max_attempts::Int=1000,
)
varinfo = deepcopy(varinfo) # Don't mutate

for attempts in 1:max_attempts
theta = varinfo[:]
z = AHMC.phasepoint(rng, theta, hamiltonian)
isfinite(z) && return varinfo, z

attempts == 10 &&
@warn "failed to find valid initial parameters in $(attempts) tries; consider providing a different initialisation strategy with the `initial_params` keyword"

# Resample and try again.
_, varinfo = DynamicPPL.init!!(
rng, model, varinfo, init_strategy, DynamicPPL.LinkAll()
)
end

# if we failed to find valid initial parameters, error
return error(
"failed to find valid initial parameters in $(max_attempts) tries. See https://turinglang.org/docs/uri/initial-parameters for common causes and solutions. If the issue persists, please open an issue at https://github.com/TuringLang/Turing.jl/issues",
)
end

function Turing.Inference.initialstep(
function AbstractMCMC.step(
rng::AbstractRNG,
model::DynamicPPL.Model,
spl::Hamiltonian,
vi_original::AbstractVarInfo;
spl::Hamiltonian;
# the initial_params kwarg is always passed on from sample(), cf. DynamicPPL
# src/sampler.jl, so we don't need to provide a default value here
initial_params::DynamicPPL.AbstractInitStrategy,
Expand All @@ -193,32 +162,19 @@ function Turing.Inference.initialstep(
verbose::Bool=true,
kwargs...,
)
# Transform the samples to unconstrained space and compute the joint log probability.
vi = DynamicPPL.link(vi_original, model)

# Extract parameters.
theta = vi[:]

# Create a Hamiltonian.
metricT = getmetricT(spl)
metric = metricT(length(theta))
# Create a Hamiltonian
ldf = DynamicPPL.LogDensityFunction(
model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.adtype
model, DynamicPPL.getlogjoint_internal, DynamicPPL.LinkAll(); adtype=spl.adtype
)
metricT = getmetricT(spl)
metric = metricT(LogDensityProblems.dimension(ldf))
lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf)
lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf)
hamiltonian = AHMC.Hamiltonian(metric, lp_func, lp_grad_func)

# Note that there is already one round of 'initialisation' before we reach this step,
# inside DynamicPPL's `AbstractMCMC.step` implementation. That leads to a possible issue
# that this `find_initial_params` function might override the parameters set by the
# user.
# Luckily for us, `find_initial_params` always checks if the logp and its gradient are
# finite. If it is already finite with the params inside the current `vi`, it doesn't
# attempt to find new ones. This means that the parameters passed to `sample()` will be
# respected instead of being overridden here.
vi, z = find_initial_params(rng, model, vi, hamiltonian, initial_params)
theta = vi[:]
# Find initial values
theta = find_initial_params_ldf(rng, ldf, initial_params)
z = AHMC.phasepoint(rng, theta, hamiltonian)

# Find good eps if not provided one
if iszero(spl.ϵ)
Expand All @@ -236,6 +192,12 @@ function Turing.Inference.initialstep(
else
DynamicPPL.ParamsWithStats(theta, ldf, NamedTuple())
end

# TODO(penelopeysm): this varinfo is only needed for Gibbs. HMC itself has no use for
# it. Get rid of this as soon as possible.
vi = DynamicPPL.link!!(VarInfo(model), model)
vi = DynamicPPL.unflatten!!(vi, theta)

state = HMCState(vi, 0, kernel, hamiltonian, z, adaptor, ldf)

return transition, state
Expand Down
Loading
Loading