diff --git a/src/bbcode.jl b/src/bbcode.jl index 1c9b5ae2..aaecae20 100644 --- a/src/bbcode.jl +++ b/src/bbcode.jl @@ -371,6 +371,7 @@ const IdToIdDict = Dict{ID,ID} function replace_ids(d::IdToIdDict, inst::NewInstruction) return NewInstruction(inst; stmt=replace_ids(d, inst.stmt)) end +replace_ids(d::IdToIdDict, x::ID) = get(d, x, x) function replace_ids(d::IdToIdDict, x::ReturnNode) return isdefined(x, :val) ? ReturnNode(get(d, x.val, x.val)) : x end diff --git a/src/refelim.jl b/src/refelim.jl index 2278afbf..f0eb05fd 100644 --- a/src/refelim.jl +++ b/src/refelim.jl @@ -37,14 +37,6 @@ actually needed across boundaries, and eliminates all calls to `set_ref_at!` tha to be retained. Returns a tuple of the modified `BBCode` and the modified `refs` tuple. - -!!! note - Right now, `eliminate_refs` does not remove dead refs from the `refs` tuple itself (so the - TapedTask will be constructed with the same `refs` tuple as before). We simply leave those - refs as unused (i.e., they will be initialised with nothing, and never read from or - written to.) In principle, we could also slim down the `refs` tuple itself by removing - the dead refs from it. This is left as a future optimisation (and the signature of this - function is designed to allow for this in the future). """ function eliminate_refs(ir::BBCode, refs::Tuple) # The `refs` tuple contains a series of `Ref`s which are used to maintain function state @@ -189,21 +181,15 @@ function eliminate_refs(ir::BBCode, refs::Tuple) # Only the refs that are live at the end of some basic block anywhere in the function # need to be kept. Note that the last ref in `refs` is always mandatory: it's the one # that stores the return block (i.e., how far through the function it's progressed). - necessary_ref_ids = sort!(collect(union(values(live_out)...))) - unnecessary_ref_ids = setdiff(1:(length(refs) - 1), necessary_ref_ids) + necessary_ref_ids = sort!(vcat(length(refs), collect(union(values(live_out)...)))) + unnecessary_ref_ids = setdiff(1:length(refs), necessary_ref_ids) - # TODO(penelopeysm): We could reduce the size of the ref tuple itself, by dropping refs - # that are never used. I think this is not super important right now: it doesn't really - # hurt to have extra refs lying around in the tuple, because they're just initialised to - # essentially null pointers and never read/written to. But in principle we could get rid - # of them too. - # - # new_refs = tuple( - # [ref for (i, ref) in enumerate(refs) if !(i in unnecessary_ref_ids)]... - # ) - # refid_to_new_refid_map = Dict{Int,Int}( - # necessary_ref_ids[i] => i for i in eachindex(necessary_ref_ids) - # ) + new_refs = map(i -> refs[i], tuple(necessary_ref_ids...)) + # Suppose that we want to keep refs 1, 4, and 5. Then this map would be Dict(1 => 1, 4 + # => 2, 5 => 3). + refid_to_new_refid_map = Dict{Int,Int}( + refid => i for (i, refid) in enumerate(necessary_ref_ids) + ) # We now need to go through the IR and remove calls that get/set the unnecessary refs. new_bblocks = map(ir.blocks) do block @@ -263,8 +249,14 @@ function eliminate_refs(ir::BBCode, refs::Tuple) else error("Unexpected value argument to set_ref_at!: $value_arg") end - ninst = replace_ids(old_ssaid_to_new_ssaid_map, inst) - push!(new_insts, (id, ninst)) + ninst = Expr( + :call, + Libtask.set_ref_at!, + replace_ids(old_ssaid_to_new_ssaid_map, inst.stmt.args[2]), + refid_to_new_refid_map[refid], + replace_ids(old_ssaid_to_new_ssaid_map, inst.stmt.args[4]), + ) + push!(new_insts, (id, new_inst(ninst))) end elseif call_func == Libtask.get_ref_at refid = inst.stmt.args[3] @@ -279,16 +271,44 @@ function eliminate_refs(ir::BBCode, refs::Tuple) old_ssaid_to_new_ssaid_map[id] = refid_to_ssaid_map[refid] else # It's a get that we legitimately still need. - ninst = replace_ids(old_ssaid_to_new_ssaid_map, inst) - push!(new_insts, (id, ninst)) + ninst = Expr( + :call, + Libtask.get_ref_at, + replace_ids(old_ssaid_to_new_ssaid_map, inst.stmt.args[2]), + refid_to_new_refid_map[refid], + ) + push!(new_insts, (id, new_inst(ninst))) end else # Some other call instruction. ninst = replace_ids(old_ssaid_to_new_ssaid_map, inst) push!(new_insts, (id, ninst)) end + elseif inst.stmt isa IDPhiNode + # Replace any SSA IDs in the phi node. + ninst = replace_ids(old_ssaid_to_new_ssaid_map, inst) + # then replace any TupleRefs with the new ref id + new_values = Vector{Any}(undef, length(ninst.stmt.values)) + for n in eachindex(ninst.stmt.values) + if isassigned(ninst.stmt.values, n) + val = ninst.stmt.values[n] + new_values[n] = if val isa Libtask.TupleRef + if !haskey(refid_to_new_refid_map, val.n) + # This should never happen, because if `val.n` was in the + # phi node, it always counts as an upwards-exposed use of + # that ref, and should therefore always be included in + # `necessary_ref_ids`. + error("found TupleRef with unused ref id $(val.n)") + end + TupleRef(refid_to_new_refid_map[val.n]) + else + val + end + end + end + push!(new_insts, (id, new_inst(IDPhiNode(ninst.stmt.edges, new_values)))) else - # Some other (non-call) instruction. + # Some other (non-call, non-PhiNode) instruction. ninst = replace_ids(old_ssaid_to_new_ssaid_map, inst) push!(new_insts, (id, ninst)) end @@ -296,13 +316,18 @@ function eliminate_refs(ir::BBCode, refs::Tuple) return BBlock(block.id, new_insts) end + # The tuple of refs is passed in as the first argument to the IR, so we need to update + # the types. + new_argtypes = vcat(typeof(new_refs), copy(ir.argtypes[2:end])) + new_ir = @static if VERSION >= v"1.12-" - BBCode(new_bblocks, ir.argtypes, ir.sptypes, ir.debuginfo, ir.meta, ir.valid_worlds) + BBCode( + new_bblocks, new_argtypes, ir.sptypes, ir.debuginfo, ir.meta, ir.valid_worlds + ) else - BBCode(new_bblocks, ir.argtypes, ir.sptypes, ir.linetable, ir.meta) + BBCode(new_bblocks, new_argtypes, ir.sptypes, ir.linetable, ir.meta) end - # return ir, refs - return new_ir, refs + return new_ir, new_refs end # Return a vector of block IDs in reverse postorder on the reverse CFG (i.e., the CFG where diff --git a/src/test_utils.jl b/src/test_utils.jl index 98e3abc0..9285cfc5 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -226,7 +226,7 @@ function test_cases() "default kwarg tester", nothing, (default_kwarg_tester, 4.0), (;), [], allocs ), Testcase( - "final statment produce", + "final statement produce", nothing, (final_statement_produce,), nothing,