Skip to content

Commit aed534a

Browse files
authored
Merge pull request #19 from tensor4all/feat/unified-jax-pytorch-oracles
[codex] Add JAX oracle witnesses and unified notes
2 parents 671ebb5 + 0eab3db commit aed534a

42 files changed

Lines changed: 3621 additions & 26 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

AGENTS.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,15 @@ uv sync --locked --all-groups
2929
- Treat the public version string `2.10.0` as the repository contract. Local build suffixes
3030
such as `+cu128` or `+cpu` should not appear in generated provenance.
3131

32+
## JAX Generator Dependencies
33+
34+
- The JAX generator surface is a first-class part of the repository contract.
35+
- Keep `jax` and `jaxlib` pinned to exact matching versions in `pyproject.toml`.
36+
- The current exact JAX pins are `jax==0.9.1` and `jaxlib==0.9.1`.
37+
- Run every JAX command through `uv run`, including generator smoke checks and any
38+
validation commands added for JAX witnesses.
39+
- When updating JAX-related dependencies, refresh `uv.lock` before syncing the environment.
40+
3241
## Operational Notes
3342

3443
- Do not run `uv lock` and `uv sync --locked --all-groups` in parallel. Update the lockfile first, then sync.
@@ -38,6 +47,7 @@ uv sync --locked --all-groups
3847
- `uv run python scripts/verify_cases.py`
3948
- `uv run python scripts/check_replay.py`
4049
- `uv run python scripts/check_regeneration.py`
50+
- Add any JAX-specific validation scripts under the same `uv run` discipline.
4151
- `scripts/check_regeneration.py` compares regenerated JSONL files semantically.
4252
Metadata must match exactly, while numeric tensors may differ only within the
4353
case-level `comparison.rtol` / `comparison.atol`.

README.md

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,37 @@ validation:
66
- mathematical AD notes
77
- a machine-readable JSON oracle database
88

9+
The repository is intentionally dual-backend:
10+
11+
- PyTorch remains the baseline oracle source for the existing case DB
12+
- JAX support is being added in later tasks as a parallel generator and witness
13+
surface
14+
15+
## Planned JAX Surface
16+
17+
The following JAX-facing entrypoint and witness fields are planned for later
18+
tasks and are not implemented yet:
19+
20+
- `uv run python -m generators.jax_v1 --list`
21+
- `jax_ref`
22+
- `linearization`
23+
- `transpose`
24+
25+
These names document the intended future contract so the JAX backend can be
26+
added without changing the vocabulary later.
27+
28+
## JAX Witness Source Policy
29+
30+
JAX witness materialization uses two source modes:
31+
32+
- Prefer `jax_test` for families with a dedicated JAX internal harness.
33+
- Fall back to `torch_aligned` when a published PyTorch case should be reused
34+
as the exact serialized primal-input source.
35+
36+
When a witness comes from a JAX internal harness, the top-level provenance
37+
records the harness `source_file`, `source_function`, `seed`, and a
38+
`comment` entry containing `harness_fullname=...`.
39+
940
The oracle database covers both scalar-style `OpInfo` families and linear
1041
algebra operations.
1142

