feat: add missing linalg ops with AD (oracle coverage) (#649)#652
Merged
feat: add missing linalg ops with AD (oracle coverage) (#649)#652
Conversation
…valsh, eigvals, pinv, norm) - Add Lu primitive (backend + StdTensorOp + compiler/exec + AD JVP) - Add Eig primitive (backend + StdTensorOp + compiler/exec + AD JVP eigenvalues only) - Refactor Solve from primitive to LU + triangular_solve composition - Add compositions: slogdet, det, inv, eigvalsh, eigvals, pinv, norm - Add oracle replay for lu, slogdet, det, inv, eigvalsh, solve - Add ReduceProd/ReduceMax/ReduceMin AD rules - eig/eigvals/pinv/norm oracle replay partially enabled (needs fixes) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…igvals - Add complex128 tensor decoding and comparison in oracle replay - Fix pinv default rtol to match PyTorch, add pinv_with_rtol API - Fix norm for ord=0, matrix induced norms (1, -1, inf) - Skip eig/eigvals oracle (needs complex-safe JVP + transpose rule) - Skip 2 norm ord=-inf matrix cases (subgradient mismatch) Oracle results: pinv 24/24 passed, norm 100/102 passed (2 skipped) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…eig/norm AD
- Add StdTensorOp::Convert { from, to } with linearize + transpose rules
(real↔complex are mutual adjoints)
- StableHLO: Convert maps to standard stablehlo.convert (not CustomCall)
- Fix linearize_eig to use Convert instead of Mul-by-complex-one hack
- Enable eig/eigvals oracle replay (eigvals 8/8 passed, eig skips vector derivatives per JAX)
- Fix norm ord=-inf matrix VJP (all 102 norm cases now pass)
- Add CPU convert kernel, backend trait method, compiler/exec wiring
- Add tests: Convert eval/JVP/VJP, eigvals gradient, norm regression
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…7% coverage) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
stablehlo.convert) for dtype conversion with adjoint AD rules (real↔complex are mutual adjoints)Oracle Results
Design Decisions
Convert{f64→c64}transpose =Convert{c64→f64}(real-part extraction). Standard StableHLO op, not CustomCallNotImplementedError)Closes #649
Test plan
cargo test --workspace --release— all passfailed: 0cargo llvm-cov)cargo doc --workspace --no-deps)🤖 Generated with Claude Code