Fix #791: Add threaded_map! with reverse-mode AD support#1104
Fix #791: Add threaded_map! with reverse-mode AD support#1104
Conversation
Introduces Mooncake.threaded_map!(f, output, input) as a new exported primitive for element-wise in-place mapping over IEEEFloat vectors using Threads.@threads. An rrule is provided via build_primitive_rrule, which pre-compiles the scalar rule for f at rule-construction time and reuses it across all elements and calls. The reverse pass (pullback) runs in parallel with the same @threads loop, accumulating cotangents into the input gradient vector. Constraints: T1,T2 <: IEEEFloat (element-level independence) and f must carry no mutable differentiable state (rule is shared across threads). Closes #791. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
make_ad_stmts! always inserts rrule!! for primitive calls in nested IR —
build_primitive_rrule is only used for top-level build_rrule calls.
Replace the ThreadedMapRRule/build_primitive_rrule approach with a direct
rrule!! method. Add an IdDict cache so the inner build_rrule(Tuple{F,T2})
is paid once per (F,T2) type combination.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Accept any number of input vectors: threaded_map!(f, output, x1, x2, ...)
- Update @is_primitive to Vararg{Vector{T}} signature
- Update rrule!! to handle variadic inputs/cotangents; cache key is now (F, N, T);
inner scalar rule built as build_rrule(Tuple{F, Vararg{T,N}})
- Pullback loops over j in 1:N to accumulate cotangents for all inputs
- Returns ntuple(_ -> NoRData(), Val(N))... for the variadic rdatas
- Remove all @inbounds annotations
- Add 2-input (+) test case in hand_written_rule_test_cases
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
When f is a closure capturing IEEEFloat scalars, fdata_type(F)=NoFData so the forward pass is already race-free. The reverse pass now stores per-element rdata for f in f_rdatas[i] (each thread writes only its own slot — race-free), then folds the vector serially via increment_internal!! after the parallel loop. Returns f_rdata at position 2 of the pullback tuple instead of NoRData(). For plain functions (rdata_type(tangent_type(F))==NoRData) the fast path skips all allocation and returns NoRData() as before. Includes a TODO for Option B (per-thread accumulators with :static scheduling) and adds an isbits-closure test case to hand_written_rule_test_cases. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Run element 1 serially to determine the concrete pullback type (PBType =
typeof(pb1)), then allocate Vector{PBType} instead of Vector{Any}. Avoids
boxing and dynamic dispatch on every pullback call in the reverse pass.
- Add @inbounds to all hot array accesses in rrule!! and the pullback (bounds
are guaranteed by the length checks on entry).
- Handle n=0 (empty input vectors) with an early-return no-op pullback.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Previously all vectors (output + inputs) had to share the same element type T.
Now output::Vector{Tout} and each input::Vector{Ti} can have independent element
types as long as Tout, Ti <: IEEEFloat. Race-freedom is unaffected: element-wise
cotangent writes are still independent by index.
- threaded_map! signature: Vector{<:IEEEFloat} for output and each input vararg
- @is_primitive: Vararg{Vector{<:IEEEFloat}} to match any IEEEFloat combination
- rrule!!: output::CoDual{Vector{Tout},...}, inputs::CoDual{<:Vector{<:IEEEFloat},...}...
- Scalar rule cache key changed from (F, N, T) to (F, input_element_types_tuple);
rule built as build_rrule(Tuple{F, T1, T2, ..., TN})
- New test cases: Float32→Float64 single input, Float32+Float64→Float64 two inputs
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
- threaded_map!: remove IEEEFloat constraint, add isbitstype() runtime check
- @is_primitive: match Vector without type parameter
- rrule!!: call rrule!! directly (Julia dispatch is the cache); Vector{Any}
for pullbacks; drop n==0 special case (empty loops handle it naturally)
- Remove all @inbounds annotations
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Replace Vector{Any} with Vector{PBType} where PBType is inferred via
pullback_type (which uses Core.Compiler.return_type, same technique as
Base.Broadcast.combine_eltypes). Eliminates boxing and dynamic dispatch
on the per-element pullback calls in the reverse pass.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Val(N) inside a Threads.@threads loop body is computed with N captured as Int64, so ntuple sees Val{::Int64} and falls back to a generic loop instead of unrolling. Hoisting val_N = Val(N) to the outer scope (where N is a compile-time constant) lets ntuple specialize on Val{N}, giving the expected ~5x primal speedup on n=100k with 12 threads. Applied to both threaded_map! (primal) and rrule!! forward/pullback. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Both loss_serial_ad and loss_threaded now use sum(map!-variant(f, zeros(n), v)), making the AD speedup comparison structurally identical. Primal correctness check also uses map! directly. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
map! operation|
sorry Niko. I don't think I understand Claude's explanation — what is the issue, what is the implementation, and why I should be convinced this is correct? |
I must misunderstand the question, because the obvious answers are: a) #791, b) https://github.com/chalk-lab/Mooncake.jl/blob/6a1c58f5b5b33d10dcb843e4a4c428228de401ea/src/rules/threads.jl, c) AFAICT, it's not correct yet and also a bit wonky - but the main logic should be correct and also quite self contained. But I think you may have been after a different answer? |
|
I think I roughly understood the motivation now, I was a little confused by Claude's explanation. The method makes sense to me. But it does feel like there are more thing to consider, for instance what kind of |
| FRData !== NoRData && (f_rdatas[i] = rdata_tuple[1]) | ||
| for j in 1:N | ||
| xd = xds[j] | ||
| xd isa NoFData || (xd[i] += rdata_tuple[j + 1]) |
There was a problem hiding this comment.
is + defined for all bit type tangents?
|
I incidentally closed this PR. |
- Extract _check_threaded_map! helper (validation + length calc) called from both threaded_map! primal and rrule!!. Previously rrule!! skipped validation. - Drop old_yd copy: the forward pass sets yd[i]=0, so the pullback can restore by zeroing after reading instead of copying and restoring a saved vector. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
For isbits structs, tangent_type(Ti) = Tangent{...} but the inner
pullback returns RData{...} — different types, += fails. For Float64,
both happen to be Float64 so += works, masking the bug.
increment_rdata!!(t, r) = tangent(fdata(t), increment(rdata(t), r))
handles all cases correctly: Float64 (reduces to +), isbits structs
(increments the RData component of the Tangent), NoTangent (no-op).
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Remove issue URL from docstring (line 12)
- Remove unused `where {F}` from threaded_map! (F not referenced in body)
- Add check for f in _check_threaded_map!: fdata_type(tangent_type(typeof(f)))==NoFData
- Widen rrule!! output/input types from CoDual{Vector{Tout},Vector{Tout}} to
CoDual{<:Vector}: the old annotation required tangent_type(Tout)==Tout, which
only holds for IEEEFloat eltypes. Tout was also unused in the body.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
|
I think this could now be ready for review, @sunxd3 - do you want to have a look? You already did a bit, didn't you? personally, I'm a bit confused by |
| pullbacks = Vector{PBType}(undef, n) | ||
| Threads.@threads for i in 1:n | ||
| xi = ntuple(j -> CoDual(xps[j][i], NoFData()), valN) | ||
| yi_codual, pb = rrule!!(zero_fcodual(fp), xi...) |
There was a problem hiding this comment.
For non-primitives, a rule need to be built first.
using Mooncake
function loss_anon(v)
out = similar(v)
threaded_map!(x -> 2x, out, v)
return sum(out)
end
x = randn(4)
# Both fail: threaded_map!'s rrule!! calls rrule!! on f, which has no hand-written rule.
rule = Mooncake.build_rrule(loss_anon, x)
value_and_gradient!!(rule, loss_anon, x)
# => MethodError at src/rules/threads.jl:76| # but the pullback returns RData{...} — different types; increment_rdata!! handles | ||
| # all tangent types uniformly. | ||
| Threads.@threads for i in 1:n | ||
| dy_i = yd[i] |
There was a problem hiding this comment.
The pullback reads dy_i = yd[i] and passes it directly to the inner pullback:
dy_i = yd[i]
# ...
rdata_tuple = pullbacks[i](dy_i)yd[i] is a tangent (the element of the tangent vector). Inner pullbacks expect rdata, not the full tangent. For Float64, tangent and rdata are the same type, so this works by accident. For isbits structs, they diverge:
struct Pair; a::Float64; b::Float64; end
tangent_type(Pair) # => Tangent{@NamedTuple{a::Float64, b::Float64}}
rdata_type(tangent_type(Pair)) # => RData{@NamedTuple{a::Float64, b::Float64}}The canonical pattern in isbits_arrayset_rrule (array_legacy.jl:463) correctly extracts rdata before returning:
dv = rdata(arrayref(false, dA, lin_inds))The PR should do the same: pullbacks[i](rdata(dy_i)) instead of pullbacks[i](dy_i).
MWE:
using Mooncake
using Mooncake: CoDual, zero_codual, zero_fcodual, rrule!!, NoRData,
tangent_type, rdata_type, @is_primitive, MinimalCtx, ReverseMode
struct Pair; a::Float64; b::Float64; end
T_rdata = rdata_type(tangent_type(Pair))
# A primitive returning Pair, whose pullback correctly expects RData:
pair_of(x::Float64) = Pair(x, 2x)
@is_primitive MinimalCtx ReverseMode Tuple{typeof(pair_of), Float64}
function Mooncake.rrule!!(::CoDual{typeof(pair_of)}, x::CoDual{Float64})
px = Mooncake.primal(x)
function pb(dy)
dy isa T_rdata || error("pullback got $(typeof(dy)), expected $T_rdata")
return NoRData(), dy.data.a + 2 * dy.data.b
end
return zero_fcodual(Pair(px, 2px)), pb
end
# Forward pass through threaded_map! works:
result, tmb = rrule!!(
zero_fcodual(threaded_map!),
zero_fcodual(pair_of),
zero_codual(Vector{Pair}(undef, 3)),
zero_codual([1.0, 2.0, 3.0]),
)
# Simulate upstream cotangent accumulation into yd (as sum/getfield pullbacks would):
yd = Mooncake.tangent(result)
for i in 1:3
yd[i] = Mooncake.increment_rdata!!(yd[i], T_rdata((a=1.0, b=1.0)))
end
# yd[1] is Tangent{...}, but the inner pullback expects RData{...}:
typeof(yd[1]) # => Mooncake.Tangent{@NamedTuple{a::Float64, b::Float64}}
T_rdata # => Mooncake.RData{@NamedTuple{a::Float64, b::Float64}}
# Pullback fails: threaded_map! passes yd[i] :: Tangent where RData is expected:
tmb(NoRData())
# => Error: pullback got Tangent{...}, expected RData{...}| return n | ||
| end | ||
|
|
||
| @is_primitive MinimalCtx ReverseMode Tuple{ |
|
It's okay to close this PR if there's consensus; doing so robustly and usefully is hard. |
|
Please do this in a separate repo. |
This look good enough now for me. There are obvious extensions using similar patterns, e.g. an allocating threaded_map, but also something like threaded_materialize, which for broadcasted elements could copy things around as this version currently does for closures - though a more efficient implementation would of course also be good.
Human content above
Claude content below
Problem
Mooncake has no way to differentiate a multi-threaded element-wise map. Any function containing a
Threads.@threadsloop over an array currently either errors or silently produces incorrect gradients because the interpreter cannot handle the threading primitives.The issue proposes a
threaded_map!utility whose interface is narrow enough to guarantee that both the forward pass and the reverse pass are race-free.Change
src/rules/threads.jl— new file (~120 lines)._check_threaded_map!(output, inputs)helper that validates isbitstype elements, non-empty inputs, and output length; returns the effectiven. Called from boththreaded_map!andrrule!!so validation is shared.Mooncake.threaded_map!(f, output::Vector, inputs::Vector...)that calls_check_threaded_map!then runsoutput[i] = f(inputs[1][i], inputs[2][i], ...)in aThreads.@threadsloop. Element types may differ freely across vectors.MinimalCtx ReverseModeprimitive via@is_primitive Tuple{typeof(threaded_map!), F, Vector, Vararg{Vector}} where {F}.rrule!!directly.make_ad_stmts!always insertsrrule!!for primitives in nested IR (thebuild_primitive_rrulehook only fires for top-levelbuild_rrulecalls), sorrule!!must be the entry point. Julia's method dispatch naturally caches specialisations — no manualIdDictcache is needed.rrule!!forward pass: calls_check_threaded_map!for validation; infers the concrete pullback element type viapullback_type(typeof(rrule!!), (F, map(eltype, xps)...))(usesCore.Compiler.return_type, same technique asBase.Broadcast.combine_eltypes), allocatesVector{PBType}(undef, n), savesold_yp = copy(yp)(primal only — no need to saveyd), then runs all elements in parallel viaThreads.@threads. Then == 0case is handled naturally by the empty loop — no special branch.rrule!!pullback:dy_i = yd[i], immediately zerosyd[i] = zero_tangent(yp[i])(the forward pass set it to zero, so zeroing restores it without needing a saved copy), then callspullbacks[i](dy_i)and accumulates input cotangents viaxd[i] = increment_rdata!!(xd[i], rdata_tuple[j+1])(element-local writes, race-free). Usingincrement_rdata!!rather than+=is necessary for correctness: for isbits structs,xd[i] :: Tangent{...}but the pullback returnsRData{...}— different types;increment_rdata!!handles all tangent types uniformly. Non-differentiable inputs (xd isa NoFData) are skipped.rdata_type(tangent_type(F)) ≠ NoRData): stores per-element f-rdata inf_rdatas[i](one slot per element, race-free), folds serially viaincrement_internal!!to producef_rdata. Plain functions (rdata_type == NoRData) skip this entirely.f_rdataat position 2 of the rdata tuple.copyto!(yp, old_yp).Threads.@threads :static + threadid(), O(nthreads) memory) as a future improvement.hand_written_rule_test_cases(rng_ctor, ::Val{:threads})with Float64 (1-input), Float32 (1-input), Float64 (2-input+), isbits closure, and two heterogeneous-type cases (Float32→Float64, Float32+Float64→Float64).src/Mooncake.jl— two lines:include(joinpath("rules", "threads.jl"))after thehigh_order_derivative_patches.jlinclude.export threaded_map!.test/rules/threads.jl— new file:TestUtils.run_rule_test_cases(StableRNG, Val(:threads)).test/runtests.jl— oneelseifbranch added:test_group == "rules/threads".Correctness argument
Race-freedom relies on two constraints enforced at runtime and by the primitive's type signature:
scalar. Thread
ireads/writes onlyxds[j][i]andyd[i], so there are no shared-memoryraces regardless of whether the element types differ across vectors. The runtime check ensures
this property holds for the actual vectors passed.
fdata_type(tangent_type(F)) == NoFData—zero_fcodual(fp)producesCoDual{F, NoFData},which can be shared across threads with no race on fdata. This covers plain functions and
closures capturing only isbits scalars. The per-element rdata for
f(from closurescapturing isbits values) is accumulated race-free via per-element storage and a serial fold.
MWE
mwe.jlruns on a 100k-elementFloat64vector usingtanh(higher compute-to-memory ratio thansin) and checks three things:threaded_map!(tanh, out, x)matches a serial reference.value_and_gradient!!onsum(threaded_map!(tanh, zeros(n), v))gives1 .- tanh.(x).^2within tolerance.JULIA_NUM_THREADS > 1,value_and_gradient!!from a prepared gradient cache usingsum(threaded_map!(f, zeros(n), v))must be ≥ 1.3× faster than the structurally identicalsum(map!(f, zeros(n), v))differentiated by Mooncake's interpreter. Both benchmarks time onlyvalue_and_gradient!!(notprepare_gradient_cache). Primal timing is reported informational-only (not asserted) becausemap!can be SIMD-vectorized, making the primal speedup ratio unreliable.On
main: exits 1 (threaded_map!undefined). On the worktree: exits 0 when all three checks pass.MWE
mwe.jlmain output:
worktree output:
Fixes #791