Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
# 0.42.9

Improve handling of model evaluator functions with Libtask.

This means that when running SMC or PG on a model with keyword arguments, you no longer need to use `@might_produce` (see patch notes of v0.42.5 for more details on this).

It also means that submodels with observations inside will now be reliably handled by the SMC/PG samplers, which was not the case before (the observations were only picked up if the submodel was inlined by the Julia compiler, which could lead to correctness issues).

# 0.42.8

Add support for `TensorBoardLogger.jl` via `AbstractMCMC.mcmc_callback`.
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.42.8"
version = "0.42.9"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -65,7 +65,7 @@ DynamicHMC = "3.4"
DynamicPPL = "0.39.1"
EllipticalSliceSampling = "0.5, 1, 2"
ForwardDiff = "0.10.3, 1"
Libtask = "0.9.5"
Libtask = "0.9.14"
LinearAlgebra = "1"
LogDensityProblems = "2"
MCMCChains = "5, 6, 7"
Expand Down
34 changes: 6 additions & 28 deletions src/mcmc/particle_mcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ function AbstractMCMC.sample(
)
check_model && _check_model(model, sampler)
error_if_threadsafe_eval(model)
check_model_kwargs(model)
# need to add on the `nparticles` keyword argument for `initialstep` to make use of
return AbstractMCMC.mcmcsample(
rng,
Expand All @@ -138,28 +137,6 @@ function AbstractMCMC.sample(
)
end

function check_model_kwargs(model::DynamicPPL.Model)
if !isempty(model.defaults)
# If there are keyword arguments, we need to check that the user has
# accounted for this by overloading `might_produce`.
might_produce = Libtask.might_produce(typeof((Core.kwcall, NamedTuple(), model.f)))
if !might_produce
io = IOBuffer()
ctx = IOContext(io, :color => true)
print(
ctx,
"Models with keyword arguments need special treatment to be used" *
" with particle methods. Please run:\n\n",
)
printstyled(
ctx, " Turing.@might_produce($(model.f))"; bold=true, color=:blue
)
print(ctx, "\n\nbefore sampling from this model with particle methods.\n")
error(String(take!(io)))
end
end
end

function Turing.Inference.initialstep(
rng::AbstractRNG,
model::DynamicPPL.Model,
Expand All @@ -169,7 +146,6 @@ function Turing.Inference.initialstep(
discard_sample=false,
kwargs...,
)
check_model_kwargs(model)
# Reset the VarInfo.
vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator())
vi = DynamicPPL.empty!!(vi)
Expand Down Expand Up @@ -292,7 +268,6 @@ function Turing.Inference.initialstep(
kwargs...,
)
error_if_threadsafe_eval(model)
check_model_kwargs(model)
vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator())

# Create a new set of particles
Expand Down Expand Up @@ -534,6 +509,9 @@ Libtask.@might_produce(DynamicPPL.tilde_observe!!)
# Could tilde_assume!! have tighter type bounds on the arguments, namely a GibbsContext?
# That's the only thing that makes tilde_assume calls result in tilde_observe calls.
Libtask.@might_produce(DynamicPPL.tilde_assume!!)
Libtask.@might_produce(DynamicPPL.evaluate!!)
Libtask.@might_produce(DynamicPPL.init!!)
Libtask.might_produce(::Type{<:Tuple{<:DynamicPPL.Model,Vararg}}) = true

# This handles all models and submodel evaluator functions (including those with keyword
# arguments). The key to this is realising that all model evaluator functions take
# DynamicPPL.Model as an argument, so we can just check for that. See
# https://github.com/TuringLang/Libtask.jl/issues/217.
Libtask.might_produce_if_sig_contains(::Type{<:DynamicPPL.Model}) = true
24 changes: 21 additions & 3 deletions test/Aqua.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,29 @@
module AquaTests

using Aqua: Aqua
using Libtask: Libtask
using Turing

# We test ambiguities separately because it catches a lot of problems
# in dependencies but we test it for Turing.
Aqua.test_ambiguities([Turing])
# We test ambiguities specifically only for Turing, because testing ambiguities for all
# packages in the environment leads to a lot of ambiguities from dependencies that we cannot
# control.
#
# `Libtask.might_produce` is excluded because the `@might_produce` macro generates a lot of
# ambiguities that will never happen in practice.
#
# Specifically, when you write `@might_produce f` for a function `f` that has methods that
# take keyword arguments, we have to generate a `might_produce` method for
# `Type{<:Tuple{<:Function,...,typeof(f)}}`. There is no way to circumvent this: see
# https://github.com/TuringLang/Libtask.jl/issues/197. This in turn will cause method
# ambiguities with any other function, say `g`, for which
# `::Type{<:Tuple{typeof(g),Vararg}}` is marked as produceable.
#
# To avoid the method ambiguities, we *could* manually spell out `might_produce` methods for
# each method of `g` manually instead of using Vararg, but that would be both very verbose
# and fragile. It would also not provide any real benefit since those ambiguities are not
# meaningful in practice (in particular, to trigger this we would need to call `g(..., f)`,
# which is incredibly unlikely).
Aqua.test_ambiguities([Turing]; exclude=[Libtask.might_produce])
Aqua.test_all(Turing; ambiguities=false)

