Skip to content

Higher-order functions or broadcasting as optional primitives #1059

@yebai

Description

@yebai

Mooncake currently traces through Base.Broadcast.materialize, map, mapreduce, and reduce on CPU — the AD interpreter walks Julia's implementation IR and differentiates each inner operation. This is correct and general, but can sometimes be (slightly) inefficient in reverse mode when the pullback blocks further compiler optimisation (eg, #249).

For GPU arrays, Mooncake already takes a different approach (#1056): it intercepts materialize as a primitive in DefaultCtx, runs a single fused kernel using NDual{T,N} (an N-wide dual number), and computes primals + all partial derivatives in one pass. The GPU rule cannot trace through because GPU kernels are opaque to Mooncake.

Proposal

Apply the same interception pattern to CPU for higher-order functions with elementwise structure — broadcast, map, and the f-part of mapreduce — while retaining trace-through as the fallback for cases where NDual is not applicable (non-isbits element types, complex reduction operators).

The key mechanism

Mark the function as a primitive Inside the rrule!!, dispatch to NDual if the inputs are amenable, otherwise explicitly invoke the derived (trace-through) rule:

@is_primitive DefaultCtx Tuple{
    typeof(Base.Broadcast.materialize),
    Broadcasted{<:Base.Broadcast.AbstractArrayStyle}   # CPU array styles
}

function rrule!!(
    f_cd::CoDual{typeof(Base.Broadcast.materialize)},
    bc_cd::CoDual{<:Broadcasted{<:AbstractArrayStyle}}
)
    bc = primal(bc_cd)

    if _can_use_ndual(bc)
        # Fast path: single NDual pass, O(N) where N = number of distinct array inputs
        return _ndual_materialize_rrule(f_cd, bc_cd)
    else
        # Pass-through: build derived rule in MinimalCtx (materialise is NOT primitive there)
        derived = build_rrule(
            MooncakeInterpreter(MinimalCtx, ReverseMode),
            Tuple{typeof(materialize), typeof(bc)}
        )
        return derived(f_cd, bc_cd)
    end
end

# NDual is viable when all array args are isbits floats, and N is small
_can_use_ndual(bc::Broadcasted) =
    all(a -> a isa AbstractArray{<:IEEEFloat} || a isa IEEEFloat, bc.args) &&
    _ndual_width(bc)  8   # tunable threshold

The same pattern applies to map (identical diagonal structure to broadcast) and to mapreduce for op ∈ {+, add_sum} (NDual handles the f part; the reduction Jacobian is trivially all-ones).

Benefits

  • One interception point per function family, regardless of .-fusion depth (since Base.Broadcast.flatten already fuses nested broadcasts into one composed function before materialize is called)
  • Fully backwards-compatible: the pass-through path is semantically identical to the current behaviour

Open questions

  1. What is the right threshold for _can_use_ndual (i.e., max N before NDual is slower than the tape)?

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions