Skip to content

Fix #791: Add threaded_map! with reverse-mode AD support#1104

Closed
nsiccha wants to merge 15 commits intomainfrom
issue-791-threaded-map
Closed

Fix #791: Add threaded_map! with reverse-mode AD support#1104
nsiccha wants to merge 15 commits intomainfrom
issue-791-threaded-map

Conversation

@nsiccha
Copy link
Copy Markdown
Collaborator

@nsiccha nsiccha commented Mar 24, 2026

This look good enough now for me. There are obvious extensions using similar patterns, e.g. an allocating threaded_map, but also something like threaded_materialize, which for broadcasted elements could copy things around as this version currently does for closures - though a more efficient implementation would of course also be good.

Human content above


Claude content below

Problem

Mooncake has no way to differentiate a multi-threaded element-wise map. Any function containing a Threads.@threads loop over an array currently either errors or silently produces incorrect gradients because the interpreter cannot handle the threading primitives.

The issue proposes a threaded_map! utility whose interface is narrow enough to guarantee that both the forward pass and the reverse pass are race-free.

Change

src/rules/threads.jl — new file (~120 lines).

  • Defines a _check_threaded_map!(output, inputs) helper that validates isbitstype elements, non-empty inputs, and output length; returns the effective n. Called from both threaded_map! and rrule!! so validation is shared.
  • Defines Mooncake.threaded_map!(f, output::Vector, inputs::Vector...) that calls _check_threaded_map! then runs output[i] = f(inputs[1][i], inputs[2][i], ...) in a Threads.@threads loop. Element types may differ freely across vectors.
  • Declares it as a MinimalCtx ReverseMode primitive via @is_primitive Tuple{typeof(threaded_map!), F, Vector, Vararg{Vector}} where {F}.
  • Defines rrule!! directly. make_ad_stmts! always inserts rrule!! for primitives in nested IR (the build_primitive_rrule hook only fires for top-level build_rrule calls), so rrule!! must be the entry point. Julia's method dispatch naturally caches specialisations — no manual IdDict cache is needed.
  • rrule!! forward pass: calls _check_threaded_map! for validation; infers the concrete pullback element type via pullback_type(typeof(rrule!!), (F, map(eltype, xps)...)) (uses Core.Compiler.return_type, same technique as Base.Broadcast.combine_eltypes), allocates Vector{PBType}(undef, n), saves old_yp = copy(yp) (primal only — no need to save yd), then runs all elements in parallel via Threads.@threads. The n == 0 case is handled naturally by the empty loop — no special branch.
  • rrule!! pullback:
    • Parallel reverse pass: reads dy_i = yd[i], immediately zeros yd[i] = zero_tangent(yp[i]) (the forward pass set it to zero, so zeroing restores it without needing a saved copy), then calls pullbacks[i](dy_i) and accumulates input cotangents via xd[i] = increment_rdata!!(xd[i], rdata_tuple[j+1]) (element-local writes, race-free). Using increment_rdata!! rather than += is necessary for correctness: for isbits structs, xd[i] :: Tangent{...} but the pullback returns RData{...} — different types; increment_rdata!! handles all tangent types uniformly. Non-differentiable inputs (xd isa NoFData) are skipped.
    • For closures capturing isbits scalars (rdata_type(tangent_type(F)) ≠ NoRData): stores per-element f-rdata in f_rdatas[i] (one slot per element, race-free), folds serially via increment_internal!! to produce f_rdata. Plain functions (rdata_type == NoRData) skip this entirely.
    • Returns f_rdata at position 2 of the rdata tuple.
    • Restores primal via copyto!(yp, old_yp).
    • A TODO marks Option B (per-thread accumulators with Threads.@threads :static + threadid(), O(nthreads) memory) as a future improvement.
  • Adds hand_written_rule_test_cases(rng_ctor, ::Val{:threads}) with Float64 (1-input), Float32 (1-input), Float64 (2-input +), isbits closure, and two heterogeneous-type cases (Float32→Float64, Float32+Float64→Float64).

src/Mooncake.jl — two lines:

  • include(joinpath("rules", "threads.jl")) after the high_order_derivative_patches.jl include.
  • export threaded_map!.