end
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5"
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
Expand Down Expand Up @@ -57,6 +58,7 @@ DynamicPPL = "0.39.6"
FiniteDifferences = "0.10.8, 0.11, 0.12"
ForwardDiff = "0.10.12 - 0.10.32, 0.10, 1"
HypothesisTests = "0.11"
Libtask = "0.9.14"
LinearAlgebra = "1"
LogDensityProblems = "2"
LogDensityProblemsAD = "1.4"
Expand Down
13 changes: 4 additions & 9 deletions test/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -314,17 +314,12 @@ using Turing
return priors
end

@test_throws ErrorException chain = sample(
StableRNG(seed), gauss2(; x=x), PG(10), 10
)
@test_throws ErrorException chain = sample(
StableRNG(seed), gauss2(; x=x), SMC(), 10
)

@test_throws ErrorException chain = sample(
chain = sample(StableRNG(seed), gauss2(; x=x), PG(10), 10)
chain = sample(StableRNG(seed), gauss2(; x=x), SMC(), 10)
chain = sample(
StableRNG(seed), gauss2(DynamicPPL.TypeWrap{Vector{Float64}}(); x=x), PG(10), 10
)
@test_throws ErrorException chain = sample(
chain = sample(
StableRNG(seed), gauss2(DynamicPPL.TypeWrap{Vector{Float64}}(); x=x), SMC(), 10
)

Expand Down
43 changes: 38 additions & 5 deletions test/mcmc/particle_mcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,26 +161,59 @@ end
@test mean(c[:x]) > 0.7
end

# https://github.com/TuringLang/Turing.jl/issues/2007
@testset "keyword argument handling" begin
@model function kwarg_demo(y; n=0.0)
x ~ Normal(n)
return y ~ Normal(x)
end
@test_throws "Models with keyword arguments" sample(kwarg_demo(5.0), PG(20), 10)

# Check that enabling `might_produce` does allow sampling
@might_produce kwarg_demo
chain = sample(StableRNG(468), kwarg_demo(5.0), PG(20), 1000)
@test chain isa MCMCChains.Chains
@test mean(chain[:x]) ≈ 2.5 atol = 0.2

# Check that the keyword argument's value is respected
chain2 = sample(StableRNG(468), kwarg_demo(5.0; n=10.0), PG(20), 1000)
@test chain2 isa MCMCChains.Chains
@test mean(chain2[:x]) ≈ 7.5 atol = 0.2
end

@testset "submodels without kwargs" begin
@model function inner(y, x)
# Mark as noinline explicitly to make sure that behaviour is not reliant on the
# Julia compiler inlining it.
# See https://github.com/TuringLang/Turing.jl/issues/2772
@noinline
return y ~ Normal(x)
end
@model function nested(y)
x ~ Normal()
return a ~ to_submodel(inner(y, x))
end
m1 = nested(1.0)
chn = sample(StableRNG(468), m1, PG(10), 1000)
@test mean(chn[:x]) ≈ 0.5 atol = 0.1
end

@testset "submodels with kwargs" begin
@model function inner_kwarg(y; n=0.0)
@noinline # See above
x ~ Normal(n)
return y ~ Normal(x)
end
@model function outer_kwarg1()
return a ~ to_submodel(inner_kwarg(5.0))
end
m1 = outer_kwarg1()
chn1 = sample(StableRNG(468), m1, PG(10), 1000)
@test mean(chn1[Symbol("a.x")]) ≈ 2.5 atol = 0.2

@model function outer_kwarg2(n)
return a ~ to_submodel(inner_kwarg(5.0; n=n))
end
m2 = outer_kwarg2(10.0)
chn2 = sample(StableRNG(468), m2, PG(10), 1000)
@test mean(chn2[Symbol("a.x")]) ≈ 7.5 atol = 0.2
end

@testset "refuses to run threadsafe eval" begin
# PG can't run models that have nondeterministic evaluation order,
# so it should refuse to run models marked as threadsafe.
Expand Down
Loading