Mooncake currently traces through Base.Broadcast.materialize, map, mapreduce, and reduce on CPU — the AD interpreter walks Julia's implementation IR and differentiates each inner operation. This is correct and general, but can sometimes be (slightly) inefficient in reverse mode when the pullback blocks further compiler optimisation (eg, #249).
For GPU arrays, Mooncake already takes a different approach (#1056): it intercepts materialize as a primitive in DefaultCtx, runs a single fused kernel using NDual{T,N} (an N-wide dual number), and computes primals + all partial derivatives in one pass. The GPU rule cannot trace through because GPU kernels are opaque to Mooncake.
Proposal
Apply the same interception pattern to CPU for higher-order functions with elementwise structure — broadcast, map, and the f-part of mapreduce — while retaining trace-through as the fallback for cases where NDual is not applicable (non-isbits element types, complex reduction operators).
The key mechanism
Mark the function as a primitive Inside the rrule!!, dispatch to NDual if the inputs are amenable, otherwise explicitly invoke the derived (trace-through) rule:
@is_primitive DefaultCtx Tuple{
typeof(Base.Broadcast.materialize),
Broadcasted{<:Base.Broadcast.AbstractArrayStyle} # CPU array styles
}
function rrule!!(
f_cd::CoDual{typeof(Base.Broadcast.materialize)},
bc_cd::CoDual{<:Broadcasted{<:AbstractArrayStyle}}
)
bc = primal(bc_cd)
if _can_use_ndual(bc)
# Fast path: single NDual pass, O(N) where N = number of distinct array inputs
return _ndual_materialize_rrule(f_cd, bc_cd)
else
# Pass-through: build derived rule in MinimalCtx (materialise is NOT primitive there)
derived = build_rrule(
MooncakeInterpreter(MinimalCtx, ReverseMode),
Tuple{typeof(materialize), typeof(bc)}
)
return derived(f_cd, bc_cd)
end
end
# NDual is viable when all array args are isbits floats, and N is small
_can_use_ndual(bc::Broadcasted) =
all(a -> a isa AbstractArray{<:IEEEFloat} || a isa IEEEFloat, bc.args) &&
_ndual_width(bc) ≤ 8 # tunable threshold
The same pattern applies to map (identical diagonal structure to broadcast) and to mapreduce for op ∈ {+, add_sum} (NDual handles the f part; the reduction Jacobian is trivially all-ones).
Benefits
- One interception point per function family, regardless of
.-fusion depth (since Base.Broadcast.flatten already fuses nested broadcasts into one composed function before materialize is called)
- Fully backwards-compatible: the pass-through path is semantically identical to the current behaviour
Open questions
- What is the right threshold for
_can_use_ndual (i.e., max N before NDual is slower than the tape)?
Mooncake currently traces through
Base.Broadcast.materialize,map,mapreduce, andreduceon CPU — the AD interpreter walks Julia's implementation IR and differentiates each inner operation. This is correct and general, but can sometimes be (slightly) inefficient in reverse mode when the pullback blocks further compiler optimisation (eg, #249).For GPU arrays, Mooncake already takes a different approach (#1056): it intercepts
materializeas a primitive inDefaultCtx, runs a single fused kernel usingNDual{T,N}(an N-wide dual number), and computes primals + all partial derivatives in one pass. The GPU rule cannot trace through because GPU kernels are opaque to Mooncake.Proposal
Apply the same interception pattern to CPU for higher-order functions with elementwise structure —
broadcast,map, and thef-part ofmapreduce— while retaining trace-through as the fallback for cases where NDual is not applicable (non-isbits element types, complex reduction operators).The key mechanism
Mark the function as a primitive Inside the
rrule!!, dispatch to NDual if the inputs are amenable, otherwise explicitly invoke the derived (trace-through) rule:The same pattern applies to
map(identical diagonal structure to broadcast) and tomapreduceforop ∈ {+, add_sum}(NDual handles thefpart; the reduction Jacobian is trivially all-ones).Benefits
.-fusion depth (sinceBase.Broadcast.flattenalready fuses nested broadcasts into one composed function beforematerializeis called)Open questions
_can_use_ndual(i.e., max N before NDual is slower than the tape)?