diff --git a/HISTORY.md b/HISTORY.md index 7d6af24f..e22e882a 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,7 @@ +# 0.9.14 + +Added the `Libtask.might_produce_if_sig_contains` method. + # 0.9.13 Fix a bug where SSA registers in `throw_undef_if_not` expressions were not being correctly handled. diff --git a/Project.toml b/Project.toml index 04993379..6ad51d8e 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" license = "MIT" desc = "Tape based task copying in Turing" repo = "https://github.com/TuringLang/Libtask.jl.git" -version = "0.9.13" +version = "0.9.14" [deps] MistyClosures = "dbe65cb8-6be2-42dd-bbc5-4196aaced4f4" diff --git a/docs/src/index.md b/docs/src/index.md index ac57a0c9..e2d08620 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -31,4 +31,5 @@ An opt-in mechanism marks functions that might contain `Libtask.produce` stateme ```@docs; canonical=true Libtask.might_produce(::Type{<:Tuple}) Libtask.@might_produce +Libtask.might_produce_if_sig_contains ``` diff --git a/src/copyable_task.jl b/src/copyable_task.jl index 521b0467..2e93628b 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -484,6 +484,26 @@ that, by default, we assume that calls do not contain `Libtask.produce` statemen """ might_produce(::Type{<:Tuple}) = false +""" + might_produce_if_sig_contains(::Type{T})::Bool + +Mark *any* method as being able to `produce` if `T` is found anywhere in its signature. + +Note that if `T` is an abstract type, you will have to use +`might_produce_if_sig_contains(::Type{<:T})` to mark methods which have subtypes of `T` in +their signature as being able to `produce`. + +For example, if `might_produce_if_sig_contains(::Type{<:AbstractFoo}) = true`, then any +method that takes an argument of `Foo <: AbstractFoo` will be treated as having +`might_produce = true`. + +!!! warning + This method should be used with caution, as it is a very broad brush. + It is only really intended for use with Turing.jl. +""" +might_produce_if_sig_contains(::Type) = false +might_produce_if_sig_contains(::typeof(Vararg)) = false + """ @might_produce(f) @@ -593,7 +613,12 @@ function stmt_might_produce(x, ret_type::Type)::Bool # Statement will terminate in the usual fashion, so _do_ bother recusing. is_produce_stmt(x) && return true - Meta.isexpr(x, :invoke) && return might_produce(get_mi(x.args[1]).specTypes) + if Meta.isexpr(x, :invoke) + mi_sig = get_mi(x.args[1]).specTypes + return ( + might_produce(mi_sig) || any(might_produce_if_sig_contains, mi_sig.parameters) + ) + end if Meta.isexpr(x, :call) # This is a hack -- it's perfectly possible for `DataType` calls to produce in general. f = get_function(x.args[1]) diff --git a/test/copyable_task.jl b/test/copyable_task.jl index 185aabc3..b113f06c 100644 --- a/test/copyable_task.jl +++ b/test/copyable_task.jl @@ -403,6 +403,79 @@ using Test t = TapedTask(nothing, tuin_f, 2) @test_throws UndefVarError consume(t) end + + @testset "might_produce_if_sig_contains" begin + @testset "concrete type" begin + struct MyType end + @noinline function mp_g(x::Int, ::MyType) + produce(x) + return nothing + end + @noinline function mp_g(::MyType; x::Int=3) + produce(x) + return nothing + end + @noinline function mp_g(x::Int) + produce(x) + return nothing + end + + # Use function reference to ensure that mp_g doesn't get inlined. + function mp_f(x, gref) + gref[](x, MyType()) # Should produce + gref[](MyType(); x=4) # Should produce + gref[](x) # Should not produce + return nothing + end + # Before marking it as might_produce, nothing should produce. + t = TapedTask(nothing, mp_f, 1, Ref(mp_g)) + @test consume(t) === nothing + + # Now marking `MyType` as causing produces should make the first two calls + # produce, but not the third. + Libtask.might_produce_if_sig_contains(::Type{MyType}) = true + t = TapedTask(nothing, mp_f, 1, Ref(mp_g)) + @test consume(t) == 1 + @test consume(t) == 4 + @test consume(t) === nothing + end + + @testset "abstract type" begin + abstract type AbstractTT end + struct MyType2 <: AbstractTT end + @noinline function abs_mp_g(x::Int, ::MyType2) + produce(x) + return nothing + end + @noinline function abs_mp_g(::MyType2; x::Int=3) + produce(x) + return nothing + end + @noinline function abs_mp_g(x::Int) + produce(x) + return nothing + end + + # Use function reference to ensure that mp_g doesn't get inlined. + function abs_mp_f(x, gref) + gref[](x, MyType2()) # Should produce + gref[](MyType2(); x=4) # Should produce + gref[](x) # Should not produce + return nothing + end + # Before marking it as might_produce, nothing should produce. + t = TapedTask(nothing, abs_mp_f, 1, Ref(abs_mp_g)) + @test consume(t) === nothing + + # Now marking `AbstractTT` as causing produces should make the first two calls + # produce, but not the third. + Libtask.might_produce_if_sig_contains(::Type{<:AbstractTT}) = true + t = TapedTask(nothing, abs_mp_f, 1, Ref(abs_mp_g)) + @test consume(t) == 1 + @test consume(t) == 4 + @test consume(t) === nothing + end + end end end # module