Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# 0.9.15

Improve Libtask's handling of non-const global variables.
Previously, usage of non-const globals in a TapedTask would cause Libtask to throw an "Unbound GlobalRef not allowed in value position" error.
They should now work, and you can mutate global variables between calls to `consume`.

# 0.9.14

Added the `Libtask.might_produce_if_sig_contains` method.
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
license = "MIT"
desc = "Tape based task copying in Turing"
repo = "https://github.com/TuringLang/Libtask.jl.git"
version = "0.9.14"
version = "0.9.15"

[deps]
MistyClosures = "dbe65cb8-6be2-42dd-bbc5-4196aaced4f4"
Expand Down
12 changes: 11 additions & 1 deletion src/copyable_task.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1027,7 +1027,17 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple,Vector{Any}}
push!(inst_pairs, (id, inst))
elseif stmt isa GlobalRef
ref_ind = ssa_id_to_ref_index_map[id]
expr = Expr(:call, set_ref_at!, refs_id, ref_ind, stmt)
# We can only use `stmt` as an argument to `set_ref_at!` if it is a
# `const` binding. If it's not const, then we need to generate a new SSA
# value for it.
set_ref_at_arg = if isconst(stmt)
stmt
else
new_id = ID()
push!(inst_pairs, (new_id, new_inst(stmt)))
new_id
end
expr = Expr(:call, set_ref_at!, refs_id, ref_ind, set_ref_at_arg)
push!(inst_pairs, (id, new_inst(expr)))
elseif stmt isa Core.PiNode
if stmt.val isa ID
Expand Down
20 changes: 20 additions & 0 deletions test/copyable_task.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ module CopyableTaskTests
using Libtask
using Test

# Used later.
__global_a = 1.0

@testset "copyable_task" begin
@testset "get_taped_globals outside of a task" begin
# This testset must come first because subsequent calls to get_taped_globals /
Expand Down Expand Up @@ -476,6 +479,23 @@ using Test
@test consume(t) === nothing
end
end

@testset "(non-const) global variables in TapedTasks" begin
function global_f()
produce(__global_a + 1.0)
produce(__global_a + 1.0)
return nothing
end
# TapedTask construction used to error:
# https://github.com/TuringLang/Libtask.jl/issues/211
t = TapedTask(nothing, global_f)
@test consume(t) == 2.0
# Check that you can mutate the variable between `produce`s.
global __global_a
__global_a = 10.0
@test consume(t) == 11.0
@test consume(t) === nothing
end
end

end # module
4 changes: 2 additions & 2 deletions test/integration/turing/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ model = f()
# Check that enabling `might_produce` does allow sampling
@might_produce kwarg_demo
chain = sample(StableRNG(468), kwarg_demo(5.0), PG(20), 1000; progress=false)
@test mean(chain[:x]) ≈ 2.5 atol = 0.2
@test mean(chain[:x]) ≈ 2.5 atol = 0.3

# Check that the keyword argument's value is respected
chain2 = sample(
StableRNG(468), kwarg_demo(5.0; n=10.0), PG(20), 1000; progress=false
)
@test mean(chain2[:x]) ≈ 7.5 atol = 0.2
@test mean(chain2[:x]) ≈ 7.5 atol = 0.3
end
end
Loading