test/rules/threads.jl — new file: TestUtils.run_rule_test_cases(StableRNG, Val(:threads)).

test/runtests.jl — one elseif branch added: test_group == "rules/threads".

Correctness argument

Race-freedom relies on two constraints enforced at runtime and by the primitive's type signature:

  1. All element types are isbits — each element of every input/output vector is a bits-type
    scalar. Thread i reads/writes only xds[j][i] and yd[i], so there are no shared-memory
    races regardless of whether the element types differ across vectors. The runtime check ensures
    this property holds for the actual vectors passed.
  2. fdata_type(tangent_type(F)) == NoFDatazero_fcodual(fp) produces CoDual{F, NoFData},
    which can be shared across threads with no race on fdata. This covers plain functions and
    closures capturing only isbits scalars. The per-element rdata for f (from closures
    capturing isbits values) is accumulated race-free via per-element storage and a serial fold.

MWE

mwe.jl runs on a 100k-element Float64 vector using tanh (higher compute-to-memory ratio than sin) and checks three things:

  1. Primal correctnessthreaded_map!(tanh, out, x) matches a serial reference.
  2. Gradient correctnessvalue_and_gradient!! on sum(threaded_map!(tanh, zeros(n), v)) gives 1 .- tanh.(x).^2 within tolerance.
  3. AD speedup — when JULIA_NUM_THREADS > 1, value_and_gradient!! from a prepared gradient cache using sum(threaded_map!(f, zeros(n), v)) must be ≥ 1.3× faster than the structurally identical sum(map!(f, zeros(n), v)) differentiated by Mooncake's interpreter. Both benchmarks time only value_and_gradient!! (not prepare_gradient_cache). Primal timing is reported informational-only (not asserted) because map! can be SIMD-vectorized, making the primal speedup ratio unreliable.

On main: exits 1 (threaded_map! undefined). On the worktree: exits 0 when all three checks pass.

MWE