@@ -80,6 +111,11 @@ uv run python scripts/report_upstream_publish_coverage.py
80111
uv run python scripts/report_complex_support.py
81112
```
82113

114+
`uv run python scripts/validate_schema.py` is a repository-integrity and
115+
publish-time check for maintainers and CI. Downstream consumers should treat
116+
`schema/case.schema.json` as the contract and normally do not need to invoke
117+
the repository script directly.
118+
83119
Repository-managed environment files:
84120

85121
- `.python-version`
@@ -90,6 +126,10 @@ The repository requires an exact PyTorch dependency pin: `torch==2.10.0`.
90126
Generated provenance stores the public version string `2.10.0`, not local
91127
build suffixes such as `+cpu` or `+cu128`.
92128

129+
The planned JAX backend will use exact version pins as well; those pins are
130+
tracked in the repository contract now so the later generator work can rely on a
131+
fixed runtime.
132+
93133
## Math Notes
94134

95135
The mathematical AD notes live under `docs/math/`.
@@ -142,6 +182,12 @@ A case is defined by:
142182
- an `observable`
143183
- one or more paired derivative probes
144184

185+
Published JSONL files store materialized numeric tensor payloads directly. For
186+
`success` cases this includes serialized inputs, probe directions, cotangents,
187+
and numeric reference tensors such as `pytorch_ref`, `fd_ref`, and any present
188+
`jax_ref` witness payloads. Downstream readers do not need PyTorch or JAX to
189+
reconstruct those published numbers.
190+
145191
The database does not require raw decomposition outputs to be the comparison target. For spectral operations, the observable may be a processed output such as `U.abs()`, `S`, `Vh.abs()`, or `U @ Vh`, following the same derivative-relevant observables used by PyTorch AD tests.
146192

147193
## Oracle Policy

cases/abs/identity.jsonl

Lines changed: 4 additions & 4 deletions
Large diffs are not rendered by default.

cases/exp/identity.jsonl

Lines changed: 12 additions & 12 deletions
Large diffs are not rendered by default.

docs/math/cholesky.md

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,72 @@
11
# Cholesky AD Notes
22

3+
## Conventions
4+
5+
Unless noted otherwise, `Linearization` and `Transpose` are written for the
6+
raw-output-space Cholesky map before any DB observable projection. For complex
7+
tensors, `Transpose` means the adjoint under the real Frobenius inner product
8+
9+
$$
10+
\langle X, Y \rangle_{\mathbb{R}} = \operatorname{Re}\operatorname{tr}(X^\dagger Y).
11+
$$
12+
13+
## Forward
14+
15+
The raw operator is the lower-triangular factor
16+
17+
$$
18+
A \mapsto L,
19+
\qquad
20+
A = L L^{\mathsf{H}},
21+
\qquad
22+
A = A^{\mathsf{H}} \succ 0.
23+
$$
24+
25+
## Linearization
26+
27+
With the helper
28+
29+
$$
30+
\varphi(X) = \mathrm{tril}(X) - \tfrac{1}{2}\mathrm{diag}(X),
31+
$$
32+
33+
the raw-output-space linearization is
34+
35+
$$
36+
\dot{L} = L \, \varphi\!\bigl(L^{-1}\dot{A}\,L^{-\mathsf{H}}\bigr).
37+
$$
38+
39+
## JVP
40+
41+
The JVP is exactly the same tangent formula:
42+
43+
$$
44+
\operatorname{jvp}(\operatorname{chol})(A;\dot{A})
45+
= L \, \varphi\!\bigl(L^{-1}\dot{A}\,L^{-\mathsf{H}}\bigr).
46+
$$
47+
48+
## Transpose
49+
50+
For a raw output cotangent $\bar{L}$, the transpose map is
51+
52+
$$
53+
\bar{A} =
54+
L^{-\mathsf{H}} \,
55+
\varphi^*\!\bigl(\mathrm{tril}(L^{\mathsf{H}}\bar{L})\bigr)
56+
\, L^{-1}.
57+
$$
58+
59+
## VJP (JAX convention)
60+
61+
JAX reads the same raw transpose map directly as the cotangent rule on the
62+
Cholesky factor.
63+
64+
## VJP (PyTorch convention)
65+
66+
PyTorch uses the same triangular-solve sandwich. For `cholesky_ex`, auxiliary
67+
status outputs are treated as metadata, so the VJP applies only to the factor
68+
output.
69+
370
## Forward Definition
471

572
$$

docs/math/det.md

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,70 @@
11
# Determinant AD Notes
22

3+
## Conventions
4+
5+
Unless noted otherwise, `Linearization` and `Transpose` are written for the
6+
raw-output-space determinant maps before any DB observable projection. For
7+
complex tensors, `Transpose` means the adjoint under the real Frobenius inner
8+
product
9+
10+
$$
11+
\langle X, Y \rangle_{\mathbb{R}} = \operatorname{Re}\operatorname{tr}(X^\dagger Y).
12+
$$
13+
14+
## Forward
15+
16+
This note covers two raw operators:
17+
18+
$$
19+
A \mapsto \det(A),
20+
\qquad
21+
A \mapsto (\operatorname{sign}, \operatorname{logabsdet}).
22+
$$
23+
24+
## Linearization
25+
26+
For `det`,
27+
28+
$$
29+
\dot{d} = \det(A)\operatorname{tr}(A^{-1}\dot{A}).
30+
$$
31+
32+
For `slogdet`, if $w = \operatorname{tr}(A^{-1}\dot{A})$, then
33+
34+
$$
35+
\dot{\operatorname{logabsdet}} = \operatorname{Re}(w),
36+
\qquad
37+
\dot{\operatorname{sign}} = i\,\operatorname{Im}(w)\operatorname{sign}.
38+
$$
39+
40+
## JVP
41+
42+
The JVP is the same linearization evaluated at the tangent matrix $\dot{A}$.
43+
44+
## Transpose
45+
46+
For `det`, a raw output cotangent $\bar{d}$ gives
47+
48+
$$
49+
\bar{A} = \overline{\bar{d}\det(A)}\,A^{-\mathsf{H}},
50+
$$
51+
52+
with the real case reducing to $A^{-\mathsf{T}}$.
53+
54+
For `slogdet`, a raw output cotangent on the pair
55+
$(\bar{\operatorname{sign}}, \bar{\operatorname{logabsdet}})$ yields the same
56+
solve-style adjoint summarized later in the note.
57+
58+
## VJP (JAX convention)
59+
60+
JAX reads these transpose maps directly on the scalar or tuple output, with the
61+
same singularity caveats as the raw determinant formulas.
62+
63+
## VJP (PyTorch convention)
64+
65+
PyTorch uses the same raw adjoint structure, together with the real-input
66+
projection and the singular-matrix fallback discussed below.
67+
368
## 1. Determinant
469

570
### Forward Definition

docs/math/eig.md

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,96 @@
11
# General Eigen AD Notes
22

3+
## Conventions
4+
5+
Unless noted otherwise, `Linearization` and `Transpose` are written for the
6+
raw-output-space eigendecomposition before any DB observable such as
7+
`values_vectors_abs` is applied. For complex tensors, `Transpose` means the
8+
adjoint under the real Frobenius inner product
9+
10+
$$
11+
\langle X, Y \rangle_{\mathbb{R}} = \operatorname{Re}\operatorname{tr}(X^\dagger Y).
12+
$$
13+
14+
## Forward
15+
16+
The raw operator is
17+
18+
$$
19+
A \mapsto (\lambda, V),
20+
\qquad
21+
A V = V \operatorname{diag}(\lambda),
22+
$$
23+
24+
with simple eigenvalues.
25+
26+
## Linearization
27+
28+
Let
29+
30+
$$
31+
\Delta P = V^{-1}\dot{A}\,V.
32+
$$
33+
34+
Then
35+
36+
$$
37+
\dot{\lambda}_i = (\Delta P)_{ii},
38+
\qquad
39+
Q_{ij} = \frac{(\Delta P)_{ij}}{\lambda_j - \lambda_i} \ \ (i \neq j),
40+
\qquad
41+
Q_{ii} = 0,
42+
$$
43+
44+
and the normalized eigenvector tangent is
45+
46+
$$
47+
\dot{V} = VQ - V\,\operatorname{diag}\!\left(\operatorname{Re}(V^\dagger VQ)\right).
48+
$$
49+
50+
## JVP
51+
52+
The JVP is exactly the linearization evaluated at $\dot{A}$, returning the raw
53+
pair $(\dot{\lambda}, \dot{V})$ before any observable removes gauge freedom.
54+
55+
## Transpose
56+
57+
For raw output cotangents $(\bar{\lambda}, \bar{V})$, define
58+
59+
$$
60+
\bar{V}_{\mathrm{adj}} =
61+
\bar{V}
62+
- V \, \operatorname{diag}\!\left(\operatorname{Re}(V^\dagger \bar{V})\right),
63+
$$
64+
65+
then build
66+
67+
$$
68+
G = V^\dagger \bar{V}_{\mathrm{adj}},
69+
\qquad
70+
G_{ij} \leftarrow \frac{G_{ij}}{\overline{\lambda_j - \lambda_i}}
71+
\ \ (i \neq j),
72+
\qquad
73+
G_{ii} = \bar{\lambda}_i,
74+
$$
75+
76+
and finally
77+
78+
$$
79+
\bar{A} = V^{-\dagger} G V^\dagger.
80+
$$
81+
82+
## VJP (JAX convention)
83+
84+
JAX reads the same raw transpose map for the eigendecomposition outputs. If a
85+
downstream observable removes phase or normalization ambiguity, that observable
86+
is applied after the raw rule.
87+
88+
## VJP (PyTorch convention)
89+
90+
PyTorch uses the same raw adjoint together with the explicit normalization and
91+
gauge checks. For real inputs with complex outputs, the final cotangent is
92+
projected back to the real domain.
93+
394
## Forward Definition
495

596
$$

0 commit comments

Comments
 (0)