Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# 0.9.14

Added the `Libtask.might_produce_if_sig_contains` method.

# 0.9.13

Fix a bug where SSA registers in `throw_undef_if_not` expressions were not being correctly handled.
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
license = "MIT"
desc = "Tape based task copying in Turing"
repo = "https://github.com/TuringLang/Libtask.jl.git"
version = "0.9.13"
version = "0.9.14"

[deps]
MistyClosures = "dbe65cb8-6be2-42dd-bbc5-4196aaced4f4"
Expand Down
1 change: 1 addition & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,5 @@ An opt-in mechanism marks functions that might contain `Libtask.produce` stateme
```@docs; canonical=true
Libtask.might_produce(::Type{<:Tuple})
Libtask.@might_produce
Libtask.might_produce_if_sig_contains
```
27 changes: 26 additions & 1 deletion src/copyable_task.jl
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,26 @@ that, by default, we assume that calls do not contain `Libtask.produce` statemen
"""
might_produce(::Type{<:Tuple}) = false

"""
might_produce_if_sig_contains(::Type{T})::Bool

Mark *any* method as being able to `produce` if `T` is found anywhere in its signature.

Note that if `T` is an abstract type, you will have to use
`might_produce_if_sig_contains(::Type{<:T})` to mark methods which have subtypes of `T` in
their signature as being able to `produce`.

For example, if `might_produce_if_sig_contains(::Type{<:AbstractFoo}) = true`, then any
method that takes an argument of `Foo <: AbstractFoo` will be treated as having
`might_produce = true`.

!!! warning
This method should be used with caution, as it is a very broad brush.
It is only really intended for use with Turing.jl.
"""
might_produce_if_sig_contains(::Type) = false
might_produce_if_sig_contains(::typeof(Vararg)) = false

"""
@might_produce(f)

Expand Down Expand Up @@ -593,7 +613,12 @@ function stmt_might_produce(x, ret_type::Type)::Bool

# Statement will terminate in the usual fashion, so _do_ bother recusing.
is_produce_stmt(x) && return true
Meta.isexpr(x, :invoke) && return might_produce(get_mi(x.args[1]).specTypes)
if Meta.isexpr(x, :invoke)
mi_sig = get_mi(x.args[1]).specTypes
return (
might_produce(mi_sig) || any(might_produce_if_sig_contains, mi_sig.parameters)
)
end
if Meta.isexpr(x, :call)
# This is a hack -- it's perfectly possible for `DataType` calls to produce in general.
f = get_function(x.args[1])
Expand Down
73 changes: 73 additions & 0 deletions test/copyable_task.jl
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,79 @@ using Test
t = TapedTask(nothing, tuin_f, 2)
@test_throws UndefVarError consume(t)
end

@testset "might_produce_if_sig_contains" begin
@testset "concrete type" begin
struct MyType end
@noinline function mp_g(x::Int, ::MyType)
produce(x)
return nothing
end
@noinline function mp_g(::MyType; x::Int=3)
produce(x)
return nothing
end
@noinline function mp_g(x::Int)
produce(x)
return nothing
end

# Use function reference to ensure that mp_g doesn't get inlined.
function mp_f(x, gref)
gref[](x, MyType()) # Should produce
gref[](MyType(); x=4) # Should produce
gref[](x) # Should not produce
return nothing
end
# Before marking it as might_produce, nothing should produce.
t = TapedTask(nothing, mp_f, 1, Ref(mp_g))
@test consume(t) === nothing

# Now marking `MyType` as causing produces should make the first two calls
# produce, but not the third.
Libtask.might_produce_if_sig_contains(::Type{MyType}) = true
t = TapedTask(nothing, mp_f, 1, Ref(mp_g))
@test consume(t) == 1
@test consume(t) == 4
@test consume(t) === nothing
end

@testset "abstract type" begin
abstract type AbstractTT end
struct MyType2 <: AbstractTT end
@noinline function abs_mp_g(x::Int, ::MyType2)
produce(x)
return nothing
end
@noinline function abs_mp_g(::MyType2; x::Int=3)
produce(x)
return nothing
end
@noinline function abs_mp_g(x::Int)
produce(x)
return nothing
end

# Use function reference to ensure that mp_g doesn't get inlined.
function abs_mp_f(x, gref)
gref[](x, MyType2()) # Should produce
gref[](MyType2(); x=4) # Should produce
gref[](x) # Should not produce
return nothing
end
# Before marking it as might_produce, nothing should produce.
t = TapedTask(nothing, abs_mp_f, 1, Ref(abs_mp_g))
@test consume(t) === nothing

# Now marking `AbstractTT` as causing produces should make the first two calls
# produce, but not the third.
Libtask.might_produce_if_sig_contains(::Type{<:AbstractTT}) = true
t = TapedTask(nothing, abs_mp_f, 1, Ref(abs_mp_g))
@test consume(t) == 1
@test consume(t) == 4
@test consume(t) === nothing
end
end
end

end # module