Note: MWE scripts are committed to the mwe/ directory on this branch but are removed in the final commit (so they won't be included after squash merge). To check out the MWE locally: git checkout 6a1c58f5b5 -- mwe/ (commit with MWE)

mwe.jl

using Mooncake

# On main, threaded_map! is not defined — expected failure.
if !isdefined(Mooncake, :threaded_map!)
    println("FAIL: Mooncake.threaded_map! is not defined")
    exit(1)
end

n = 100_000  # large enough that threading overhead is amortised

x = randn(Float64, n)
out1 = zeros(Float64, n)
out2 = zeros(Float64, n)

# More compute-heavy function: tanh is transcendental with higher compute-to-memory ratio.
f_bench = tanh

# --- Correctness: primal ---
threaded_map!(f_bench, out1, x)
map!(f_bench, out2, x)
if !isapprox(out1, out2; rtol=1e-12)
    println("FAIL: threaded_map! primal result differs from map!")
    exit(1)
end

# --- Correctness: gradient ---
loss(v) = sum(threaded_map!(f_bench, zeros(n), v))
cache = prepare_gradient_cache(loss, x)
val, (_, grad) = value_and_gradient!!(cache, loss, x)

# tanh'(x) = 1 - tanh(x)^2 = sech(x)^2
expected_val  = sum(tanh.(x))
expected_grad = 1 .- tanh.(x).^2

if !isapprox(val, expected_val; rtol=1e-12)
    println("FAIL: value mismatch: got $val, expected $expected_val")
    exit(1)
end
if !isapprox(grad, expected_grad; rtol=1e-6)
    println("FAIL: gradient mismatch (max err=$(maximum(abs, grad .- expected_grad)))")
    exit(1)
end

# --- Primal timing (informational — not asserted) ---
# Note: map! can be SIMD-vectorised by the compiler, making the primal speedup of
# Threads.@threads over map! unpredictable for lightweight f.
for _ in 1:5
    threaded_map!(f_bench, out1, x)
    map!(f_bench, out2, x)
end
t_threaded = minimum([@elapsed(threaded_map!(f_bench, out1, x)) for _ in 1:20])
t_serial   = minimum([@elapsed(map!(f_bench, out2, x))          for _ in 1:20])
nthreads   = Threads.nthreads()
println("primal: nthreads=$nthreads  serial=$(round(t_serial*1e3; digits=2))ms  " *
        "threaded=$(round(t_threaded*1e3; digits=2))ms  " *
        "speedup=$(round(t_serial/t_threaded; digits=2))x  (informational)")

# --- AD speedup: threaded_map! vs map! differentiated by Mooncake ---
# Both functions have the same structure: sum(map!-variant(f, zeros(n), v)).
# threaded_map! uses a hand-coded parallel rrule!!; map! is traced by Mooncake's interpreter.
# Both benchmarks start from a prepared gradient cache (no compilation overhead).
function loss_threaded(v)
    return sum(threaded_map!(f_bench, zeros(length(v)), v))
end
function loss_serial_ad(v)
    return sum(map!(f_bench, zeros(length(v)), v))
end

cache_t = prepare_gradient_cache(loss_threaded, x)
cache_s = prepare_gradient_cache(loss_serial_ad, x)

# Warm up — run a few iterations to ensure any lazy initialisation is done.
for _ in 1:3
    value_and_gradient!!(cache_t, loss_threaded, x)
    value_and_gradient!!(cache_s, loss_serial_ad, x)
end

# Benchmark from prepared state: time only value_and_gradient!! (not prepare_gradient_cache).
t_ad_t = minimum([@elapsed(value_and_gradient!!(cache_t, loss_threaded, x)) for _ in 1:20])
t_ad_s = minimum([@elapsed(value_and_gradient!!(cache_s, loss_serial_ad, x)) for _ in 1:20])
ad_speedup = t_ad_s / t_ad_t
println("AD:     nthreads=$nthreads  serial=$(round(t_ad_s*1e3; digits=2))ms  " *
        "threaded=$(round(t_ad_t*1e3; digits=2))ms  " *
        "speedup=$(round(ad_speedup; digits=2))x")

if nthreads > 1 && ad_speedup < 1.3
    println("FAIL: expected AD speedup >= 1.3x with $nthreads threads, got $(round(ad_speedup; digits=2))x")
    exit(1)
end

println("PASS")
exit(0)

main output:

# MWE: mwe.jl on main
# started: 2026-03-30 13:30:30
# dir: /home/niko/github/issues/Mooncake.jl
# project: /home/niko/github/issues/Mooncake.jl
---
==> Pkg.instantiate()
==> Running mwe.jl
FAIL: Mooncake.threaded_map! is not defined

# exit_code: 1
# status: FAIL
# finished: 2026-03-30 13:30:33

worktree output:

# MWE: mwe.jl on worktree
# started: 2026-03-30 13:30:30
# dir: /home/niko/github/issues/Mooncake/worktrees/issue-791-threaded-map
# project: /home/niko/github/issues/Mooncake/worktrees/issue-791-threaded-map
---
==> Pkg.instantiate()
Precompiling packages...
  56459.1 ms  ✓ Mooncake
  1 dependency successfully precompiled in 58 seconds. 52 already precompiled.
==> Running mwe.jl
primal: nthreads=12  serial=2.52ms  threaded=0.44ms  speedup=5.79x  (informational)
AD:     nthreads=12  serial=8.1ms  threaded=5.01ms  speedup=1.62x
PASS

# exit_code: 0
# status: PASS
# finished: 2026-03-30 13:31:52

Fixes #791

nsiccha and others added 6 commits March 24, 2026 19:41
Introduces Mooncake.threaded_map!(f, output, input) as a new exported
primitive for element-wise in-place mapping over IEEEFloat vectors using
Threads.@threads.  An rrule is provided via build_primitive_rrule, which
pre-compiles the scalar rule for f at rule-construction time and reuses it
across all elements and calls.  The reverse pass (pullback) runs in
parallel with the same @threads loop, accumulating cotangents into the
input gradient vector.

Constraints: T1,T2 <: IEEEFloat (element-level independence) and f must
carry no mutable differentiable state (rule is shared across threads).

Closes #791.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
make_ad_stmts! always inserts rrule!! for primitive calls in nested IR —
build_primitive_rrule is only used for top-level build_rrule calls.
Replace the ThreadedMapRRule/build_primitive_rrule approach with a direct
rrule!! method. Add an IdDict cache so the inner build_rrule(Tuple{F,T2})
is paid once per (F,T2) type combination.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Accept any number of input vectors: threaded_map!(f, output, x1, x2, ...)
- Update @is_primitive to Vararg{Vector{T}} signature
- Update rrule!! to handle variadic inputs/cotangents; cache key is now (F, N, T);
  inner scalar rule built as build_rrule(Tuple{F, Vararg{T,N}})
- Pullback loops over j in 1:N to accumulate cotangents for all inputs
- Returns ntuple(_ -> NoRData(), Val(N))... for the variadic rdatas
- Remove all @inbounds annotations
- Add 2-input (+) test case in hand_written_rule_test_cases

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
When f is a closure capturing IEEEFloat scalars, fdata_type(F)=NoFData so the
forward pass is already race-free. The reverse pass now stores per-element rdata
for f in f_rdatas[i] (each thread writes only its own slot — race-free), then
folds the vector serially via increment_internal!! after the parallel loop.

Returns f_rdata at position 2 of the pullback tuple instead of NoRData(). For
plain functions (rdata_type(tangent_type(F))==NoRData) the fast path skips all
allocation and returns NoRData() as before.

Includes a TODO for Option B (per-thread accumulators with :static scheduling)
and adds an isbits-closure test case to hand_written_rule_test_cases.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Run element 1 serially to determine the concrete pullback type (PBType =
  typeof(pb1)), then allocate Vector{PBType} instead of Vector{Any}. Avoids
  boxing and dynamic dispatch on every pullback call in the reverse pass.
- Add @inbounds to all hot array accesses in rrule!! and the pullback (bounds
  are guaranteed by the length checks on entry).
- Handle n=0 (empty input vectors) with an early-return no-op pullback.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Previously all vectors (output + inputs) had to share the same element type T.
Now output::Vector{Tout} and each input::Vector{Ti} can have independent element
types as long as Tout, Ti <: IEEEFloat. Race-freedom is unaffected: element-wise
cotangent writes are still independent by index.

- threaded_map! signature: Vector{<:IEEEFloat} for output and each input vararg
- @is_primitive: Vararg{Vector{<:IEEEFloat}} to match any IEEEFloat combination
- rrule!!: output::CoDual{Vector{Tout},...}, inputs::CoDual{<:Vector{<:IEEEFloat},...}...
- Scalar rule cache key changed from (F, N, T) to (F, input_element_types_tuple);
  rule built as build_rrule(Tuple{F, T1, T2, ..., TN})
- New test cases: Float32→Float64 single input, Float32+Float64→Float64 two inputs

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 24, 2026

Codecov Report

❌ Patch coverage is 0% with 58 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/rules/threads.jl 0.00% 58 Missing ⚠️

📢 Thoughts on this report? Let us know!

nsiccha and others added 5 commits March 25, 2026 08:26
- threaded_map!: remove IEEEFloat constraint, add isbitstype() runtime check
- @is_primitive: match Vector without type parameter
- rrule!!: call rrule!! directly (Julia dispatch is the cache); Vector{Any}
  for pullbacks; drop n==0 special case (empty loops handle it naturally)
- Remove all @inbounds annotations

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Replace Vector{Any} with Vector{PBType} where PBType is inferred via
pullback_type (which uses Core.Compiler.return_type, same technique as
Base.Broadcast.combine_eltypes). Eliminates boxing and dynamic dispatch
on the per-element pullback calls in the reverse pass.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Val(N) inside a Threads.@threads loop body is computed with N captured
as Int64, so ntuple sees Val{::Int64} and falls back to a generic loop
instead of unrolling. Hoisting val_N = Val(N) to the outer scope (where
N is a compile-time constant) lets ntuple specialize on Val{N}, giving
the expected ~5x primal speedup on n=100k with 12 threads.

Applied to both threaded_map! (primal) and rrule!! forward/pullback.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Both loss_serial_ad and loss_threaded now use sum(map!-variant(f, zeros(n), v)),
making the AD speedup comparison structurally identical. Primal correctness
check also uses map! directly.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@nsiccha nsiccha changed the title Fix #791: Add support for reverse mode AD of multi-threaded map! operation Fix #791: Add threaded_map! with reverse-mode AD support Mar 25, 2026
@sunxd3
Copy link
Copy Markdown
Collaborator

sunxd3 commented Mar 28, 2026

sorry Niko. I don't think I understand Claude's explanation — what is the issue, what is the implementation, and why I should be convinced this is correct?

@nsiccha
Copy link
Copy Markdown
Collaborator Author

nsiccha commented Mar 28, 2026

sorry Niko. I don't think I understand Claude's explanation — what is the issue, what is the implementation, and why I should be convinced this is correct?

I must misunderstand the question, because the obvious answers are: a) #791, b) https://github.com/chalk-lab/Mooncake.jl/blob/6a1c58f5b5b33d10dcb843e4a4c428228de401ea/src/rules/threads.jl, c) AFAICT, it's not correct yet and also a bit wonky - but the main logic should be correct and also quite self contained.

