diff --git a/HISTORY.md b/HISTORY.md index e22e882a..614ea964 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,9 @@ +# 0.9.15 + +Improve Libtask's handling of non-const global variables. +Previously, usage of non-const globals in a TapedTask would cause Libtask to throw an "Unbound GlobalRef not allowed in value position" error. +They should now work, and you can mutate global variables between calls to `consume`. + # 0.9.14 Added the `Libtask.might_produce_if_sig_contains` method. diff --git a/Project.toml b/Project.toml index 6ad51d8e..ba1faee2 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" license = "MIT" desc = "Tape based task copying in Turing" repo = "https://github.com/TuringLang/Libtask.jl.git" -version = "0.9.14" +version = "0.9.15" [deps] MistyClosures = "dbe65cb8-6be2-42dd-bbc5-4196aaced4f4" diff --git a/src/copyable_task.jl b/src/copyable_task.jl index 2e93628b..0c83a80a 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -1027,7 +1027,17 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple,Vector{Any}} push!(inst_pairs, (id, inst)) elseif stmt isa GlobalRef ref_ind = ssa_id_to_ref_index_map[id] - expr = Expr(:call, set_ref_at!, refs_id, ref_ind, stmt) + # We can only use `stmt` as an argument to `set_ref_at!` if it is a + # `const` binding. If it's not const, then we need to generate a new SSA + # value for it. + set_ref_at_arg = if isconst(stmt) + stmt + else + new_id = ID() + push!(inst_pairs, (new_id, new_inst(stmt))) + new_id + end + expr = Expr(:call, set_ref_at!, refs_id, ref_ind, set_ref_at_arg) push!(inst_pairs, (id, new_inst(expr))) elseif stmt isa Core.PiNode if stmt.val isa ID diff --git a/test/copyable_task.jl b/test/copyable_task.jl index b113f06c..fb1f32dd 100644 --- a/test/copyable_task.jl +++ b/test/copyable_task.jl @@ -3,6 +3,9 @@ module CopyableTaskTests using Libtask using Test +# Used later. +__global_a = 1.0 + @testset "copyable_task" begin @testset "get_taped_globals outside of a task" begin # This testset must come first because subsequent calls to get_taped_globals / @@ -476,6 +479,23 @@ using Test @test consume(t) === nothing end end + + @testset "(non-const) global variables in TapedTasks" begin + function global_f() + produce(__global_a + 1.0) + produce(__global_a + 1.0) + return nothing + end + # TapedTask construction used to error: + # https://github.com/TuringLang/Libtask.jl/issues/211 + t = TapedTask(nothing, global_f) + @test consume(t) == 2.0 + # Check that you can mutate the variable between `produce`s. + global __global_a + __global_a = 10.0 + @test consume(t) == 11.0 + @test consume(t) === nothing + end end end # module diff --git a/test/integration/turing/main.jl b/test/integration/turing/main.jl index 087d390c..20ff1671 100644 --- a/test/integration/turing/main.jl +++ b/test/integration/turing/main.jl @@ -34,12 +34,12 @@ model = f() # Check that enabling `might_produce` does allow sampling @might_produce kwarg_demo chain = sample(StableRNG(468), kwarg_demo(5.0), PG(20), 1000; progress=false) - @test mean(chain[:x]) ≈ 2.5 atol = 0.2 + @test mean(chain[:x]) ≈ 2.5 atol = 0.3 # Check that the keyword argument's value is respected chain2 = sample( StableRNG(468), kwarg_demo(5.0; n=10.0), PG(20), 1000; progress=false ) - @test mean(chain2[:x]) ≈ 7.5 atol = 0.2 + @test mean(chain2[:x]) ≈ 7.5 atol = 0.3 end end