From 0a7080a7c761bbb6643ecc0e71b65031d6559b2e Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 24 Mar 2026 14:36:48 +0000 Subject: [PATCH 01/10] Break code up --- src/Libtask.jl | 1 + src/copyable_task.jl | 850 ------------------------------------------ src/transformation.jl | 848 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 849 insertions(+), 850 deletions(-) create mode 100644 src/transformation.jl diff --git a/src/Libtask.jl b/src/Libtask.jl index 473c6068..2d6475bd 100644 --- a/src/Libtask.jl +++ b/src/Libtask.jl @@ -14,6 +14,7 @@ include("bbcode.jl") using .BasicBlockCode include("copyable_task.jl") +include("transformation.jl") include("test_utils.jl") export TapedTask, diff --git a/src/copyable_task.jl b/src/copyable_task.jl index 0c83a80a..569bda08 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -569,853 +569,3 @@ macro might_produce(f) end end end - -# Helper struct used in `derive_copyable_task_ir`. -struct TupleRef - n::Int -end - -# Unclear whether this is needed. -get_value(x::GlobalRef) = getglobal(x.mod, x.name) -get_value(x::QuoteNode) = x.value -get_value(x) = x - -""" - is_produce_stmt(x)::Bool - -`true` if `x` is an expression of the form `Expr(:call, produce, %x)` or a similar `:invoke` -expression, otherwise `false`. -""" -function is_produce_stmt(x)::Bool - if Meta.isexpr(x, :invoke) && - length(x.args) == 3 && - x.args[1] isa Union{Core.MethodInstance,Core.CodeInstance} - return get_mi(x.args[1]).specTypes <: Tuple{typeof(produce),Any} - elseif Meta.isexpr(x, :call) && length(x.args) == 2 - return get_value(x.args[1]) === produce - else - return false - end -end - -""" - stmt_might_produce(x, ret_type::Type)::Bool - -`true` if `x` might contain a call to `produce`, and `false` otherwise. -""" -function stmt_might_produce(x, ret_type::Type)::Bool - - # Statement will terminate in an unusual fashion, so don't bother recursing. - # This isn't _strictly_ correct (there could be a `produce` statement before the - # `throw` call is hit), but this seems unlikely to happen in practice. If it does, the - # user should get a sensible error message anyway. - ret_type == Union{} && return false - - # Statement will terminate in the usual fashion, so _do_ bother recusing. - is_produce_stmt(x) && return true - if Meta.isexpr(x, :invoke) - mi_sig = get_mi(x.args[1]).specTypes - return ( - might_produce(mi_sig) || any(might_produce_if_sig_contains, mi_sig.parameters) - ) - end - if Meta.isexpr(x, :call) - # This is a hack -- it's perfectly possible for `DataType` calls to produce in general. - f = get_function(x.args[1]) - _might_produce = !isa(f, Union{Core.IntrinsicFunction,Core.Builtin,DataType}) - return _might_produce - end - return false -end - -get_function(x) = x -get_function(x::Expr) = eval(x) -get_function(x::GlobalRef) = isconst(x) ? getglobal(x.mod, x.name) : x.binding - -""" - produce_value(x::Expr) - -Returns the value that a `produce` statement returns. For example, for the statment -`produce(%x)`, this function will return `%x`. -""" -function produce_value(x::Expr) - is_produce_stmt(x) || throw(error("Not a produce statement. Please report this error.")) - Meta.isexpr(x, :invoke) && return x.args[3] - return x.args[2] # must be a `:call` Expr. -end - -struct ProducedValue{T} - x::T -end -ProducedValue(::Type{T}) where {T} = ProducedValue{Type{T}}(T) - -@inline Base.getindex(x::ProducedValue) = x.x - -""" - inc_args(stmt::T)::T where {T} - -Returns a new `T` which is equal to `stmt`, except any `Argument`s present in `stmt` are -incremented by `1`. For example -```jldoctest -julia> Libtask.inc_args(Core.ReturnNode(Core.Argument(1))) -:(return _2) -``` -""" -inc_args(x::Expr) = Expr(x.head, map(__inc, x.args)...) -inc_args(x::ReturnNode) = isdefined(x, :val) ? ReturnNode(__inc(x.val)) : x -inc_args(x::IDGotoIfNot) = IDGotoIfNot(__inc(x.cond), x.dest) -inc_args(x::IDGotoNode) = x -function inc_args(x::IDPhiNode) - new_values = Vector{Any}(undef, length(x.values)) - for n in eachindex(x.values) - if isassigned(x.values, n) - new_values[n] = __inc(x.values[n]) - end - end - return IDPhiNode(x.edges, new_values) -end -inc_args(::Nothing) = nothing -inc_args(x::GlobalRef) = x -inc_args(x::Core.PiNode) = Core.PiNode(__inc(x.val), __inc(x.typ)) - -__inc(x::Argument) = Argument(x.n + 1) -__inc(x) = x - -const TypeInfo = Tuple{Vector{Any},Dict{ID,Type}} - -""" - _typeof(x) - -Central definition of typeof, which is specific to the use-required in this package. -Largely the same as `Base._stable_typeof`, differing only in a handful of -situations, for example: -```jldoctest -julia> Base._stable_typeof((Float64,)) -Tuple{DataType} - -julia> Libtask._typeof((Float64,)) -Tuple{Type{Float64}} -``` -""" -_typeof(x) = Base._stable_typeof(x) -_typeof(x::Tuple) = Tuple{map(_typeof, x)...} -_typeof(x::NamedTuple{names}) where {names} = NamedTuple{names,_typeof(Tuple(x))} - -""" - get_type(info::ADInfo, x) - -Returns the static / inferred type associated to `x`. -""" -get_type(info::TypeInfo, x::Argument) = info[1][x.n - 1] -get_type(info::TypeInfo, x::ID) = CC.widenconst(info[2][x]) -get_type(::TypeInfo, x::QuoteNode) = _typeof(x.value) -get_type(::TypeInfo, x) = _typeof(x) -function get_type(::TypeInfo, x::GlobalRef) - return isconst(x) ? _typeof(getglobal(x.mod, x.name)) : x.binding.ty -end -function get_type(::TypeInfo, x::Expr) - x.head === :boundscheck && return Bool - return error("Unrecognised expression $x found in argument slot.") -end - -function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple,Vector{Any}} - - # The location from which all state can be retrieved. Since we're using `OpaqueClosure`s - # to implement `TapedTask`s, this appears via the first argument. - refs_id = Argument(1) - - # Increment all arguments by 1. - for bb in ir.blocks, (n, inst) in enumerate(bb.insts) - bb.insts[n] = CC.NewInstruction( - inc_args(inst.stmt), inst.type, inst.info, inst.line, inst.flag - ) - end - - # Construct map between SSA IDs and their index in the state data structure and back. - # Also construct a map from each ref index to its type. We only construct `Ref`s - # for statements which return a value e.g. `IDGotoIfNot`s do not have a meaningful - # return value, so there's no need to allocate a `Ref` for them. - ssa_id_to_ref_index_map = Dict{ID,Int}() - ref_index_to_ssa_id_map = Dict{Int,ID}() - ref_index_to_type_map = Dict{Int,Type}() - id_to_type_map = Dict{ID,Type}() - is_used_dict = characterise_used_ids(collect_stmts(ir)) - n = 0 - for bb in ir.blocks - for (id, stmt) in zip(bb.inst_ids, bb.insts) - id_to_type_map[id] = CC.widenconst(stmt.type) - stmt.stmt isa IDGotoNode && continue - stmt.stmt isa IDGotoIfNot && continue - stmt.stmt === nothing && continue - stmt.stmt isa ReturnNode && continue - is_used_dict[id] || continue - n += 1 - ssa_id_to_ref_index_map[id] = n - ref_index_to_ssa_id_map[n] = id - ref_index_to_type_map[n] = CC.widenconst(stmt.type) - end - end - - # Specify data structure containing `Ref`s for all of the SSAs. - _refs = Any[Ref{ref_index_to_type_map[p]}() for p in 1:length(ref_index_to_ssa_id_map)] - - # Ensure that each basic block ends with a non-producing statement. This is achieved by - # replacing any fall-through terminators with `IDGotoNode`s. This is not strictly - # necessary, but simplifies later stages of the pipeline, as discussed variously below. - for (n, block) in enumerate(ir.blocks) - if terminator(block) === nothing - # Fall-through terminator, so next block in `ir.blocks` is the unique successor - # block of `block`. Final block cannot have a fall-through terminator, so asking - # for element `n + 1` is always going to be valid. - successor_id = ir.blocks[n + 1].id - push!(block.insts, new_inst(IDGotoNode(successor_id))) - push!(block.inst_ids, ID()) - end - end - - # For each existing basic block, create a sequence of `NamedTuple`s which - # define the manner in which it must be split. - # A block will in general be split as follows: - # 1 - %1 = φ(...) - # 1 - %2 = φ(...) - # 1 - %3 = call_which_must_not_produce(...) - # 1 - %4 = produce(%3) - # 2 - %5 = call_which_must_not_produce(...) - # 2 - %6 = call_which_might_produce(...) - # 3 - %7 = call_which_must_not_produce(...) - # 3 - terminator (GotoIfNot, GotoNode, etc) - # - # The numbers on the left indicate which split each statement falls. The first - # split comprises all statements up until the first produce / call-which-might-produce. - # Consequently, the first split will always contain any `PhiNode`s present in the block. - # The next set of statements up until the next produce / call-which-might-produce form - # the second split, and so on. - # We enforced above the condition that the final statement in a basic block must not - # produce. This ensures that the final split does not produce. While not strictly - # necessary, this simplifies the implementation (see below). - # - # As a result of the above, a basic block will be associated to exactly one split if it - # does not contain any statements which may produce. - # - # Each `NamedTuple` contains a `start` index and `last` index, indicating the position - # in the block at which the corresponding split starts and finishes. - all_splits = map(ir.blocks) do block - split_ends = vcat( - findall( - inst -> stmt_might_produce(inst.stmt, CC.widenconst(inst.type)), - block.insts, - ), - length(block), - ) - return map(enumerate(split_ends)) do (n, split_end) - return (start=(n == 1 ? 0 : split_ends[n - 1]) + 1, last=split_end) - end - end - - # Owing to splitting blocks up, we will need to re-label some `GotoNode`s and - # `GotoIfNot`s. To understand this, consider the following block, whose original `ID` - # we assume to be `ID(old_id)`. - # ID(new_id) - %1 = φ(ID(3) => ...) - # ID(new_id) - %3 = call_which_must_not_produce(...) - # ID(new_id) - %4 = produce(%3) - # ID(old_id) - GotoNode(ID(5)) - # - # In the above, the entire block was original associated to a single ID, `ID(old_id)`, - # but is now split into two sections. We keep the original ID for the final split, and - # assign a new one to the first split. As a result, any `PhiNode`s in other blocks - # which have edges incoming from `ID(old_id)` will remain valid. - # However, if we adopt this strategy for all blocks, `ID(5)` in the `GotoNode` at the - # end of the block will refer to the wrong block if the block original associated to - # `ID(5)` was itself split, since the "top" of that block will have a new `ID`. - # - # To resolve this, we: - # 1. Associate an ID to each split in each block, ensuring that the ID for the final - # split of each block is the same ID as that of the original block. - all_split_ids = map(zip(ir.blocks, all_splits)) do (block, splits) - return vcat([ID() for _ in splits[1:(end - 1)]], block.id) - end - - # 2. Construct a map between the ID of each block and the ID associated to its split. - top_split_id_map = Dict{ID,ID}(b.id => x[1] for (b, x) in zip(ir.blocks, all_split_ids)) - - # 3. Update all `GotoNode`s and `GotoIfNot`s to refer to these new names. - for block in ir.blocks - t = terminator(block) - if t isa IDGotoNode - block.insts[end] = new_inst(IDGotoNode(top_split_id_map[t.label])) - elseif t isa IDGotoIfNot - block.insts[end] = new_inst(IDGotoIfNot(t.cond, top_split_id_map[t.dest])) - end - end - - # A set of blocks from which we might wish to resume computation. - resume_block_ids = Vector{ID}() - - # A list onto which we'll push the type of any statement which might produce. - possible_produce_types = Any[] - - # This where most of the action happens. - # - # For each split of each block, we must - # 1. translate all statements which accept any SSAs as arguments, or return a value, - # into statements which read in data from the `Ref`s containing the value associated - # to each SSA, and write the result to `Ref`s associated to the SSA of the line in - # question. - # 2. add additional code at the end of the split to handle the possibility that the - # last statement produces (per the definition of the splits above). This applies to - # all splits except the last, which cannot produce by construction. Exactly what - # happens here depends on whether the last statement is a `produce` call, or a - # call-which-might-produce -- see below for specifics. - # - # This code transforms each block (and its splits) into a new collection of blocks. - # Note that the total number of new blocks may be greater than the total number of - # splits, because each split ending in a call-which-might-produce requires more than a - # single block to implement the required resumption functionality. - new_bblocks = map(zip(ir.blocks, all_splits, all_split_ids)) do (bb, splits, splits_ids) - new_blocks = map(enumerate(splits)) do (n, split) - # We'll push ID-NewInstruction pairs to this as we proceed through the split. - inst_pairs = IDInstPair[] - - # PhiNodes: - # - # A single `PhiNode` - # - # ID(%1) = φ(ID(#1) => 1, ID(#2) => ID(%n)) - # - # sets `ID(%1)` to either `1` or whatever value is currently associated to - # `ID(%n)`, depending upon whether the predecessor block was `ID(#1)` or - # `ID(#2)`. Consequently, a single `PhiNode` can be transformed into something - # along the lines of: - # - # ID(%1) = φ(ID(#1) => 1, ID(#2) => TupleRef(ref_ind_for_ID(%n))) - # ID(%2) = deref_phi(refs, ID(%1)) - # set_ref_at!(refs, ref_ind_for_ID(%1), ID(%2)) - # - # where `deref_phi` retrieves the value in position `ref_ind_for_ID(%n)` if - # ID(%1) is a `TupleRef`, and `1` otherwise, and `set_ref_at!` sets the `Ref` - # at position `ref_ind_for_ID(%1)` to the value of `ID(%2)`. See the actual - # implementations below. - # - # If we have multiple `PhiNode`s at the start of a block, we must run all of - # them, then dereference all of their values, and finally write all of the - # de-referenced values to the appropriate locations. This is because - # a. we require all `PhiNode`s appear together at the top of a given basic - # block, and - # b. the semantics of `PhiNode`s is that they are all "run" simultaneously. This - # only matters if one `PhiNode` in the block can refer to the value stored in - # the SSA associated to another. For example, something along the lines of: - # - # ID(%1) = φ(ID(#1) => 1, ID(#2) => ID(%2)) - # ID(%2) = φ(ID(#1) => 1, ID(#2) => 2) - # - # (we leave it as an exercise for the reader to figure out why this particular - # semantic feature of `PhiNode`s is relevant in this specific case). - # - # So, in general, the code produced by this block will look roughly like - # - # ID(%1) = φ(...) - # ID(%2) = φ(...) - # ID(%3) = φ(...) - # ID(%4) = deref_phi(refs, ID(%1)) - # ID(%5) = deref_phi(refs, ID(%2)) - # ID(%6) = deref_phi(refs, ID(%3)) - # set_ref_at!(refs, ref_ind_for_ID(%1), ID(%4)) - # set_ref_at!(refs, ref_ind_for_ID(%2), ID(%5)) - # set_ref_at!(refs, ref_ind_for_ID(%3), ID(%6)) - if n == 1 - # Find all PhiNodes in the block -- will definitely be in this split. - phi_inds = findall(x -> x.stmt isa IDPhiNode, bb.insts) - - # Replace SSA IDs with `TupleRef`s, and record these instructions. - phi_ids = map(phi_inds) do n - phi = bb.insts[n].stmt - for i in eachindex(phi.values) - isassigned(phi.values, i) || continue - v = phi.values[i] - v isa ID || continue - phi.values[i] = TupleRef(ssa_id_to_ref_index_map[v]) - end - phi_id = ID() - push!(inst_pairs, (phi_id, new_inst(phi, Any))) - return phi_id - end - - # De-reference values associated to `IDPhiNode`s. - deref_ids = map(phi_inds) do n - id = bb.inst_ids[n] - phi_id = phi_ids[n] - ref_ind = ssa_id_to_ref_index_map[id] - push!( - inst_pairs, - # The last argument, ref_index_to_type_map[ref_ind], is a - # performance optimisation. The idea is that we know the inferred - # type of the PhiNode from the original IR, and by passing it to - # deref_phi we can type annotate the element type of the Ref - # that it's being dereferenced, resulting in more concrete types - # in the generated IR. - ( - id, - new_inst( - Expr( - :call, - deref_phi, - refs_id, - phi_id, - ref_index_to_type_map[ref_ind], - ), - ), - ), - ) - return id - end - - # Update values stored in `Ref`s associated to `PhiNode`s. - for n in phi_inds - ref_ind = ssa_id_to_ref_index_map[bb.inst_ids[n]] - expr = Expr(:call, set_ref_at!, refs_id, ref_ind, deref_ids[n]) - push!(inst_pairs, (ID(), new_inst(expr))) - end - end - - # Statements which do not produce: - # - # Iterate every statement in the split other than the final one, replacing uses - # of SSAs with de-referenced `Ref`s, and writing the results of statements to - # the corresponding `Ref`s. - _ids = view(bb.inst_ids, (split.start):(split.last - 1)) - _insts = view(bb.insts, (split.start):(split.last - 1)) - for (id, inst) in zip(_ids, _insts) - stmt = inst.stmt - if Meta.isexpr(stmt, :invoke) || - Meta.isexpr(stmt, :call) || - Meta.isexpr(stmt, :new) || - Meta.isexpr(stmt, :foreigncall) || - Meta.isexpr(stmt, :throw_undef_if_not) - - # Find any `ID`s and replace them with calls to read whatever is stored - # in the `Ref`s that they are associated to. - for (n, arg) in enumerate(stmt.args) - arg isa ID || continue - - new_id = ID() - ref_ind = ssa_id_to_ref_index_map[arg] - expr = Expr(:call, get_ref_at, refs_id, ref_ind) - push!(inst_pairs, (new_id, new_inst(expr))) - stmt.args[n] = new_id - end - - # Push the target instruction to the list. - push!(inst_pairs, (id, inst)) - - # If we know it is not possible for this statement to contain any calls - # to produce, then simply write out the result to its `Ref`. If it is - # never used, then there is no need to store it. - if is_used_dict[id] - out_ind = ssa_id_to_ref_index_map[id] - set_ref = Expr(:call, set_ref_at!, refs_id, out_ind, id) - push!(inst_pairs, (ID(), new_inst(set_ref))) - end - elseif Meta.isexpr(stmt, :boundscheck) - push!(inst_pairs, (id, inst)) - elseif Meta.isexpr(stmt, :code_coverage_effect) - push!(inst_pairs, (id, inst)) - elseif Meta.isexpr(stmt, :gc_preserve_begin) - push!(inst_pairs, (id, inst)) - elseif Meta.isexpr(stmt, :gc_preserve_end) - push!(inst_pairs, (id, inst)) - elseif stmt isa Nothing - push!(inst_pairs, (id, inst)) - elseif stmt isa GlobalRef - ref_ind = ssa_id_to_ref_index_map[id] - # We can only use `stmt` as an argument to `set_ref_at!` if it is a - # `const` binding. If it's not const, then we need to generate a new SSA - # value for it. - set_ref_at_arg = if isconst(stmt) - stmt - else - new_id = ID() - push!(inst_pairs, (new_id, new_inst(stmt))) - new_id - end - expr = Expr(:call, set_ref_at!, refs_id, ref_ind, set_ref_at_arg) - push!(inst_pairs, (id, new_inst(expr))) - elseif stmt isa Core.PiNode - if stmt.val isa ID - ref_ind = ssa_id_to_ref_index_map[stmt.val] - val_id = ID() - expr = Expr(:call, get_ref_at, refs_id, ref_ind) - push!(inst_pairs, (val_id, new_inst(expr))) - push!(inst_pairs, (id, new_inst(Core.PiNode(val_id, stmt.typ)))) - else - push!(inst_pairs, (id, inst)) - end - set_ind = ssa_id_to_ref_index_map[id] - set_expr = Expr(:call, set_ref_at!, refs_id, set_ind, id) - push!(inst_pairs, (ID(), new_inst(set_expr))) - elseif stmt isa IDPhiNode - # do nothing -- we've already handled any `PhiNode`s. - elseif Meta.isexpr(stmt, :loopinfo) - push!(inst_pairs, (id, inst)) - else - throw(error("Unhandled stmt $stmt of type $(typeof(stmt))")) - end - end - - # TODO: explain this better. - new_blocks = BBlock[] - - # Produce and Terminators: - # - # Handle the last statement in the split. - id = bb.inst_ids[split.last] - inst = bb.insts[split.last] - stmt = inst.stmt - if n == length(splits) - # This is the last split in the block, so it must end with a non-producing - # terminator. We handle this in a similar way to the statements above. - - if stmt isa ReturnNode - # Reset the position counter to `-1`, so that if this function gets - # called again, execution starts from the beginning. - expr = Expr(:call, set_resume_block!, refs_id, Int32(-1)) - push!(inst_pairs, (ID(), new_inst(expr))) - # If returning an SSA, it might be one whose value was restored from - # before. Therefore, grab it out of storage, rather than assuming that - # it is def-ed. - if isdefined(stmt, :val) && stmt.val isa ID - ref_ind = ssa_id_to_ref_index_map[stmt.val] - val_id = ID() - expr = Expr(:call, get_ref_at, refs_id, ref_ind) - push!(inst_pairs, (val_id, new_inst(expr))) - push!(inst_pairs, (ID(), new_inst(ReturnNode(val_id)))) - else - push!(inst_pairs, (id, inst)) - end - elseif stmt isa IDGotoIfNot - # If the condition is an SSA, it might be one whose value was restored - # from before. Therefore, grab it out of storage, rather than assuming - # that it is defined. - if stmt.cond isa ID - ref_ind = ssa_id_to_ref_index_map[stmt.cond] - cond_id = ID() - expr = Expr(:call, get_ref_at, refs_id, ref_ind) - push!(inst_pairs, (cond_id, new_inst(expr))) - push!(inst_pairs, (ID(), new_inst(IDGotoIfNot(cond_id, stmt.dest)))) - else - push!(inst_pairs, (id, inst)) - end - elseif stmt isa IDGotoNode - push!(inst_pairs, (id, inst)) - else - error("Unexpected terminator $stmt") - end - push!(new_blocks, BBlock(splits_ids[n], inst_pairs)) - elseif is_produce_stmt(stmt) - # This is a statement of the form - # %n = produce(arg) - # - # We transform this into - # Libtask.set_resume_block!(refs_id, id_of_next_block) - # return ProducedValue(arg) - # - # The point is to ensure that, next time that this `TapedTask` is called, - # computation is resumed from the statement _after_ this produce statement, - # and to return whatever this produce statement returns. - - # Log the result type of this statement. - arg = stmt.args[Meta.isexpr(stmt, :invoke) ? 3 : 2] - push!(possible_produce_types, get_type((ir.argtypes, id_to_type_map), arg)) - - # When this TapedTask is next called, we should resume from the first - # statement of the next split. - resume_id = splits_ids[n + 1] - push!(resume_block_ids, resume_id) - - # Insert statement to enforce correct resumption behaviour. - resume_stmt = Expr(:call, set_resume_block!, refs_id, resume_id.id) - push!(inst_pairs, (ID(), new_inst(resume_stmt))) - - # Insert statement to construct a `ProducedValue` from the value. - # Could be that the produce references an SSA, in which case we need to - # de-reference, rather than just return the thing. - prod_val = produce_value(stmt) - if prod_val isa ID - deref_id = ID() - ref_ind = ssa_id_to_ref_index_map[prod_val] - expr = Expr(:call, get_ref_at, refs_id, ref_ind) - push!(inst_pairs, (deref_id, new_inst(expr))) - prod_val = deref_id - end - - # Set the ref for this statement, as we would for any other call or invoke. - # The TapedTask may need to read this ref when it resumes, if the return - # value of `produce` is used within the original function. - if is_used_dict[id] - out_ind = ssa_id_to_ref_index_map[id] - set_ref = Expr(:call, set_ref_at!, refs_id, out_ind, prod_val) - push!(inst_pairs, (ID(), new_inst(set_ref))) - end - - # Construct a `ProducedValue`. - val_id = ID() - push!(inst_pairs, (val_id, new_inst(Expr(:call, ProducedValue, prod_val)))) - - # Insert statement to return the `ProducedValue`. - push!(inst_pairs, (ID(), new_inst(ReturnNode(val_id)))) - - # Construct a single new basic block from all of the inst-pairs. - push!(new_blocks, BBlock(splits_ids[n], inst_pairs)) - else - # The final statement is one which might produce, but is not itself a - # `produce` statement. For example - # y = f(x) - # - # becomes (morally speaking) - # y = f(x) - # if y isa ProducedValue - # set_resume_block!(refs_id, id_of_current_block) - # return y - # end - # - # The point is to ensure that, if `f` "produces" (as indicated by `y` being - # a `ProducedValue`) then the next time that this TapedTask is called, we - # must resume from the call to `f`, as subsequent runs might also produce. - # On the other hand, if anything other than a `ProducedValue` is returned, - # we know that `f` has nothing else to produce, and execution can safely - # continue to the next split. - # In addition to the above, we must do the usual thing and ensure that any - # ssas are read from storage, and write the result of this computation to - # storage before continuing to the next instruction. - # - # You should look at the IR generated by a simple example in the test suite - # which involves calls that might produce, in order to get a sense of what - # the resulting code looks like prior to digging into the code below. - - # At present, we're not able to properly infer the values which might - # potentially be produced by a call-which-might-produce. Consequently, we - # have to assume they can produce anything. - # - # This `Any` only affects the return type of the function being derived - # here. Importantly, it does not affect the type stability of subsequent - # statements in this function. As a result, the impact ought to be - # reasonably limited. - push!(possible_produce_types, Any) - - # Create a new basic block from the existing statements, since all new - # statement need to live in their own basic blocks. - callable_block_id = ID() - push!(inst_pairs, (ID(), new_inst(IDGotoNode(callable_block_id)))) - push!(new_blocks, BBlock(splits_ids[n], inst_pairs)) - - # Derive TapedTask for this statement. - (callable, callable_args) = if Meta.isexpr(stmt, :invoke) - sig = get_mi(stmt.args[1]).specTypes - v = Any[Any] - (LazyCallable{sig,callable_ret_type(sig, v)}(), stmt.args[2:end]) - elseif Meta.isexpr(stmt, :call) - (DynamicCallable(), stmt.args) - else - display(stmt) - println() - error("unhandled statement which might produce $stmt") - end - - # Find any `ID`s and replace them with calls to read whatever is stored - # in the `Ref`s that they are associated to. - callable_inst_pairs = IDInstPair[] - for (n, arg) in enumerate(callable_args) - arg isa ID || continue - - new_id = ID() - ref_ind = ssa_id_to_ref_index_map[arg] - expr = Expr(:call, get_ref_at, refs_id, ref_ind) - push!(callable_inst_pairs, (new_id, new_inst(expr))) - callable_args[n] = new_id - end - - # Allocate a slot in the _refs vector for this callable. - push!(_refs, Ref(callable)) - callable_ind = length(_refs) - - # Retrieve the callable from the refs. - callable_id = ID() - callable_stmt = Expr(:call, get_ref_at, refs_id, callable_ind) - push!(callable_inst_pairs, (callable_id, new_inst(callable_stmt))) - - # Call the callable. - result_id = ID() - result_stmt = Expr(:call, callable_id, callable_args...) - push!(callable_inst_pairs, (result_id, new_inst(result_stmt))) - - # Determine whether this TapedTask has produced a not-a-`ProducedValue`. - not_produced_id = ID() - not_produced_stmt = Expr(:call, not_a_produced, result_id) - push!(callable_inst_pairs, (not_produced_id, new_inst(not_produced_stmt))) - - # Go to a block which just returns the `ProducedValue`, if a - # `ProducedValue` is returned, otherwise continue to the next split. - is_produced_block_id = ID() - is_not_produced_block_id = ID() - switch = Switch( - Any[not_produced_id], - [is_produced_block_id], - is_not_produced_block_id, - ) - push!(callable_inst_pairs, (ID(), new_inst(switch))) - - # Push the above statements onto a new block. - push!(new_blocks, BBlock(callable_block_id, callable_inst_pairs)) - - # Construct block which handles the case that we got a `ProducedValue`. If - # this happens, it means that `callable` has more things to produce still. - # This means that we need to call it again next time we enter this function. - # To achieve this, we set the resume block to the `callable_block_id`, - # and return the `ProducedValue` currently located in `result_id`. - push!(resume_block_ids, callable_block_id) - set_res = Expr(:call, set_resume_block!, refs_id, callable_block_id.id) - return_id = ID() - produced_block_inst_pairs = IDInstPair[ - (ID(), new_inst(set_res)), - (return_id, new_inst(ReturnNode(result_id))), - ] - push!(new_blocks, BBlock(is_produced_block_id, produced_block_inst_pairs)) - - # Construct block which handles the case that we did not get a - # `ProducedValue`. In this case, we must first push the result to the `Ref` - # associated to the call, and goto the next split. - next_block_id = splits_ids[n + 1] # safe since the last split ends with a terminator - if is_used_dict[id] - result_ref_ind = ssa_id_to_ref_index_map[id] - set_ref = Expr(:call, set_ref_at!, refs_id, result_ref_ind, result_id) - else - set_ref = nothing - end - not_produced_block_inst_pairs = IDInstPair[ - (ID(), new_inst(set_ref)) - (ID(), new_inst(IDGotoNode(next_block_id))) - ] - push!( - new_blocks, - BBlock(is_not_produced_block_id, not_produced_block_inst_pairs), - ) - end - return new_blocks - end - return reduce(vcat, new_blocks) - end - new_bblocks = reduce(vcat, new_bblocks) - - # Insert statements at the top. - cases = map(resume_block_ids) do id - return ID(), id, Expr(:call, resume_block_is, refs_id, id.id) - end - cond_ids = ID[x[1] for x in cases] - cond_dests = ID[x[2] for x in cases] - cond_stmts = Any[x[3] for x in cases] - switch_stmt = Switch(Any[x for x in cond_ids], cond_dests, first(new_bblocks).id) - entry_stmts = vcat(cond_stmts, nothing, switch_stmt) - entry_block = BBlock(ID(), vcat(cond_ids, ID(), ID()), map(new_inst, entry_stmts)) - new_bblocks = vcat(entry_block, new_bblocks) - - # New argtypes are the same as the old ones, except we have `Ref`s in the first argument - # rather than nothing at all. - new_argtypes = copy(ir.argtypes) - refs = (_refs..., Ref{Int32}(-1)) - new_argtypes = vcat(typeof(refs), copy(ir.argtypes)) - - # Return BBCode and the `Ref`s. - @static if VERSION >= v"1.12-" - new_ir = BBCode( - new_bblocks, new_argtypes, ir.sptypes, ir.debuginfo, ir.meta, ir.valid_worlds - ) - else - new_ir = BBCode(new_bblocks, new_argtypes, ir.sptypes, ir.linetable, ir.meta) - end - return new_ir, refs, possible_produce_types -end - -# Helper used in `derive_copyable_task_ir`. -@inline get_ref_at(refs::R, n::Int) where {R<:Tuple} = refs[n][] - -# Helper used in `derive_copyable_task_ir`. -@inline function set_ref_at!(refs::R, n::Int, val) where {R<:Tuple} - refs[n][] = val - return nothing -end - -# Helper used in `derive_copyable_task_ir`. -@inline function set_resume_block!(refs::R, id::Int32) where {R<:Tuple} - refs[end][] = id - return nothing -end - -# Helper used in `derive_copyable_task_ir`. -@inline resume_block_is(refs::R, id::Int32) where {R<:Tuple} = !(refs[end][] === id) - -# Helper used in `derive_copyable_task_ir`. -@inline function deref_phi(refs::R, n::TupleRef, ::Type{T}) where {R<:Tuple,T} - ref = refs[n.n] - return ref[]::T -end -@inline deref_phi(::R, x, t::Type) where {R<:Tuple} = x - -# Helper used in `derived_copyable_task_ir`. -@inline not_a_produced(x) = !(isa(x, ProducedValue)) - -# Implement iterator interface. -function Base.iterate(t::TapedTask, state::Nothing=nothing) - v = consume(t) - return v === nothing ? nothing : (v, nothing) -end -Base.IteratorSize(::Type{<:TapedTask}) = Base.SizeUnknown() -Base.IteratorEltype(::Type{<:TapedTask}) = Base.EltypeUnknown() - -""" - LazyCallable - -Used to implement static dispatch, while avoiding the need to construct the -callable immediately. When constructed, just stores the signature of the -callable and its return type. Constructs the callable when first called. - -All type information is known, so it is possible to make this callable type -stable provided that the return type is concrete. -""" -mutable struct LazyCallable{sig<:Tuple,Tret} - mc::MistyClosure - position::Base.RefValue{Int32} - LazyCallable{sig,Tret}() where {sig,Tret} = new{sig,Tret}() -end - -function (l::LazyCallable)(args::Vararg{Any,N}) where {N} - isdefined(l, :mc) || construct_callable!(l) - return l.mc(args...) -end - -function construct_callable!(l::LazyCallable{sig}) where {sig} - mc, pos = build_callable(sig) - l.mc = mc - l.position = pos - return nothing -end - -""" - DynamicCallable - -Like [`LazyCallable`](@ref), but without any type information. Used to implement -dynamic dispatch. -""" -mutable struct DynamicCallable{V} - cache::V -end - -DynamicCallable() = DynamicCallable(Dict{Any,Any}()) - -function (dynamic_callable::DynamicCallable)(args::Vararg{Any,N}) where {N} - sig = _typeof(args) - callable = get(dynamic_callable.cache, sig, nothing) - if callable === nothing - callable = build_callable(sig) - dynamic_callable.cache[sig] = callable - end - return callable[1](args...) -end diff --git a/src/transformation.jl b/src/transformation.jl new file mode 100644 index 00000000..df709b7f --- /dev/null +++ b/src/transformation.jl @@ -0,0 +1,848 @@ +# Helper struct used in `derive_copyable_task_ir`. +struct TupleRef + n::Int +end + +# Unclear whether this is needed. +get_value(x::GlobalRef) = getglobal(x.mod, x.name) +get_value(x::QuoteNode) = x.value +get_value(x) = x + +""" + is_produce_stmt(x)::Bool + +`true` if `x` is an expression of the form `Expr(:call, produce, %x)` or a similar `:invoke` +expression, otherwise `false`. +""" +function is_produce_stmt(x)::Bool + if Meta.isexpr(x, :invoke) && + length(x.args) == 3 && + x.args[1] isa Union{Core.MethodInstance,Core.CodeInstance} + return get_mi(x.args[1]).specTypes <: Tuple{typeof(produce),Any} + elseif Meta.isexpr(x, :call) && length(x.args) == 2 + return get_value(x.args[1]) === produce + else + return false + end +end + +""" + stmt_might_produce(x, ret_type::Type)::Bool + +`true` if `x` might contain a call to `produce`, and `false` otherwise. +""" +function stmt_might_produce(x, ret_type::Type)::Bool + + # Statement will terminate in an unusual fashion, so don't bother recursing. + # This isn't _strictly_ correct (there could be a `produce` statement before the + # `throw` call is hit), but this seems unlikely to happen in practice. If it does, the + # user should get a sensible error message anyway. + ret_type == Union{} && return false + + # Statement will terminate in the usual fashion, so _do_ bother recusing. + is_produce_stmt(x) && return true + if Meta.isexpr(x, :invoke) + mi_sig = get_mi(x.args[1]).specTypes + return ( + might_produce(mi_sig) || any(might_produce_if_sig_contains, mi_sig.parameters) + ) + end + if Meta.isexpr(x, :call) + # This is a hack -- it's perfectly possible for `DataType` calls to produce in general. + f = get_function(x.args[1]) + _might_produce = !isa(f, Union{Core.IntrinsicFunction,Core.Builtin,DataType}) + return _might_produce + end + return false +end + +get_function(x) = x +get_function(x::Expr) = eval(x) +get_function(x::GlobalRef) = isconst(x) ? getglobal(x.mod, x.name) : x.binding + +""" + produce_value(x::Expr) + +Returns the value that a `produce` statement returns. For example, for the statment +`produce(%x)`, this function will return `%x`. +""" +function produce_value(x::Expr) + is_produce_stmt(x) || throw(error("Not a produce statement. Please report this error.")) + Meta.isexpr(x, :invoke) && return x.args[3] + return x.args[2] # must be a `:call` Expr. +end + +struct ProducedValue{T} + x::T +end +ProducedValue(::Type{T}) where {T} = ProducedValue{Type{T}}(T) + +@inline Base.getindex(x::ProducedValue) = x.x + +""" + inc_args(stmt::T)::T where {T} + +Returns a new `T` which is equal to `stmt`, except any `Argument`s present in `stmt` are +incremented by `1`. For example +```jldoctest +julia> Libtask.inc_args(Core.ReturnNode(Core.Argument(1))) +:(return _2) +``` +""" +inc_args(x::Expr) = Expr(x.head, map(__inc, x.args)...) +inc_args(x::ReturnNode) = isdefined(x, :val) ? ReturnNode(__inc(x.val)) : x +inc_args(x::IDGotoIfNot) = IDGotoIfNot(__inc(x.cond), x.dest) +inc_args(x::IDGotoNode) = x +function inc_args(x::IDPhiNode) + new_values = Vector{Any}(undef, length(x.values)) + for n in eachindex(x.values) + if isassigned(x.values, n) + new_values[n] = __inc(x.values[n]) + end + end + return IDPhiNode(x.edges, new_values) +end +inc_args(::Nothing) = nothing +inc_args(x::GlobalRef) = x +inc_args(x::Core.PiNode) = Core.PiNode(__inc(x.val), __inc(x.typ)) + +__inc(x::Argument) = Argument(x.n + 1) +__inc(x) = x + +const TypeInfo = Tuple{Vector{Any},Dict{ID,Type}} + +""" + _typeof(x) + +Central definition of typeof, which is specific to the use-required in this package. +Largely the same as `Base._stable_typeof`, differing only in a handful of +situations, for example: +```jldoctest +julia> Base._stable_typeof((Float64,)) +Tuple{DataType} + +julia> Libtask._typeof((Float64,)) +Tuple{Type{Float64}} +``` +""" +_typeof(x) = Base._stable_typeof(x) +_typeof(x::Tuple) = Tuple{map(_typeof, x)...} +_typeof(x::NamedTuple{names}) where {names} = NamedTuple{names,_typeof(Tuple(x))} + +""" + get_type(info::ADInfo, x) + +Returns the static / inferred type associated to `x`. +""" +get_type(info::TypeInfo, x::Argument) = info[1][x.n - 1] +get_type(info::TypeInfo, x::ID) = CC.widenconst(info[2][x]) +get_type(::TypeInfo, x::QuoteNode) = _typeof(x.value) +get_type(::TypeInfo, x) = _typeof(x) +function get_type(::TypeInfo, x::GlobalRef) + return isconst(x) ? _typeof(getglobal(x.mod, x.name)) : x.binding.ty +end +function get_type(::TypeInfo, x::Expr) + x.head === :boundscheck && return Bool + return error("Unrecognised expression $x found in argument slot.") +end + +function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple,Vector{Any}} + # The location from which all state can be retrieved. Since we're using `OpaqueClosure`s + # to implement `TapedTask`s, this appears via the first argument. + refs_id = Argument(1) + + # Increment all arguments by 1. + for bb in ir.blocks, (n, inst) in enumerate(bb.insts) + bb.insts[n] = CC.NewInstruction( + inc_args(inst.stmt), inst.type, inst.info, inst.line, inst.flag + ) + end + + # Construct map between SSA IDs and their index in the state data structure and back. + # Also construct a map from each ref index to its type. We only construct `Ref`s + # for statements which return a value e.g. `IDGotoIfNot`s do not have a meaningful + # return value, so there's no need to allocate a `Ref` for them. + ssa_id_to_ref_index_map = Dict{ID,Int}() + ref_index_to_ssa_id_map = Dict{Int,ID}() + ref_index_to_type_map = Dict{Int,Type}() + id_to_type_map = Dict{ID,Type}() + is_used_dict = characterise_used_ids(collect_stmts(ir)) + n = 0 + for bb in ir.blocks + for (id, stmt) in zip(bb.inst_ids, bb.insts) + id_to_type_map[id] = CC.widenconst(stmt.type) + stmt.stmt isa IDGotoNode && continue + stmt.stmt isa IDGotoIfNot && continue + stmt.stmt === nothing && continue + stmt.stmt isa ReturnNode && continue + is_used_dict[id] || continue + n += 1 + ssa_id_to_ref_index_map[id] = n + ref_index_to_ssa_id_map[n] = id + ref_index_to_type_map[n] = CC.widenconst(stmt.type) + end + end + + # Specify data structure containing `Ref`s for all of the SSAs. + _refs = Any[Ref{ref_index_to_type_map[p]}() for p in 1:length(ref_index_to_ssa_id_map)] + + # Ensure that each basic block ends with a non-producing statement. This is achieved by + # replacing any fall-through terminators with `IDGotoNode`s. This is not strictly + # necessary, but simplifies later stages of the pipeline, as discussed variously below. + for (n, block) in enumerate(ir.blocks) + if terminator(block) === nothing + # Fall-through terminator, so next block in `ir.blocks` is the unique successor + # block of `block`. Final block cannot have a fall-through terminator, so asking + # for element `n + 1` is always going to be valid. + successor_id = ir.blocks[n + 1].id + push!(block.insts, new_inst(IDGotoNode(successor_id))) + push!(block.inst_ids, ID()) + end + end + + # For each existing basic block, create a sequence of `NamedTuple`s which + # define the manner in which it must be split. + # A block will in general be split as follows: + # 1 - %1 = φ(...) + # 1 - %2 = φ(...) + # 1 - %3 = call_which_must_not_produce(...) + # 1 - %4 = produce(%3) + # 2 - %5 = call_which_must_not_produce(...) + # 2 - %6 = call_which_might_produce(...) + # 3 - %7 = call_which_must_not_produce(...) + # 3 - terminator (GotoIfNot, GotoNode, etc) + # + # The numbers on the left indicate which split each statement falls. The first + # split comprises all statements up until the first produce / call-which-might-produce. + # Consequently, the first split will always contain any `PhiNode`s present in the block. + # The next set of statements up until the next produce / call-which-might-produce form + # the second split, and so on. + # We enforced above the condition that the final statement in a basic block must not + # produce. This ensures that the final split does not produce. While not strictly + # necessary, this simplifies the implementation (see below). + # + # As a result of the above, a basic block will be associated to exactly one split if it + # does not contain any statements which may produce. + # + # Each `NamedTuple` contains a `start` index and `last` index, indicating the position + # in the block at which the corresponding split starts and finishes. + all_splits = map(ir.blocks) do block + split_ends = vcat( + findall( + inst -> stmt_might_produce(inst.stmt, CC.widenconst(inst.type)), + block.insts, + ), + length(block), + ) + return map(enumerate(split_ends)) do (n, split_end) + return (start=(n == 1 ? 0 : split_ends[n - 1]) + 1, last=split_end) + end + end + + # Owing to splitting blocks up, we will need to re-label some `GotoNode`s and + # `GotoIfNot`s. To understand this, consider the following block, whose original `ID` + # we assume to be `ID(old_id)`. + # ID(new_id) - %1 = φ(ID(3) => ...) + # ID(new_id) - %3 = call_which_must_not_produce(...) + # ID(new_id) - %4 = produce(%3) + # ID(old_id) - GotoNode(ID(5)) + # + # In the above, the entire block was original associated to a single ID, `ID(old_id)`, + # but is now split into two sections. We keep the original ID for the final split, and + # assign a new one to the first split. As a result, any `PhiNode`s in other blocks + # which have edges incoming from `ID(old_id)` will remain valid. + # However, if we adopt this strategy for all blocks, `ID(5)` in the `GotoNode` at the + # end of the block will refer to the wrong block if the block original associated to + # `ID(5)` was itself split, since the "top" of that block will have a new `ID`. + # + # To resolve this, we: + # 1. Associate an ID to each split in each block, ensuring that the ID for the final + # split of each block is the same ID as that of the original block. + all_split_ids = map(zip(ir.blocks, all_splits)) do (block, splits) + return vcat([ID() for _ in splits[1:(end - 1)]], block.id) + end + + # 2. Construct a map between the ID of each block and the ID associated to its split. + top_split_id_map = Dict{ID,ID}(b.id => x[1] for (b, x) in zip(ir.blocks, all_split_ids)) + + # 3. Update all `GotoNode`s and `GotoIfNot`s to refer to these new names. + for block in ir.blocks + t = terminator(block) + if t isa IDGotoNode + block.insts[end] = new_inst(IDGotoNode(top_split_id_map[t.label])) + elseif t isa IDGotoIfNot + block.insts[end] = new_inst(IDGotoIfNot(t.cond, top_split_id_map[t.dest])) + end + end + + # A set of blocks from which we might wish to resume computation. + resume_block_ids = Vector{ID}() + + # A list onto which we'll push the type of any statement which might produce. + possible_produce_types = Any[] + + # This where most of the action happens. + # + # For each split of each block, we must + # 1. translate all statements which accept any SSAs as arguments, or return a value, + # into statements which read in data from the `Ref`s containing the value associated + # to each SSA, and write the result to `Ref`s associated to the SSA of the line in + # question. + # 2. add additional code at the end of the split to handle the possibility that the + # last statement produces (per the definition of the splits above). This applies to + # all splits except the last, which cannot produce by construction. Exactly what + # happens here depends on whether the last statement is a `produce` call, or a + # call-which-might-produce -- see below for specifics. + # + # This code transforms each block (and its splits) into a new collection of blocks. + # Note that the total number of new blocks may be greater than the total number of + # splits, because each split ending in a call-which-might-produce requires more than a + # single block to implement the required resumption functionality. + new_bblocks = map(zip(ir.blocks, all_splits, all_split_ids)) do (bb, splits, splits_ids) + new_blocks = map(enumerate(splits)) do (n, split) + # We'll push ID-NewInstruction pairs to this as we proceed through the split. + inst_pairs = IDInstPair[] + + # PhiNodes: + # + # A single `PhiNode` + # + # ID(%1) = φ(ID(#1) => 1, ID(#2) => ID(%n)) + # + # sets `ID(%1)` to either `1` or whatever value is currently associated to + # `ID(%n)`, depending upon whether the predecessor block was `ID(#1)` or + # `ID(#2)`. Consequently, a single `PhiNode` can be transformed into something + # along the lines of: + # + # ID(%1) = φ(ID(#1) => 1, ID(#2) => TupleRef(ref_ind_for_ID(%n))) + # ID(%2) = deref_phi(refs, ID(%1)) + # set_ref_at!(refs, ref_ind_for_ID(%1), ID(%2)) + # + # where `deref_phi` retrieves the value in position `ref_ind_for_ID(%n)` if + # ID(%1) is a `TupleRef`, and `1` otherwise, and `set_ref_at!` sets the `Ref` + # at position `ref_ind_for_ID(%1)` to the value of `ID(%2)`. See the actual + # implementations below. + # + # If we have multiple `PhiNode`s at the start of a block, we must run all of + # them, then dereference all of their values, and finally write all of the + # de-referenced values to the appropriate locations. This is because + # a. we require all `PhiNode`s appear together at the top of a given basic + # block, and + # b. the semantics of `PhiNode`s is that they are all "run" simultaneously. This + # only matters if one `PhiNode` in the block can refer to the value stored in + # the SSA associated to another. For example, something along the lines of: + # + # ID(%1) = φ(ID(#1) => 1, ID(#2) => ID(%2)) + # ID(%2) = φ(ID(#1) => 1, ID(#2) => 2) + # + # (we leave it as an exercise for the reader to figure out why this particular + # semantic feature of `PhiNode`s is relevant in this specific case). + # + # So, in general, the code produced by this block will look roughly like + # + # ID(%1) = φ(...) + # ID(%2) = φ(...) + # ID(%3) = φ(...) + # ID(%4) = deref_phi(refs, ID(%1)) + # ID(%5) = deref_phi(refs, ID(%2)) + # ID(%6) = deref_phi(refs, ID(%3)) + # set_ref_at!(refs, ref_ind_for_ID(%1), ID(%4)) + # set_ref_at!(refs, ref_ind_for_ID(%2), ID(%5)) + # set_ref_at!(refs, ref_ind_for_ID(%3), ID(%6)) + if n == 1 + # Find all PhiNodes in the block -- will definitely be in this split. + phi_inds = findall(x -> x.stmt isa IDPhiNode, bb.insts) + + # Replace SSA IDs with `TupleRef`s, and record these instructions. + phi_ids = map(phi_inds) do n + phi = bb.insts[n].stmt + for i in eachindex(phi.values) + isassigned(phi.values, i) || continue + v = phi.values[i] + v isa ID || continue + phi.values[i] = TupleRef(ssa_id_to_ref_index_map[v]) + end + phi_id = ID() + push!(inst_pairs, (phi_id, new_inst(phi, Any))) + return phi_id + end + + # De-reference values associated to `IDPhiNode`s. + deref_ids = map(phi_inds) do n + id = bb.inst_ids[n] + phi_id = phi_ids[n] + ref_ind = ssa_id_to_ref_index_map[id] + push!( + inst_pairs, + # The last argument, ref_index_to_type_map[ref_ind], is a + # performance optimisation. The idea is that we know the inferred + # type of the PhiNode from the original IR, and by passing it to + # deref_phi we can type annotate the element type of the Ref + # that it's being dereferenced, resulting in more concrete types + # in the generated IR. + ( + id, + new_inst( + Expr( + :call, + deref_phi, + refs_id, + phi_id, + ref_index_to_type_map[ref_ind], + ), + ), + ), + ) + return id + end + + # Update values stored in `Ref`s associated to `PhiNode`s. + for n in phi_inds + ref_ind = ssa_id_to_ref_index_map[bb.inst_ids[n]] + expr = Expr(:call, set_ref_at!, refs_id, ref_ind, deref_ids[n]) + push!(inst_pairs, (ID(), new_inst(expr))) + end + end + + # Statements which do not produce: + # + # Iterate every statement in the split other than the final one, replacing uses + # of SSAs with de-referenced `Ref`s, and writing the results of statements to + # the corresponding `Ref`s. + _ids = view(bb.inst_ids, (split.start):(split.last - 1)) + _insts = view(bb.insts, (split.start):(split.last - 1)) + for (id, inst) in zip(_ids, _insts) + stmt = inst.stmt + if Meta.isexpr(stmt, :invoke) || + Meta.isexpr(stmt, :call) || + Meta.isexpr(stmt, :new) || + Meta.isexpr(stmt, :foreigncall) || + Meta.isexpr(stmt, :throw_undef_if_not) + + # Find any `ID`s and replace them with calls to read whatever is stored + # in the `Ref`s that they are associated to. + for (n, arg) in enumerate(stmt.args) + arg isa ID || continue + + new_id = ID() + ref_ind = ssa_id_to_ref_index_map[arg] + expr = Expr(:call, get_ref_at, refs_id, ref_ind) + push!(inst_pairs, (new_id, new_inst(expr))) + stmt.args[n] = new_id + end + + # Push the target instruction to the list. + push!(inst_pairs, (id, inst)) + + # If we know it is not possible for this statement to contain any calls + # to produce, then simply write out the result to its `Ref`. If it is + # never used, then there is no need to store it. + if is_used_dict[id] + out_ind = ssa_id_to_ref_index_map[id] + set_ref = Expr(:call, set_ref_at!, refs_id, out_ind, id) + push!(inst_pairs, (ID(), new_inst(set_ref))) + end + elseif Meta.isexpr(stmt, :boundscheck) + push!(inst_pairs, (id, inst)) + elseif Meta.isexpr(stmt, :code_coverage_effect) + push!(inst_pairs, (id, inst)) + elseif Meta.isexpr(stmt, :gc_preserve_begin) + push!(inst_pairs, (id, inst)) + elseif Meta.isexpr(stmt, :gc_preserve_end) + push!(inst_pairs, (id, inst)) + elseif stmt isa Nothing + push!(inst_pairs, (id, inst)) + elseif stmt isa GlobalRef + ref_ind = ssa_id_to_ref_index_map[id] + # We can only use `stmt` as an argument to `set_ref_at!` if it is a + # `const` binding. If it's not const, then we need to generate a new SSA + # value for it. + set_ref_at_arg = if isconst(stmt) + stmt + else + new_id = ID() + push!(inst_pairs, (new_id, new_inst(stmt))) + new_id + end + expr = Expr(:call, set_ref_at!, refs_id, ref_ind, set_ref_at_arg) + push!(inst_pairs, (id, new_inst(expr))) + elseif stmt isa Core.PiNode + if stmt.val isa ID + ref_ind = ssa_id_to_ref_index_map[stmt.val] + val_id = ID() + expr = Expr(:call, get_ref_at, refs_id, ref_ind) + push!(inst_pairs, (val_id, new_inst(expr))) + push!(inst_pairs, (id, new_inst(Core.PiNode(val_id, stmt.typ)))) + else + push!(inst_pairs, (id, inst)) + end + set_ind = ssa_id_to_ref_index_map[id] + set_expr = Expr(:call, set_ref_at!, refs_id, set_ind, id) + push!(inst_pairs, (ID(), new_inst(set_expr))) + elseif stmt isa IDPhiNode + # do nothing -- we've already handled any `PhiNode`s. + elseif Meta.isexpr(stmt, :loopinfo) + push!(inst_pairs, (id, inst)) + else + throw(error("Unhandled stmt $stmt of type $(typeof(stmt))")) + end + end + + # TODO: explain this better. + new_blocks = BBlock[] + + # Produce and Terminators: + # + # Handle the last statement in the split. + id = bb.inst_ids[split.last] + inst = bb.insts[split.last] + stmt = inst.stmt + if n == length(splits) + # This is the last split in the block, so it must end with a non-producing + # terminator. We handle this in a similar way to the statements above. + + if stmt isa ReturnNode + # Reset the position counter to `-1`, so that if this function gets + # called again, execution starts from the beginning. + expr = Expr(:call, set_resume_block!, refs_id, Int32(-1)) + push!(inst_pairs, (ID(), new_inst(expr))) + # If returning an SSA, it might be one whose value was restored from + # before. Therefore, grab it out of storage, rather than assuming that + # it is def-ed. + if isdefined(stmt, :val) && stmt.val isa ID + ref_ind = ssa_id_to_ref_index_map[stmt.val] + val_id = ID() + expr = Expr(:call, get_ref_at, refs_id, ref_ind) + push!(inst_pairs, (val_id, new_inst(expr))) + push!(inst_pairs, (ID(), new_inst(ReturnNode(val_id)))) + else + push!(inst_pairs, (id, inst)) + end + elseif stmt isa IDGotoIfNot + # If the condition is an SSA, it might be one whose value was restored + # from before. Therefore, grab it out of storage, rather than assuming + # that it is defined. + if stmt.cond isa ID + ref_ind = ssa_id_to_ref_index_map[stmt.cond] + cond_id = ID() + expr = Expr(:call, get_ref_at, refs_id, ref_ind) + push!(inst_pairs, (cond_id, new_inst(expr))) + push!(inst_pairs, (ID(), new_inst(IDGotoIfNot(cond_id, stmt.dest)))) + else + push!(inst_pairs, (id, inst)) + end + elseif stmt isa IDGotoNode + push!(inst_pairs, (id, inst)) + else + error("Unexpected terminator $stmt") + end + push!(new_blocks, BBlock(splits_ids[n], inst_pairs)) + elseif is_produce_stmt(stmt) + # This is a statement of the form + # %n = produce(arg) + # + # We transform this into + # Libtask.set_resume_block!(refs_id, id_of_next_block) + # return ProducedValue(arg) + # + # The point is to ensure that, next time that this `TapedTask` is called, + # computation is resumed from the statement _after_ this produce statement, + # and to return whatever this produce statement returns. + + # Log the result type of this statement. + arg = stmt.args[Meta.isexpr(stmt, :invoke) ? 3 : 2] + push!(possible_produce_types, get_type((ir.argtypes, id_to_type_map), arg)) + + # When this TapedTask is next called, we should resume from the first + # statement of the next split. + resume_id = splits_ids[n + 1] + push!(resume_block_ids, resume_id) + + # Insert statement to enforce correct resumption behaviour. + resume_stmt = Expr(:call, set_resume_block!, refs_id, resume_id.id) + push!(inst_pairs, (ID(), new_inst(resume_stmt))) + + # Insert statement to construct a `ProducedValue` from the value. + # Could be that the produce references an SSA, in which case we need to + # de-reference, rather than just return the thing. + prod_val = produce_value(stmt) + if prod_val isa ID + deref_id = ID() + ref_ind = ssa_id_to_ref_index_map[prod_val] + expr = Expr(:call, get_ref_at, refs_id, ref_ind) + push!(inst_pairs, (deref_id, new_inst(expr))) + prod_val = deref_id + end + + # Set the ref for this statement, as we would for any other call or invoke. + # The TapedTask may need to read this ref when it resumes, if the return + # value of `produce` is used within the original function. + if is_used_dict[id] + out_ind = ssa_id_to_ref_index_map[id] + set_ref = Expr(:call, set_ref_at!, refs_id, out_ind, prod_val) + push!(inst_pairs, (ID(), new_inst(set_ref))) + end + + # Construct a `ProducedValue`. + val_id = ID() + push!(inst_pairs, (val_id, new_inst(Expr(:call, ProducedValue, prod_val)))) + + # Insert statement to return the `ProducedValue`. + push!(inst_pairs, (ID(), new_inst(ReturnNode(val_id)))) + + # Construct a single new basic block from all of the inst-pairs. + push!(new_blocks, BBlock(splits_ids[n], inst_pairs)) + else + # The final statement is one which might produce, but is not itself a + # `produce` statement. For example + # y = f(x) + # + # becomes (morally speaking) + # y = f(x) + # if y isa ProducedValue + # set_resume_block!(refs_id, id_of_current_block) + # return y + # end + # + # The point is to ensure that, if `f` "produces" (as indicated by `y` being + # a `ProducedValue`) then the next time that this TapedTask is called, we + # must resume from the call to `f`, as subsequent runs might also produce. + # On the other hand, if anything other than a `ProducedValue` is returned, + # we know that `f` has nothing else to produce, and execution can safely + # continue to the next split. + # In addition to the above, we must do the usual thing and ensure that any + # ssas are read from storage, and write the result of this computation to + # storage before continuing to the next instruction. + # + # You should look at the IR generated by a simple example in the test suite + # which involves calls that might produce, in order to get a sense of what + # the resulting code looks like prior to digging into the code below. + + # At present, we're not able to properly infer the values which might + # potentially be produced by a call-which-might-produce. Consequently, we + # have to assume they can produce anything. + # + # This `Any` only affects the return type of the function being derived + # here. Importantly, it does not affect the type stability of subsequent + # statements in this function. As a result, the impact ought to be + # reasonably limited. + push!(possible_produce_types, Any) + + # Create a new basic block from the existing statements, since all new + # statement need to live in their own basic blocks. + callable_block_id = ID() + push!(inst_pairs, (ID(), new_inst(IDGotoNode(callable_block_id)))) + push!(new_blocks, BBlock(splits_ids[n], inst_pairs)) + + # Derive TapedTask for this statement. + (callable, callable_args) = if Meta.isexpr(stmt, :invoke) + sig = get_mi(stmt.args[1]).specTypes + v = Any[Any] + (LazyCallable{sig,callable_ret_type(sig, v)}(), stmt.args[2:end]) + elseif Meta.isexpr(stmt, :call) + (DynamicCallable(), stmt.args) + else + display(stmt) + println() + error("unhandled statement which might produce $stmt") + end + + # Find any `ID`s and replace them with calls to read whatever is stored + # in the `Ref`s that they are associated to. + callable_inst_pairs = IDInstPair[] + for (n, arg) in enumerate(callable_args) + arg isa ID || continue + + new_id = ID() + ref_ind = ssa_id_to_ref_index_map[arg] + expr = Expr(:call, get_ref_at, refs_id, ref_ind) + push!(callable_inst_pairs, (new_id, new_inst(expr))) + callable_args[n] = new_id + end + + # Allocate a slot in the _refs vector for this callable. + push!(_refs, Ref(callable)) + callable_ind = length(_refs) + + # Retrieve the callable from the refs. + callable_id = ID() + callable_stmt = Expr(:call, get_ref_at, refs_id, callable_ind) + push!(callable_inst_pairs, (callable_id, new_inst(callable_stmt))) + + # Call the callable. + result_id = ID() + result_stmt = Expr(:call, callable_id, callable_args...) + push!(callable_inst_pairs, (result_id, new_inst(result_stmt))) + + # Determine whether this TapedTask has produced a not-a-`ProducedValue`. + not_produced_id = ID() + not_produced_stmt = Expr(:call, not_a_produced, result_id) + push!(callable_inst_pairs, (not_produced_id, new_inst(not_produced_stmt))) + + # Go to a block which just returns the `ProducedValue`, if a + # `ProducedValue` is returned, otherwise continue to the next split. + is_produced_block_id = ID() + is_not_produced_block_id = ID() + switch = Switch( + Any[not_produced_id], + [is_produced_block_id], + is_not_produced_block_id, + ) + push!(callable_inst_pairs, (ID(), new_inst(switch))) + + # Push the above statements onto a new block. + push!(new_blocks, BBlock(callable_block_id, callable_inst_pairs)) + + # Construct block which handles the case that we got a `ProducedValue`. If + # this happens, it means that `callable` has more things to produce still. + # This means that we need to call it again next time we enter this function. + # To achieve this, we set the resume block to the `callable_block_id`, + # and return the `ProducedValue` currently located in `result_id`. + push!(resume_block_ids, callable_block_id) + set_res = Expr(:call, set_resume_block!, refs_id, callable_block_id.id) + return_id = ID() + produced_block_inst_pairs = IDInstPair[ + (ID(), new_inst(set_res)), + (return_id, new_inst(ReturnNode(result_id))), + ] + push!(new_blocks, BBlock(is_produced_block_id, produced_block_inst_pairs)) + + # Construct block which handles the case that we did not get a + # `ProducedValue`. In this case, we must first push the result to the `Ref` + # associated to the call, and goto the next split. + next_block_id = splits_ids[n + 1] # safe since the last split ends with a terminator + if is_used_dict[id] + result_ref_ind = ssa_id_to_ref_index_map[id] + set_ref = Expr(:call, set_ref_at!, refs_id, result_ref_ind, result_id) + else + set_ref = nothing + end + not_produced_block_inst_pairs = IDInstPair[ + (ID(), new_inst(set_ref)) + (ID(), new_inst(IDGotoNode(next_block_id))) + ] + push!( + new_blocks, + BBlock(is_not_produced_block_id, not_produced_block_inst_pairs), + ) + end + return new_blocks + end + return reduce(vcat, new_blocks) + end + new_bblocks = reduce(vcat, new_bblocks) + + # Insert statements at the top. + cases = map(resume_block_ids) do id + return ID(), id, Expr(:call, resume_block_is, refs_id, id.id) + end + cond_ids = ID[x[1] for x in cases] + cond_dests = ID[x[2] for x in cases] + cond_stmts = Any[x[3] for x in cases] + switch_stmt = Switch(Any[x for x in cond_ids], cond_dests, first(new_bblocks).id) + entry_stmts = vcat(cond_stmts, nothing, switch_stmt) + entry_block = BBlock(ID(), vcat(cond_ids, ID(), ID()), map(new_inst, entry_stmts)) + new_bblocks = vcat(entry_block, new_bblocks) + + # New argtypes are the same as the old ones, except we have `Ref`s in the first argument + # rather than nothing at all. + new_argtypes = copy(ir.argtypes) + refs = (_refs..., Ref{Int32}(-1)) + new_argtypes = vcat(typeof(refs), copy(ir.argtypes)) + + # Return BBCode and the `Ref`s. + @static if VERSION >= v"1.12-" + new_ir = BBCode( + new_bblocks, new_argtypes, ir.sptypes, ir.debuginfo, ir.meta, ir.valid_worlds + ) + else + new_ir = BBCode(new_bblocks, new_argtypes, ir.sptypes, ir.linetable, ir.meta) + end + return new_ir, refs, possible_produce_types +end + +# Helper used in `derive_copyable_task_ir`. +@inline get_ref_at(refs::R, n::Int) where {R<:Tuple} = refs[n][] + +# Helper used in `derive_copyable_task_ir`. +@inline function set_ref_at!(refs::R, n::Int, val) where {R<:Tuple} + refs[n][] = val + return nothing +end + +# Helper used in `derive_copyable_task_ir`. +@inline function set_resume_block!(refs::R, id::Int32) where {R<:Tuple} + refs[end][] = id + return nothing +end + +# Helper used in `derive_copyable_task_ir`. +@inline resume_block_is(refs::R, id::Int32) where {R<:Tuple} = !(refs[end][] === id) + +# Helper used in `derive_copyable_task_ir`. +@inline function deref_phi(refs::R, n::TupleRef, ::Type{T}) where {R<:Tuple,T} + ref = refs[n.n] + return ref[]::T +end +@inline deref_phi(::R, x, t::Type) where {R<:Tuple} = x + +# Helper used in `derived_copyable_task_ir`. +@inline not_a_produced(x) = !(isa(x, ProducedValue)) + +# Implement iterator interface. +function Base.iterate(t::TapedTask, state::Nothing=nothing) + v = consume(t) + return v === nothing ? nothing : (v, nothing) +end +Base.IteratorSize(::Type{<:TapedTask}) = Base.SizeUnknown() +Base.IteratorEltype(::Type{<:TapedTask}) = Base.EltypeUnknown() + +""" + LazyCallable + +Used to implement static dispatch, while avoiding the need to construct the +callable immediately. When constructed, just stores the signature of the +callable and its return type. Constructs the callable when first called. + +All type information is known, so it is possible to make this callable type +stable provided that the return type is concrete. +""" +mutable struct LazyCallable{sig<:Tuple,Tret} + mc::MistyClosure + position::Base.RefValue{Int32} + LazyCallable{sig,Tret}() where {sig,Tret} = new{sig,Tret}() +end + +function (l::LazyCallable)(args::Vararg{Any,N}) where {N} + isdefined(l, :mc) || construct_callable!(l) + return l.mc(args...) +end + +function construct_callable!(l::LazyCallable{sig}) where {sig} + mc, pos = build_callable(sig) + l.mc = mc + l.position = pos + return nothing +end + +""" + DynamicCallable + +Like [`LazyCallable`](@ref), but without any type information. Used to implement +dynamic dispatch. +""" +mutable struct DynamicCallable{V} + cache::V +end + +DynamicCallable() = DynamicCallable(Dict{Any,Any}()) + +function (dynamic_callable::DynamicCallable)(args::Vararg{Any,N}) where {N} + sig = _typeof(args) + callable = get(dynamic_callable.cache, sig, nothing) + if callable === nothing + callable = build_callable(sig) + dynamic_callable.cache[sig] = callable + end + return callable[1](args...) +end From 201fa0e5afdb9d5f4f423a8e785c5353c80ed89e Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 24 Mar 2026 20:37:21 +0000 Subject: [PATCH 02/10] Add pretty-printing for BBCode --- src/bbcode.jl | 103 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) diff --git a/src/bbcode.jl b/src/bbcode.jl index bdffd55f..1d451e36 100644 --- a/src/bbcode.jl +++ b/src/bbcode.jl @@ -569,4 +569,107 @@ end _find_id_uses!(d::Dict{ID,Bool}, x::QuoteNode) = nothing _find_id_uses!(d::Dict{ID,Bool}, x) = nothing +# Pretty-printing + +_id_str(id::ID) = string("%", id.id) +_block_str(id::ID) = string("#", id.id) + +_val_str(x::ID) = _id_str(x) +_val_str(x::Argument) = string("_", x.n) +_val_str(x::QuoteNode) = repr(x) +_val_str(x::GlobalRef) = string(x) +_val_str(x::Nothing) = "nothing" +_val_str(x) = repr(x) + +function Base.show(io::IO, id::ID) + return print(io, _id_str(id)) +end + +function Base.show(io::IO, node::IDPhiNode) + print(io, "φ (") + for (i, edge) in enumerate(node.edges) + print(io, _block_str(edge), " => ") + if isassigned(node.values, i) + print(io, _val_str(node.values[i])) + else + print(io, "#undef") + end + i < length(node.edges) && print(io, ", ") + end + return print(io, ")") +end + +function Base.show(io::IO, node::IDGotoNode) + return print(io, "goto ", _block_str(node.label)) +end + +function Base.show(io::IO, node::IDGotoIfNot) + return print(io, "goto ", _block_str(node.dest), " if not ", _val_str(node.cond)) +end + +function Base.show(io::IO, sw::Switch) + print(io, "switch ") + for (i, (cond, dest)) in enumerate(zip(sw.conds, sw.dests)) + print(io, _val_str(cond), " => ", _block_str(dest)) + i < length(sw.conds) && print(io, ", ") + end + return print(io, ", fallthrough ", _block_str(sw.fallthrough_dest)) +end + +function _stmt_str(stmt) + stmt isa Union{IDPhiNode,IDGotoNode,IDGotoIfNot,Switch} && return sprint(show, stmt) + stmt isa ReturnNode && + return isdefined(stmt, :val) ? string("return ", _val_str(stmt.val)) : "unreachable" + stmt isa Expr && return _expr_str(stmt) + stmt isa PiNode && return string("π (", _val_str(stmt.val), ", ", stmt.typ, ")") + return _val_str(stmt) +end + +function _expr_str(x::Expr) + if x.head === :call + f = _val_str(x.args[1]) + args = join((_val_str(a) for a in x.args[2:end]), ", ") + return string(f, "(", args, ")") + end + args = join((_val_str(a) for a in x.args), ", ") + return string("Expr(:", x.head, ", ", args, ")") +end + +function _type_str(@nospecialize(t)) + t === Any && return "" + t === Union{} && return "::Union{}" + return string("::", t) +end + +_is_terminator_stmt(stmt) = stmt isa Terminator || stmt isa ReturnNode + +function Base.show(io::IO, bb::BBlock) + print(io, _block_str(bb.id), " ─") + n = length(bb.insts) + for (i, (id, inst)) in enumerate(zip(bb.inst_ids, bb.insts)) + println(io) + prefix = i < n ? "│ " : "└──" + stmt = inst.stmt + if _is_terminator_stmt(stmt) && i == n + print(io, prefix, " ", _stmt_str(stmt)) + else + print( + io, prefix, " ", _id_str(id), " = ", _stmt_str(stmt), _type_str(inst.type) + ) + end + end +end + +function Base.show(io::IO, ir::BBCode) + println(io, "BBCode (", length(ir.argtypes), " args, ", length(ir.blocks), " blocks)") + for (i, block) in enumerate(ir.blocks) + show(io, block) + i < length(ir.blocks) && println(io) + end +end + +function Base.show(io::IO, ::MIME"text/plain", ir::BBCode) + return show(io, ir) +end + end From 87939f4211c19af49e9a3c9cf4beb22a3752d9d2 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 24 Mar 2026 20:37:27 +0000 Subject: [PATCH 03/10] Improve granularity of `generate_ir` debugging tool --- src/Libtask.jl | 1 + src/copyable_task.jl | 53 ++++++++++++++++++++++++++++---------------- src/refelim.jl | 3 +++ 3 files changed, 38 insertions(+), 19 deletions(-) create mode 100644 src/refelim.jl diff --git a/src/Libtask.jl b/src/Libtask.jl index 2d6475bd..415101c2 100644 --- a/src/Libtask.jl +++ b/src/Libtask.jl @@ -15,6 +15,7 @@ using .BasicBlockCode include("copyable_task.jl") include("transformation.jl") +include("refelim.jl") include("test_utils.jl") export TapedTask, diff --git a/src/copyable_task.jl b/src/copyable_task.jl index 569bda08..17db7013 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -100,23 +100,33 @@ function _throw_ir_error(@nospecialize(sig::Type{<:Tuple})) end """ - generate_ir(optimise::Bool, f, args...; kwargs...) + generate_ir(stage::Symbol, f, args...; kwargs...) -Returns `(original_ir, transformed_ir)` for the call `f(args...; kwargs...)`. +Returns an IR output for the call `f(args...; kwargs...)`. -The first element is the original `IRCode` that Julia generates for the call. The second is -the transformed `IRCode` that Libtask would use to implement the `produce`/`consume` -interface in a `TapedTask`. +`stage` controls the form of the IR that is returned, and corresponds to each stage in the +generation of a `TapedTask`. The options are: -`optimise` controls whether the transformed IR (i.e. for the `TapedTask`) is optimised or -not. Apart from inspecting the effects of optimisation, setting `optimise` to `false` can -also be useful for debugging, because the optimisation pass will also perform verification, -and will error if the IR is malformed (which can happen if Libtask's transformation pass has -bugs). +- `:original` - no transformations applied, this generates a `Core.Compiler.IRCode` + corresponding to the original function call. -This is intended purely as a debugging tool, and is not exported. +- `:originalbb`: same as `:original` but converted to `Libtask.BasicBlock.BBCode` + +- `:transformedbb`: same as `originalbb` but with transformations applied to convert it + into a form that supports the `produce`-`consume` interface. + +- `:toptbb`: same as `transformedbb` but with Libtask optimisations to reduce the number of + stored references + +- `:transformed`: same as `toptbb` but converted back to `Core.Compiler.IRCode` + +- `:final`: same as `transformed` but with Julia's builtin SSA IR optimisations applied + to it. This is the IRCode that is eventually wrapped in the `MistyClosure`. + +This is intended purely as a debugging tool, and is not exported. Breaking changes to the +interface may occur at any time. """ -function generate_ir(optimise::Bool, fargs...; kwargs...) +function generate_ir(stage::Symbol, fargs...; kwargs...) all_args = isempty(kwargs) ? fargs : (Core.kwcall, getfield(kwargs, :data), fargs...) sig = typeof(all_args) ir_results = Base.code_ircode_by_type(sig) @@ -124,14 +134,19 @@ function generate_ir(optimise::Bool, fargs...; kwargs...) _throw_ir_error(sig) end original_ir = ir_results[1][1] + stage == :original && return original_ir seed_id!() - bb, _, _ = derive_copyable_task_ir(BBCode(original_ir)) - transformed_ir = if optimise - optimise_ir!(IRCode(bb)) - else - IRCode(bb) - end - return original_ir, transformed_ir + original_bb = BBCode(original_ir) + stage == :originalbb && return original_bb + transformed_bb, refs, _ = derive_copyable_task_ir(BBCode(original_ir)) + stage == :transformedbb && return transformed_bb + topt_bb, refs = eliminate_refs(transformed_bb, refs) + stage == :toptbb && return topt_bb + transformed_ir = IRCode(transformed_bb) + stage == :transformed && return transformed_ir + optimise_ir!(transformed_ir) + stage == :final && return transformed_ir + throw(ArgumentError("unknown stage $stage")) end """ diff --git a/src/refelim.jl b/src/refelim.jl new file mode 100644 index 00000000..81a0d029 --- /dev/null +++ b/src/refelim.jl @@ -0,0 +1,3 @@ +function eliminate_refs(ir::BBCode, refs::Vector) + return ir, refs +end From de296669d16d8f88e6a3172d2648a8f30a94af79 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 25 Mar 2026 11:42:02 +0000 Subject: [PATCH 04/10] Implement ref elimination pass --- src/bbcode.jl | 42 ++++++++- src/copyable_task.jl | 3 +- src/refelim.jl | 216 ++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 257 insertions(+), 4 deletions(-) diff --git a/src/bbcode.jl b/src/bbcode.jl index 1d451e36..1c9b5ae2 100644 --- a/src/bbcode.jl +++ b/src/bbcode.jl @@ -30,6 +30,7 @@ export ID, insert_before_terminator!, collect_stmts, compute_all_predecessors, + compute_all_successors, BBCode, characterise_used_ids, characterise_unique_predecessor_blocks, @@ -37,7 +38,8 @@ export ID, IDInstPair, __line_numbers_to_block_numbers!, is_reachable_return_node, - new_inst + new_inst, + replace_ids const _id_count::Dict{Int,Int32} = Dict{Int,Int32}() @@ -364,6 +366,44 @@ _block_num_to_ids(d::BlockNumToIdDict, x::GotoNode) = IDGotoNode(d[x.label]) _block_num_to_ids(d::BlockNumToIdDict, x::GotoIfNot) = IDGotoIfNot(x.cond, d[x.dest]) _block_num_to_ids(d::BlockNumToIdDict, x) = x +# A map from IDs to IDs; useful when generically replacing IDs in BBCode statements +const IdToIdDict = Dict{ID,ID} +function replace_ids(d::IdToIdDict, inst::NewInstruction) + return NewInstruction(inst; stmt=replace_ids(d, inst.stmt)) +end +function replace_ids(d::IdToIdDict, x::ReturnNode) + return isdefined(x, :val) ? ReturnNode(get(d, x.val, x.val)) : x +end +replace_ids(d::IdToIdDict, x::Expr) = Expr(x.head, map(a -> get(d, a, a), x.args)...) +replace_ids(d::IdToIdDict, x::PiNode) = PiNode(get(d, x.val, x.val), get(d, x.typ, x.typ)) +replace_ids(d::IdToIdDict, x::QuoteNode) = x +replace_ids(d::IdToIdDict, x) = x +function replace_ids(d::IdToIdDict, x::IDPhiNode) + new_ids = [get(d, e, e) for e in x.edges] + new_values = Vector{Any}(undef, length(x.values)) + for n in eachindex(x.values) + if isassigned(x.values, n) + new_values[n] = get(d, x.values[n], x.values[n]) + end + end + return IDPhiNode(new_ids, new_values) +end +replace_ids(d::IdToIdDict, x::IDGotoNode) = x +function replace_ids(d::IdToIdDict, x::IDGotoIfNot) + return IDGotoIfNot(get(d, x.cond, x.cond), get(d, x.dest, x.dest)) +end +function replace_ids(d::IdToIdDict, x::Switch) + new_conds = Vector{Any}(undef, length(x.conds)) + for n in eachindex(x.conds) + if isassigned(x.conds, n) + new_conds[n] = get(d, x.conds[n], x.conds[n]) + end + end + new_dests = [get(d, dest, dest) for dest in x.dests] + new_fallthrough_dest = get(d, x.fallthrough_dest, x.fallthrough_dest) + return Switch(new_conds, new_dests, new_fallthrough_dest) +end + # # Converting from BBCode to IRCode # diff --git a/src/copyable_task.jl b/src/copyable_task.jl index 17db7013..0ff1ca9b 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -142,7 +142,7 @@ function generate_ir(stage::Symbol, fargs...; kwargs...) stage == :transformedbb && return transformed_bb topt_bb, refs = eliminate_refs(transformed_bb, refs) stage == :toptbb && return topt_bb - transformed_ir = IRCode(transformed_bb) + transformed_ir = IRCode(topt_bb) stage == :transformed && return transformed_ir optimise_ir!(transformed_ir) stage == :final && return transformed_ir @@ -178,6 +178,7 @@ function build_callable(sig::Type{<:Tuple}) # Check whether this is a varargs call. isva = which(sig).isva bb, refs, types = derive_copyable_task_ir(BBCode(ir)) + bb, refs = eliminate_refs(bb, refs) unoptimised_ir = IRCode(bb) @static if VERSION > v"1.12-" # This is a performance optimisation, copied over from Mooncake, where setting diff --git a/src/refelim.jl b/src/refelim.jl index 81a0d029..8d23db09 100644 --- a/src/refelim.jl +++ b/src/refelim.jl @@ -1,3 +1,215 @@ -function eliminate_refs(ir::BBCode, refs::Vector) - return ir, refs +""" + eliminate_refs(ir::BBCode, refs::Vector) + +Transform the `BBCode` to remove redundant `get_ref_at` / `set_ref_at!` calls. + +TODO: Explain more about how this happens + +Returns a tuple of the modified `BBCode` and the modified `refs` vector. +""" +function eliminate_refs(ir::BBCode, refs::Tuple) + # The `refs` tuple contains a series of `Ref`s which are used to maintain function state + # across produce boundaries. This mechanism is, in general, necessary as it allows + # resumption of a function from a produce point with the correct state. However, + # sometimes not all previously defined variables are used in the rest of the function. + # This implies that it is not necessary to maintain the state of those variables across + # produce points. + # + # This function identifies which `Ref`s really do not need to be kept at all, and + # removes them from the IR. The algorithm we use here is a classic 'live variable + # analysis', see e.g., Chapter 8 of Cooper & Torczon 'Engineering a Compiler', 3rd ed. + + # Slightly faffy setup needed here, as we need a function that returns IDs given an + # integer, but we can't call `ID(int)` to get the block ID from the integer; we need to + # locate it from the set of existing IDs. + all_block_ids = Set(block.id for block in ir.blocks) + get_block_id_from_int(i::Integer)::ID = + let + matching_ids = filter(id -> id.id == i, all_block_ids) + isempty(matching_ids) && error("No block ID found for integer $i") + return only(matching_ids) + end + + # Start by constructing, for each basic block `i`, the set of `Ref`s for which the value + # is read in that basic block before being written to. Annoyingly, the literature is + # quite inconsistent in its notation. Cooper & Torczon call this UEVar[i] (i.e., + # 'upward exposed variables'). Sometimes it's called USE[i], and on Wikipedia it's called + # GEN[i]. + # + # Furthermore, we also construct the set of `Ref`s for which the value is written to in + # that basic block. This is `VarKill[i]` in C&T, or `DEF[i]`, or KILL[i] on Wikipedia. + # + # TODO: Handle phi nodes. + use = Dict{ID,Set{Int}}() + def = Dict{ID,Set{Int}}() + # In this pass through the IR, we also capture the resume blocks that Libtask sets in + # each block. This is explained later on. + set_resume_blocks = Dict{ID,Set{ID}}() + for block in ir.blocks + def_i = Set{Int}() # All variables defined in this block + use_i = Set{Int}() # All variables used in this block before being defined in this block + resume_block_i = Set{ID}() + for inst in block.insts + if Meta.isexpr(inst.stmt, :call) + call_func = inst.stmt.args[1] + if call_func == Libtask.set_ref_at! + push!(def_i, inst.stmt.args[3]) + elseif call_func == Libtask.get_ref_at + if !(inst.stmt.args[3] in def_i) + push!(use_i, inst.stmt.args[3]) + end + elseif call_func == Libtask.set_resume_block! + return_block = inst.stmt.args[3] + # A return block of `-1` means the function is ending naturally, and not + # resuming again. + if return_block != -1 + push!(resume_block_i, get_block_id_from_int(return_block)) + end + end + elseif inst.stmt isa IDPhiNode + # For a phi node, we might have wrapped a ref value inside a TupleRef. These + # also count as uses of the ref, so we need to add those to the use set. + for val in inst.stmt.values + if val isa Libtask.TupleRef + push!(use_i, val.n) + end + end + end + end + use[block.id] = use_i + def[block.id] = def_i + set_resume_blocks[block.id] = resume_block_i + end + + # Get a map of successors. + successor_map = compute_all_successors(ir) + # The tricky thing here is that although successor_map is *technically* correct from the + # perspective of the IR, it doesn't capture the fact that we have synthetic edges from + # one block to the other, mediated by calls to `set_resume_block!`. For example, if we + # have something like + # + # #11 ─ + # │ %1 = ... + # │ %2 = Libtask.set_ref_at!(_1, n, %1) + # │ %3 = Libtask.set_resume_block!(_1, 12) + # │ %4 = ... + # └── return %4 + # #12 ─ + # │ %5 = Libtask.get_ref_at(_1, n) + # │ %6 = ... + # └── return %6 + # + # then compute_all_successors will happily say that 12 is *not* a successor of 11, + # because the function returns. However, it's obviously important here that the + # set value of ref `n` is live in block 12, because when the function resumes it + # can jump straight to 12. + # To handle this we need to add synthetic edges from 11 to 12. We do this by parsing + # the IR to find all calls to `set_resume_block!` (which explains the + # `set_resume_blocks` dictionary we constructed above), and then combining that with + # the successor map. + successors = Dict( + id => union(set_resume_blocks[id], Set(natural_successors)) for + (id, natural_successors) in pairs(successor_map) + ) + + # Now we have all the information needed to run the live variable analysis, which is a + # fixed-point iteration. This algorithm is lifted straight from Cooper & Torczon. + changed = true + live_out = Dict{ID,Set{Int}}(id => Set{Int}() for id in all_block_ids) + while (changed) + changed = false + for i in all_block_ids + # Recompute live_out + live_out_i_new = Set{Int}() + for succ_id in successors[i] + live_in_succ = union(use[succ_id], setdiff(live_out[succ_id], def[succ_id])) + live_out_i_new = union(live_out_i_new, live_in_succ) + end + # If it's changed from the previous one, then we need to run another iteration + if live_out_i_new != live_out[i] + changed = true + end + live_out[i] = live_out_i_new + end + end + + # Only the refs that are live at the end of some basic block anywhere in the function + # need to be kept. Note that the last ref in `refs` is always mandatory: it's the one + # that stores the return block (i.e., how far through the function it's progressed). + necessary_ref_ids = sort!(collect(union(values(live_out)...))) + unnecessary_ref_ids = setdiff(1:(length(refs) - 1), necessary_ref_ids) + + # TODO(penelopeysm): We could reduce the size of the ref tuple itself, by dropping refs + # that are never used. I think this is not super important right now: it doesn't really + # hurt to have extra refs lying around in the tuple, because they're just initialised to + # essentially null pointers and never read/written to. But in principle we could get rid + # of them too. + # + # new_refs = tuple( + # [ref for (i, ref) in enumerate(refs) if !(i in unnecessary_ref_ids)]... + # ) + # old_refid_to_new_refid_map = Dict{Int,Int}( + # necessary_ref_ids[i] => i for i in eachindex(necessary_ref_ids) + # ) + + # We now need to go through the IR and remove calls that get/set the unnecessary refs. + new_bblocks = map(ir.blocks) do block + new_insts = IDInstPair[] + # Map, from ref numbers, to the SSA ID that contains the definition of the value + # that would have been stored in that ref. + old_refid_to_ssaid_map = Dict{Int,ID}() + # Map, from SSA IDs that used to contain get_ref_at(refid) values, to the new SSA + # IDs that contain the value itself. + old_ssaid_to_new_ssaid_map = Dict{ID,ID}() + + for (id, inst) in zip(block.inst_ids, block.insts) + if Meta.isexpr(inst.stmt, :call) + call_func = inst.stmt.args[1] + + if call_func == Libtask.set_ref_at! + old_refid = inst.stmt.args[3] + if old_refid in unnecessary_ref_ids + # We can skip this instruction, but first we need to record which + # SSA ID contains the value that we would have set in this ref, so + # that if we encounter a get_ref_at, we can replace it with this + # value. + ssaid = inst.stmt.args[4] + # That value might itself be something that needs to be replaced. + ssaid = get(old_ssaid_to_new_ssaid_map, ssaid, ssaid) + old_refid_to_ssaid_map[old_refid] = inst.stmt.args[4] + else + # It's a set that we still need. + push!(new_insts, (id, inst)) + end + elseif call_func == Libtask.get_ref_at + old_refid = inst.stmt.args[3] + if old_refid in unnecessary_ref_ids + # Eliminate it entirely. + old_ssaid_to_new_ssaid_map[id] = old_refid_to_ssaid_map[old_refid] + else + # It's a get that we still need. + inst = replace_ids(old_ssaid_to_new_ssaid_map, inst) + push!(new_insts, (id, inst)) + end + else + # Some other call instruction. + inst = replace_ids(old_ssaid_to_new_ssaid_map, inst) + push!(new_insts, (id, inst)) + end + else + # Some other (non-call) instruction. + inst = replace_ids(old_ssaid_to_new_ssaid_map, inst) + push!(new_insts, (id, inst)) + end + end + return BBlock(block.id, new_insts) + end + + new_ir = @static if VERSION >= v"1.12-" + BBCode(new_bblocks, ir.argtypes, ir.sptypes, ir.debuginfo, ir.meta, ir.valid_worlds) + else + BBCode(new_bblocks, ir.argtypes, ir.sptypes, ir.linetable, ir.meta) + end + # return ir, refs + return new_ir, refs end From dbbc852c42812ec33efaf637ebd0c65544d9e1a1 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 25 Mar 2026 11:57:06 +0000 Subject: [PATCH 05/10] Handle GlobalRefs --- src/refelim.jl | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/refelim.jl b/src/refelim.jl index 8d23db09..5a64bf72 100644 --- a/src/refelim.jl +++ b/src/refelim.jl @@ -173,10 +173,17 @@ function eliminate_refs(ir::BBCode, refs::Tuple) # SSA ID contains the value that we would have set in this ref, so # that if we encounter a get_ref_at, we can replace it with this # value. - ssaid = inst.stmt.args[4] - # That value might itself be something that needs to be replaced. - ssaid = get(old_ssaid_to_new_ssaid_map, ssaid, ssaid) - old_refid_to_ssaid_map[old_refid] = inst.stmt.args[4] + value_arg = inst.stmt.args[4] + if value_arg isa ID + # That value might itself be something that needs to be replaced. + ssaid = get(old_ssaid_to_new_ssaid_map, value_arg, value_arg) + old_refid_to_ssaid_map[old_refid] = inst.stmt.args[4] + elseif value_arg isa GlobalRef + # If it's a GlobalRef that's being stored in the ref, we just + # need to store the GlobalRef itself inside the SSA ID. + old_refid_to_ssaid_map[old_refid] = id + push!(new_insts, (id, new_inst(value_arg))) + end else # It's a set that we still need. push!(new_insts, (id, inst)) From 83a276287510df6efb24f64d4848a7358050e7dd Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 25 Mar 2026 12:24:54 +0000 Subject: [PATCH 06/10] Add note --- src/refelim.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/refelim.jl b/src/refelim.jl index 5a64bf72..b763e24c 100644 --- a/src/refelim.jl +++ b/src/refelim.jl @@ -113,7 +113,8 @@ function eliminate_refs(ir::BBCode, refs::Tuple) ) # Now we have all the information needed to run the live variable analysis, which is a - # fixed-point iteration. This algorithm is lifted straight from Cooper & Torczon. + # fixed-point iteration. This algorithm is lifted straight from Cooper & Torczon (figure + # 8.15). changed = true live_out = Dict{ID,Set{Int}}(id => Set{Int}() for id in all_block_ids) while (changed) From 76058108b9ae40327b6837aa1f4da66325dd95e9 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 25 Mar 2026 14:47:44 +0000 Subject: [PATCH 07/10] fix a bug where SSA IDs in set_ref_at! weren't being replaced --- src/refelim.jl | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/refelim.jl b/src/refelim.jl index b763e24c..e40f2529 100644 --- a/src/refelim.jl +++ b/src/refelim.jl @@ -178,7 +178,7 @@ function eliminate_refs(ir::BBCode, refs::Tuple) if value_arg isa ID # That value might itself be something that needs to be replaced. ssaid = get(old_ssaid_to_new_ssaid_map, value_arg, value_arg) - old_refid_to_ssaid_map[old_refid] = inst.stmt.args[4] + old_refid_to_ssaid_map[old_refid] = ssaid elseif value_arg isa GlobalRef # If it's a GlobalRef that's being stored in the ref, we just # need to store the GlobalRef itself inside the SSA ID. @@ -187,7 +187,8 @@ function eliminate_refs(ir::BBCode, refs::Tuple) end else # It's a set that we still need. - push!(new_insts, (id, inst)) + ninst = replace_ids(old_ssaid_to_new_ssaid_map, inst) + push!(new_insts, (id, ninst)) end elseif call_func == Libtask.get_ref_at old_refid = inst.stmt.args[3] @@ -196,18 +197,18 @@ function eliminate_refs(ir::BBCode, refs::Tuple) old_ssaid_to_new_ssaid_map[id] = old_refid_to_ssaid_map[old_refid] else # It's a get that we still need. - inst = replace_ids(old_ssaid_to_new_ssaid_map, inst) - push!(new_insts, (id, inst)) + ninst = replace_ids(old_ssaid_to_new_ssaid_map, inst) + push!(new_insts, (id, ninst)) end else # Some other call instruction. - inst = replace_ids(old_ssaid_to_new_ssaid_map, inst) - push!(new_insts, (id, inst)) + ninst = replace_ids(old_ssaid_to_new_ssaid_map, inst) + push!(new_insts, (id, ninst)) end else # Some other (non-call) instruction. - inst = replace_ids(old_ssaid_to_new_ssaid_map, inst) - push!(new_insts, (id, inst)) + ninst = replace_ids(old_ssaid_to_new_ssaid_map, inst) + push!(new_insts, (id, ninst)) end end return BBlock(block.id, new_insts) From bc89258061428d31c099c827505d174efc4ca5a2 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 25 Mar 2026 14:50:18 +0000 Subject: [PATCH 08/10] improve generate_ir symbol arguments --- src/copyable_task.jl | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/copyable_task.jl b/src/copyable_task.jl index 0ff1ca9b..febf294e 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -107,20 +107,20 @@ Returns an IR output for the call `f(args...; kwargs...)`. `stage` controls the form of the IR that is returned, and corresponds to each stage in the generation of a `TapedTask`. The options are: -- `:original` - no transformations applied, this generates a `Core.Compiler.IRCode` +- `:input_ir` - no transformations applied, this generates a `Core.Compiler.IRCode` corresponding to the original function call. -- `:originalbb`: same as `:original` but converted to `Libtask.BasicBlock.BBCode` +- `:input_bb`: same as `:input_ir` but converted to `Libtask.BasicBlock.BBCode` -- `:transformedbb`: same as `originalbb` but with transformations applied to convert it +- `:transformed_bb`: same as `:input_bb` but with transformations applied to convert it into a form that supports the `produce`-`consume` interface. -- `:toptbb`: same as `transformedbb` but with Libtask optimisations to reduce the number of - stored references +- `:optimised_bb`: same as `:transformed_bb` but with Libtask optimisations to reduce the + number of stored references -- `:transformed`: same as `toptbb` but converted back to `Core.Compiler.IRCode` +- `:optimised_ir`: same as `:optimised_bb` but converted back to `Core.Compiler.IRCode` -- `:final`: same as `transformed` but with Julia's builtin SSA IR optimisations applied +- `:final`: same as `:optimised_ir` but with Julia's builtin SSA IR optimisations applied to it. This is the IRCode that is eventually wrapped in the `MistyClosure`. This is intended purely as a debugging tool, and is not exported. Breaking changes to the @@ -134,16 +134,16 @@ function generate_ir(stage::Symbol, fargs...; kwargs...) _throw_ir_error(sig) end original_ir = ir_results[1][1] - stage == :original && return original_ir + stage == :input_ir && return original_ir seed_id!() original_bb = BBCode(original_ir) - stage == :originalbb && return original_bb + stage == :input_bb && return original_bb transformed_bb, refs, _ = derive_copyable_task_ir(BBCode(original_ir)) - stage == :transformedbb && return transformed_bb + stage == :transformed_bb && return transformed_bb topt_bb, refs = eliminate_refs(transformed_bb, refs) - stage == :toptbb && return topt_bb + stage == :optimised_bb && return topt_bb transformed_ir = IRCode(topt_bb) - stage == :transformed && return transformed_ir + stage == :optimised_ir && return transformed_ir optimise_ir!(transformed_ir) stage == :final && return transformed_ir throw(ArgumentError("unknown stage $stage")) From f34a6cd9ac80f428f9ae60dae92a3c44e334739f Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 25 Mar 2026 15:08:43 +0000 Subject: [PATCH 09/10] fix undef phi node values on 1.10 --- src/refelim.jl | 78 +++++++++++++++++++++++++++++++++----------------- 1 file changed, 52 insertions(+), 26 deletions(-) diff --git a/src/refelim.jl b/src/refelim.jl index e40f2529..683f3a27 100644 --- a/src/refelim.jl +++ b/src/refelim.jl @@ -38,8 +38,6 @@ function eliminate_refs(ir::BBCode, refs::Tuple) # # Furthermore, we also construct the set of `Ref`s for which the value is written to in # that basic block. This is `VarKill[i]` in C&T, or `DEF[i]`, or KILL[i] on Wikipedia. - # - # TODO: Handle phi nodes. use = Dict{ID,Set{Int}}() def = Dict{ID,Set{Int}}() # In this pass through the IR, we also capture the resume blocks that Libtask sets in @@ -53,25 +51,31 @@ function eliminate_refs(ir::BBCode, refs::Tuple) if Meta.isexpr(inst.stmt, :call) call_func = inst.stmt.args[1] if call_func == Libtask.set_ref_at! - push!(def_i, inst.stmt.args[3]) + ref_n = inst.stmt.args[3] + push!(def_i, ref_n) elseif call_func == Libtask.get_ref_at - if !(inst.stmt.args[3] in def_i) - push!(use_i, inst.stmt.args[3]) + ref_n = inst.stmt.args[3] + if !(ref_n in def_i) + push!(use_i, ref_n) end elseif call_func == Libtask.set_resume_block! return_block = inst.stmt.args[3] # A return block of `-1` means the function is ending naturally, and not - # resuming again. + # resuming again, so we don't need to add a synthetic edge. if return_block != -1 push!(resume_block_i, get_block_id_from_int(return_block)) end end elseif inst.stmt isa IDPhiNode - # For a phi node, we might have wrapped a ref value inside a TupleRef. These - # also count as uses of the ref, so we need to add those to the use set. - for val in inst.stmt.values - if val isa Libtask.TupleRef - push!(use_i, val.n) + # We might have wrapped a ref value inside a TupleRef as one of the phi + # node's values. These also count as uses of the ref, so we need to add + # those to the use set. + for i in eachindex(inst.stmt.values) + if isassigned(inst.stmt.values, i) + val = inst.stmt.values[i] + if val isa Libtask.TupleRef + push!(use_i, val.n) + end end end end @@ -149,7 +153,7 @@ function eliminate_refs(ir::BBCode, refs::Tuple) # new_refs = tuple( # [ref for (i, ref) in enumerate(refs) if !(i in unnecessary_ref_ids)]... # ) - # old_refid_to_new_refid_map = Dict{Int,Int}( + # refid_to_new_refid_map = Dict{Int,Int}( # necessary_ref_ids[i] => i for i in eachindex(necessary_ref_ids) # ) @@ -158,7 +162,7 @@ function eliminate_refs(ir::BBCode, refs::Tuple) new_insts = IDInstPair[] # Map, from ref numbers, to the SSA ID that contains the definition of the value # that would have been stored in that ref. - old_refid_to_ssaid_map = Dict{Int,ID}() + refid_to_ssaid_map = Dict{Int,ID}() # Map, from SSA IDs that used to contain get_ref_at(refid) values, to the new SSA # IDs that contain the value itself. old_ssaid_to_new_ssaid_map = Dict{ID,ID}() @@ -168,35 +172,57 @@ function eliminate_refs(ir::BBCode, refs::Tuple) call_func = inst.stmt.args[1] if call_func == Libtask.set_ref_at! - old_refid = inst.stmt.args[3] - if old_refid in unnecessary_ref_ids + refid = inst.stmt.args[3] + value_arg = inst.stmt.args[4] + if refid in unnecessary_ref_ids # We can skip this instruction, but first we need to record which # SSA ID contains the value that we would have set in this ref, so # that if we encounter a get_ref_at, we can replace it with this # value. - value_arg = inst.stmt.args[4] if value_arg isa ID - # That value might itself be something that needs to be replaced. + # That value might itself be something that needs to be + # replaced. ssaid = get(old_ssaid_to_new_ssaid_map, value_arg, value_arg) - old_refid_to_ssaid_map[old_refid] = ssaid + refid_to_ssaid_map[refid] = ssaid elseif value_arg isa GlobalRef # If it's a GlobalRef that's being stored in the ref, we just - # need to store the GlobalRef itself inside the SSA ID. - old_refid_to_ssaid_map[old_refid] = id + # need to store the GlobalRef itself inside the SSA ID. In other + # words: + # %1 = set_ref_at!(_1, refid, Main.a) + # can be replaced with + # %1 = Main.a + refid_to_ssaid_map[refid] = id push!(new_insts, (id, new_inst(value_arg))) end else - # It's a set that we still need. + # It's a set that we still need. However, we additionally want to + # track the SSA ID that contains the value being set, so that if we + # encounter a get_ref_at in the same block, we can replace the + # get_ref_at with that value directly. + if value_arg isa ID + refid_to_ssaid_map[refid] = value_arg + elseif value_arg isa GlobalRef + # Create a new SSA ID that points to the GlobalRef. + new_id = ID() + push!(new_insts, (new_id, new_inst(value_arg))) + refid_to_ssaid_map[refid] = new_id + end ninst = replace_ids(old_ssaid_to_new_ssaid_map, inst) push!(new_insts, (id, ninst)) end elseif call_func == Libtask.get_ref_at - old_refid = inst.stmt.args[3] - if old_refid in unnecessary_ref_ids - # Eliminate it entirely. - old_ssaid_to_new_ssaid_map[id] = old_refid_to_ssaid_map[old_refid] + refid = inst.stmt.args[3] + if haskey(refid_to_ssaid_map, refid) + # If `refid` was found in the `refid_to_ssaid_map`, that means we + # have an SSA ID that contains the value that `get_ref_at` would + # have returned anyway. So, we can skip this instruction entirely. + # + # However, we need to record that the SSA ID of the current + # `get_ref_at` instruction should be replaced with that SSA ID, so + # that future instructions don't reference the current ID. + old_ssaid_to_new_ssaid_map[id] = refid_to_ssaid_map[refid] else - # It's a get that we still need. + # It's a get that we legitimately still need. ninst = replace_ids(old_ssaid_to_new_ssaid_map, inst) push!(new_insts, (id, ninst)) end From 5ea45dd78d07d6ac64f0ee79ef2c2a553161fdfa Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 25 Mar 2026 15:17:55 +0000 Subject: [PATCH 10/10] add docstring to docs --- docs/src/internals.md | 1 + src/refelim.jl | 45 ++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/docs/src/internals.md b/docs/src/internals.md index 1311d099..c4de309a 100644 --- a/docs/src/internals.md +++ b/docs/src/internals.md @@ -18,4 +18,5 @@ Libtask.LazyCallable Libtask.DynamicCallable Libtask.callable_ret_type Libtask.fresh_copy +Libtask.eliminate_refs ``` diff --git a/src/refelim.jl b/src/refelim.jl index 683f3a27..0c67013c 100644 --- a/src/refelim.jl +++ b/src/refelim.jl @@ -1,11 +1,50 @@ """ - eliminate_refs(ir::BBCode, refs::Vector) + eliminate_refs(ir::BBCode, refs::Tuple) Transform the `BBCode` to remove redundant `get_ref_at` / `set_ref_at!` calls. -TODO: Explain more about how this happens +This optimises both within a single basic block, and across basic blocks. As an example of +the former, if we have something like -Returns a tuple of the modified `BBCode` and the modified `refs` vector. + %11 = set_ref_at!(_1, 1, %1) + %12 = get_ref_at(_1, 1) + %13 = f(%12) + +then we can replace `%12` with `%1`, and eliminate the load entirely: + + %11 = set_ref_at!(_1, 1, %1) + %13 = f(%1) + +As for an example of optimising across basic blocks, consider a function like + + function f(x) + a = x + 1 + b = a + 2 + produce(b) + c = b * 3 + produce(c) + return nothing + end + +Libtask's transformation pass recognises that `a`, `b`, and `c` all constitute the state of +the function, and creates `Ref`s to store all of those values, and to read/write from them +across produce boundaries. However, in this example, `a` is actually not needed after the +first produce, because it is only used to compute `b`. Thus, it is not necessary to retain +its state. + +This function performs a classic live-variable analysis to identify which `Ref`s are +actually needed across boundaries, and eliminates all calls to `set_ref_at!` that don't need +to be retained. + +Returns a tuple of the modified `BBCode` and the modified `refs` tuple. + +!!! note + Right now, `eliminate_refs` does not remove dead refs from the `refs` tuple itself (so the + TapedTask will be constructed with the same `refs` tuple as before). We simply leave those + refs as unused (i.e., they will be initialised with nothing, and never read from or + written to.) In principle, we could also slim down the `refs` tuple itself by removing + the dead refs from it. This is left as a future optimisation (and the signature of this + function is designed to allow for this in the future). """ function eliminate_refs(ir::BBCode, refs::Tuple) # The `refs` tuple contains a series of `Ref`s which are used to maintain function state