But I think you may have been after a different answer?

@yebai yebai closed this Mar 28, 2026
@yebai yebai deleted the issue-791-threaded-map branch March 28, 2026 16:27
@sunxd3
Copy link
Copy Markdown
Collaborator

sunxd3 commented Mar 28, 2026

I think I roughly understood the motivation now, I was a little confused by Claude's explanation. The method makes sense to me. But it does feel like there are more thing to consider, for instance what kind of f we can support (love to discuss further).

Comment thread src/rules/threads.jl Outdated
FRData !== NoRData && (f_rdatas[i] = rdata_tuple[1])
for j in 1:N
xd = xds[j]
xd isa NoFData || (xd[i] += rdata_tuple[j + 1])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is + defined for all bit type tangents?

@yebai yebai restored the issue-791-threaded-map branch March 28, 2026 23:09
@yebai yebai reopened this Mar 28, 2026
@yebai
Copy link
Copy Markdown
Member

yebai commented Mar 28, 2026

I incidentally closed this PR.

nsiccha and others added 4 commits March 30, 2026 12:08
- Extract _check_threaded_map! helper (validation + length calc) called from
  both threaded_map! primal and rrule!!. Previously rrule!! skipped validation.
- Drop old_yd copy: the forward pass sets yd[i]=0, so the pullback can restore
  by zeroing after reading instead of copying and restoring a saved vector.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
