Description
Enzyme.gradient through a function with a custom augmented_primal/reverse rule throws MixedReturnException at runtime when the return type involves a @generated PolyAlgorithm with 3+ mixed sub-algorithm types. Works fine with 1-2 mixed types or any number of identical types.
The error is thrown from runtime_generic_augfwd, not at compile time — autodiff_thunk compiles successfully and has_rrule_from_sig returns true. Julia's Base.return_types infers a single concrete type for all cases.
MWE
using NonlinearSolve, Enzyme
f(u, p) = u .^ 2 .- p
# ✓ Works: 2 mixed algorithm types
poly2 = NonlinearSolvePolyAlgorithm([NewtonRaphson(), TrustRegion()])
Enzyme.gradient(Enzyme.Reverse,
p -> sum(solve(NonlinearProblem(f, [1.0], p), poly2).u), [2.0])
# => ([0.353...],)
# ✗ Fails: 3 mixed algorithm types
poly3 = NonlinearSolvePolyAlgorithm([NewtonRaphson(), TrustRegion(), LevenbergMarquardt()])
Enzyme.gradient(Enzyme.Reverse,
p -> sum(solve(NonlinearProblem(f, [1.0], p), poly3).u), [2.0])
# => MixedReturnException
# ✓ Works: 6 identical types (same count, no type diversity)
poly6 = NonlinearSolvePolyAlgorithm([NewtonRaphson() for _ in 1:6])
Enzyme.gradient(Enzyme.Reverse,
p -> sum(solve(NonlinearProblem(f, [1.0], p), poly6).u), [2.0])
# => ([0.353...],)
Key observations
autodiff_thunk compiles: the error is runtime, not compile-time
has_rrule_from_sig returns true for all cases
Base.return_types returns a single concrete type for all cases
- Without any custom rule, Enzyme differentiates through all cases natively (including 6 mixed types)
- The custom rule in
NonlinearSolveBaseEnzymeExt (on solve_up) handles Union{Type{Duplicated{RT}}, Type{MixedDuplicated{RT}}}
- Could not reproduce with pure Julia types — the trigger seems specific to the complexity of the NonlinearSolve dispatch chain inside the rule body
Versions
- Julia 1.10.11
- Enzyme 0.13.134
- NonlinearSolve 4.16.0
- NonlinearSolveBase 2.16.0
Related
- NonlinearSolve.jl#878
- SciMLSensitivity.jl#1358
🤖 Generated with Claude Code
Co-Authored-By: Chris Rackauckas accounts@chrisrackauckas.com
Description
Enzyme.gradientthrough a function with a customaugmented_primal/reverserule throwsMixedReturnExceptionat runtime when the return type involves a@generatedPolyAlgorithm with 3+ mixed sub-algorithm types. Works fine with 1-2 mixed types or any number of identical types.The error is thrown from
runtime_generic_augfwd, not at compile time —autodiff_thunkcompiles successfully andhas_rrule_from_sigreturnstrue. Julia'sBase.return_typesinfers a single concrete type for all cases.MWE
Key observations
autodiff_thunkcompiles: the error is runtime, not compile-timehas_rrule_from_sigreturnstruefor all casesBase.return_typesreturns a single concrete type for all casesNonlinearSolveBaseEnzymeExt(onsolve_up) handlesUnion{Type{Duplicated{RT}}, Type{MixedDuplicated{RT}}}Versions
Related
🤖 Generated with Claude Code
Co-Authored-By: Chris Rackauckas accounts@chrisrackauckas.com