For isbits structs, tangent_type(Ti) = Tangent{...} but the inner
pullback returns RData{...} — different types, += fails. For Float64,
both happen to be Float64 so += works, masking the bug.

increment_rdata!!(t, r) = tangent(fdata(t), increment(rdata(t), r))
handles all cases correctly: Float64 (reduces to +), isbits structs
(increments the RData component of the Tangent), NoTangent (no-op).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Remove issue URL from docstring (line 12)
- Remove unused `where {F}` from threaded_map! (F not referenced in body)
- Add check for f in _check_threaded_map!: fdata_type(tangent_type(typeof(f)))==NoFData
- Widen rrule!! output/input types from CoDual{Vector{Tout},Vector{Tout}} to
  CoDual{<:Vector}: the old annotation required tangent_type(Tout)==Tout, which
  only holds for IEEEFloat eltypes. Tout was also unused in the body.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@nsiccha nsiccha marked this pull request as ready for review March 30, 2026 11:40
@nsiccha
Copy link
Copy Markdown
Collaborator Author

nsiccha commented Mar 30, 2026

I think this could now be ready for review, @sunxd3 - do you want to have a look? You already did a bit, didn't you?

personally, I'm a bit confused by function hand_written_rule_test_cases(rng_ctor, ::Val{:threads}) being in src, but claude claims that's the pattern in Mooncake?

Comment thread src/rules/threads.jl
pullbacks = Vector{PBType}(undef, n)
Threads.@threads for i in 1:n
xi = ntuple(j -> CoDual(xps[j][i], NoFData()), valN)
yi_codual, pb = rrule!!(zero_fcodual(fp), xi...)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For non-primitives, a rule need to be built first.

using Mooncake

function loss_anon(v)
    out = similar(v)
    threaded_map!(x -> 2x, out, v)
    return sum(out)
end

x = randn(4)

# Both fail: threaded_map!'s rrule!! calls rrule!! on f, which has no hand-written rule.
rule = Mooncake.build_rrule(loss_anon, x)
value_and_gradient!!(rule, loss_anon, x)
# => MethodError at src/rules/threads.jl:76

Comment thread src/rules/threads.jl
# but the pullback returns RData{...} — different types; increment_rdata!! handles
# all tangent types uniformly.
Threads.@threads for i in 1:n
dy_i = yd[i]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pullback reads dy_i = yd[i] and passes it directly to the inner pullback:

dy_i = yd[i]
# ...
rdata_tuple = pullbacks[i](dy_i)

yd[i] is a tangent (the element of the tangent vector). Inner pullbacks expect rdata, not the full tangent. For Float64, tangent and rdata are the same type, so this works by accident. For isbits structs, they diverge:

struct Pair; a::Float64; b::Float64; end

tangent_type(Pair)                  # => Tangent{@NamedTuple{a::Float64, b::Float64}}
rdata_type(tangent_type(Pair))      # => RData{@NamedTuple{a::Float64, b::Float64}}

The canonical pattern in isbits_arrayset_rrule (array_legacy.jl:463) correctly extracts rdata before returning:

dv = rdata(arrayref(false, dA, lin_inds))

The PR should do the same: pullbacks[i](rdata(dy_i)) instead of pullbacks[i](dy_i).

MWE:

using Mooncake
using Mooncake: CoDual, zero_codual, zero_fcodual, rrule!!, NoRData,
    tangent_type, rdata_type, @is_primitive, MinimalCtx, ReverseMode

struct Pair; a::Float64; b::Float64; end

T_rdata = rdata_type(tangent_type(Pair))

# A primitive returning Pair, whose pullback correctly expects RData:
pair_of(x::Float64) = Pair(x, 2x)
@is_primitive MinimalCtx ReverseMode Tuple{typeof(pair_of), Float64}
function Mooncake.rrule!!(::CoDual{typeof(pair_of)}, x::CoDual{Float64})
    px = Mooncake.primal(x)
    function pb(dy)
        dy isa T_rdata || error("pullback got $(typeof(dy)), expected $T_rdata")
        return NoRData(), dy.data.a + 2 * dy.data.b
    end
    return zero_fcodual(Pair(px, 2px)), pb
end

# Forward pass through threaded_map! works:
result, tmb = rrule!!(
    zero_fcodual(threaded_map!),
    zero_fcodual(pair_of),
    zero_codual(Vector{Pair}(undef, 3)),
    zero_codual([1.0, 2.0, 3.0]),
)

# Simulate upstream cotangent accumulation into yd (as sum/getfield pullbacks would):
yd = Mooncake.tangent(result)
for i in 1:3
    yd[i] = Mooncake.increment_rdata!!(yd[i], T_rdata((a=1.0, b=1.0)))
end

# yd[1] is Tangent{...}, but the inner pullback expects RData{...}:
typeof(yd[1])   # => Mooncake.Tangent{@NamedTuple{a::Float64, b::Float64}}
T_rdata         # => Mooncake.RData{@NamedTuple{a::Float64, b::Float64}}

# Pullback fails: threaded_map! passes yd[i] :: Tangent where RData is expected:
tmb(NoRData())
# => Error: pullback got Tangent{...}, expected RData{...}

Comment thread src/rules/threads.jl
return n
end

@is_primitive MinimalCtx ReverseMode Tuple{
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe a forward mode too?

@yebai
Copy link
Copy Markdown
Member

yebai commented Apr 4, 2026

It's okay to close this PR if there's consensus; doing so robustly and usefully is hard.

@yebai
Copy link
Copy Markdown
Member

yebai commented Apr 13, 2026

Please do this in a separate repo.

@yebai yebai closed this Apr 13, 2026
@yebai yebai deleted the issue-791-threaded-map branch April 13, 2026 17:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add support for reverse mode AD of multi-threaded map! operation

3 participants