Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ uv sync --locked --all-groups
- Treat the public version string `2.10.0` as the repository contract. Local build suffixes
such as `+cu128` or `+cpu` should not appear in generated provenance.

## JAX Generator Dependencies

- The JAX generator surface is a first-class part of the repository contract.
- Keep `jax` and `jaxlib` pinned to exact matching versions in `pyproject.toml`.
- The current exact JAX pins are `jax==0.9.1` and `jaxlib==0.9.1`.
- Run every JAX command through `uv run`, including generator smoke checks and any
validation commands added for JAX witnesses.
- When updating JAX-related dependencies, refresh `uv.lock` before syncing the environment.

## Operational Notes

- Do not run `uv lock` and `uv sync --locked --all-groups` in parallel. Update the lockfile first, then sync.
Expand All @@ -38,6 +47,7 @@ uv sync --locked --all-groups
- `uv run python scripts/verify_cases.py`
- `uv run python scripts/check_replay.py`
- `uv run python scripts/check_regeneration.py`
- Add any JAX-specific validation scripts under the same `uv run` discipline.
- `scripts/check_regeneration.py` compares regenerated JSONL files semantically.
Metadata must match exactly, while numeric tensors may differ only within the
case-level `comparison.rtol` / `comparison.atol`.
Expand Down
46 changes: 46 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,37 @@ validation:
- mathematical AD notes
- a machine-readable JSON oracle database

The repository is intentionally dual-backend:

- PyTorch remains the baseline oracle source for the existing case DB
- JAX support is being added in later tasks as a parallel generator and witness
surface

## Planned JAX Surface

The following JAX-facing entrypoint and witness fields are planned for later
tasks and are not implemented yet:

- `uv run python -m generators.jax_v1 --list`
- `jax_ref`
- `linearization`
- `transpose`

These names document the intended future contract so the JAX backend can be
added without changing the vocabulary later.

## JAX Witness Source Policy

JAX witness materialization uses two source modes:

- Prefer `jax_test` for families with a dedicated JAX internal harness.
- Fall back to `torch_aligned` when a published PyTorch case should be reused
as the exact serialized primal-input source.

When a witness comes from a JAX internal harness, the top-level provenance
records the harness `source_file`, `source_function`, `seed`, and a
`comment` entry containing `harness_fullname=...`.

The oracle database covers both scalar-style `OpInfo` families and linear
algebra operations.

Expand Down Expand Up @@ -80,6 +111,11 @@ uv run python scripts/report_upstream_publish_coverage.py
uv run python scripts/report_complex_support.py
```

`uv run python scripts/validate_schema.py` is a repository-integrity and
publish-time check for maintainers and CI. Downstream consumers should treat
`schema/case.schema.json` as the contract and normally do not need to invoke
the repository script directly.

Repository-managed environment files:

- `.python-version`
Expand All @@ -90,6 +126,10 @@ The repository requires an exact PyTorch dependency pin: `torch==2.10.0`.
Generated provenance stores the public version string `2.10.0`, not local
build suffixes such as `+cpu` or `+cu128`.

The planned JAX backend will use exact version pins as well; those pins are
tracked in the repository contract now so the later generator work can rely on a
fixed runtime.

## Math Notes

The mathematical AD notes live under `docs/math/`.
Expand Down Expand Up @@ -142,6 +182,12 @@ A case is defined by:
- an `observable`
- one or more paired derivative probes

Published JSONL files store materialized numeric tensor payloads directly. For
`success` cases this includes serialized inputs, probe directions, cotangents,
and numeric reference tensors such as `pytorch_ref`, `fd_ref`, and any present
`jax_ref` witness payloads. Downstream readers do not need PyTorch or JAX to
reconstruct those published numbers.

The database does not require raw decomposition outputs to be the comparison target. For spectral operations, the observable may be a processed output such as `U.abs()`, `S`, `Vh.abs()`, or `U @ Vh`, following the same derivative-relevant observables used by PyTorch AD tests.

## Oracle Policy
Expand Down
8 changes: 4 additions & 4 deletions cases/abs/identity.jsonl

Large diffs are not rendered by default.

24 changes: 12 additions & 12 deletions cases/exp/identity.jsonl

Large diffs are not rendered by default.

67 changes: 67 additions & 0 deletions docs/math/cholesky.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,72 @@
# Cholesky AD Notes

## Conventions

Unless noted otherwise, `Linearization` and `Transpose` are written for the
raw-output-space Cholesky map before any DB observable projection. For complex
tensors, `Transpose` means the adjoint under the real Frobenius inner product

$$
\langle X, Y \rangle_{\mathbb{R}} = \operatorname{Re}\operatorname{tr}(X^\dagger Y).
$$

## Forward

The raw operator is the lower-triangular factor

$$
A \mapsto L,
\qquad
A = L L^{\mathsf{H}},
\qquad
A = A^{\mathsf{H}} \succ 0.
$$

## Linearization

With the helper

$$
\varphi(X) = \mathrm{tril}(X) - \tfrac{1}{2}\mathrm{diag}(X),
$$

the raw-output-space linearization is

$$
\dot{L} = L \, \varphi\!\bigl(L^{-1}\dot{A}\,L^{-\mathsf{H}}\bigr).
$$

## JVP

The JVP is exactly the same tangent formula:

$$
\operatorname{jvp}(\operatorname{chol})(A;\dot{A})
= L \, \varphi\!\bigl(L^{-1}\dot{A}\,L^{-\mathsf{H}}\bigr).
$$

## Transpose

For a raw output cotangent $\bar{L}$, the transpose map is

$$
\bar{A} =
L^{-\mathsf{H}} \,
\varphi^*\!\bigl(\mathrm{tril}(L^{\mathsf{H}}\bar{L})\bigr)
\, L^{-1}.
$$

## VJP (JAX convention)

JAX reads the same raw transpose map directly as the cotangent rule on the
Cholesky factor.

## VJP (PyTorch convention)

PyTorch uses the same triangular-solve sandwich. For `cholesky_ex`, auxiliary
status outputs are treated as metadata, so the VJP applies only to the factor
output.

## Forward Definition

$$
Expand Down
65 changes: 65 additions & 0 deletions docs/math/det.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,70 @@
# Determinant AD Notes

## Conventions

Unless noted otherwise, `Linearization` and `Transpose` are written for the
raw-output-space determinant maps before any DB observable projection. For
complex tensors, `Transpose` means the adjoint under the real Frobenius inner
product

$$
\langle X, Y \rangle_{\mathbb{R}} = \operatorname{Re}\operatorname{tr}(X^\dagger Y).
$$

## Forward

This note covers two raw operators:

$$
A \mapsto \det(A),
\qquad
A \mapsto (\operatorname{sign}, \operatorname{logabsdet}).
$$

## Linearization

For `det`,

$$
\dot{d} = \det(A)\operatorname{tr}(A^{-1}\dot{A}).
$$

For `slogdet`, if $w = \operatorname{tr}(A^{-1}\dot{A})$, then

$$
\dot{\operatorname{logabsdet}} = \operatorname{Re}(w),
\qquad
\dot{\operatorname{sign}} = i\,\operatorname{Im}(w)\operatorname{sign}.
$$

## JVP

The JVP is the same linearization evaluated at the tangent matrix $\dot{A}$.

## Transpose

For `det`, a raw output cotangent $\bar{d}$ gives

$$
\bar{A} = \overline{\bar{d}\det(A)}\,A^{-\mathsf{H}},
$$

with the real case reducing to $A^{-\mathsf{T}}$.

For `slogdet`, a raw output cotangent on the pair
$(\bar{\operatorname{sign}}, \bar{\operatorname{logabsdet}})$ yields the same
solve-style adjoint summarized later in the note.

## VJP (JAX convention)

JAX reads these transpose maps directly on the scalar or tuple output, with the
same singularity caveats as the raw determinant formulas.

## VJP (PyTorch convention)

PyTorch uses the same raw adjoint structure, together with the real-input
projection and the singular-matrix fallback discussed below.

## 1. Determinant

### Forward Definition
Expand Down
91 changes: 91 additions & 0 deletions docs/math/eig.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,96 @@
# General Eigen AD Notes

## Conventions

Unless noted otherwise, `Linearization` and `Transpose` are written for the
raw-output-space eigendecomposition before any DB observable such as
`values_vectors_abs` is applied. For complex tensors, `Transpose` means the
adjoint under the real Frobenius inner product

$$
\langle X, Y \rangle_{\mathbb{R}} = \operatorname{Re}\operatorname{tr}(X^\dagger Y).
$$

## Forward

The raw operator is

$$
A \mapsto (\lambda, V),
\qquad
A V = V \operatorname{diag}(\lambda),
$$

with simple eigenvalues.

## Linearization

Let

$$
\Delta P = V^{-1}\dot{A}\,V.
$$

Then

$$
\dot{\lambda}_i = (\Delta P)_{ii},
\qquad
Q_{ij} = \frac{(\Delta P)_{ij}}{\lambda_j - \lambda_i} \ \ (i \neq j),
\qquad
Q_{ii} = 0,
$$

and the normalized eigenvector tangent is

$$
\dot{V} = VQ - V\,\operatorname{diag}\!\left(\operatorname{Re}(V^\dagger VQ)\right).
$$

## JVP

The JVP is exactly the linearization evaluated at $\dot{A}$, returning the raw
pair $(\dot{\lambda}, \dot{V})$ before any observable removes gauge freedom.

## Transpose

For raw output cotangents $(\bar{\lambda}, \bar{V})$, define

$$
\bar{V}_{\mathrm{adj}} =
\bar{V}
- V \, \operatorname{diag}\!\left(\operatorname{Re}(V^\dagger \bar{V})\right),
$$

then build

$$
G = V^\dagger \bar{V}_{\mathrm{adj}},
\qquad
G_{ij} \leftarrow \frac{G_{ij}}{\overline{\lambda_j - \lambda_i}}
\ \ (i \neq j),
\qquad
G_{ii} = \bar{\lambda}_i,
$$

and finally

$$
\bar{A} = V^{-\dagger} G V^\dagger.
$$

## VJP (JAX convention)

JAX reads the same raw transpose map for the eigendecomposition outputs. If a
downstream observable removes phase or normalization ambiguity, that observable
is applied after the raw rule.

## VJP (PyTorch convention)

PyTorch uses the same raw adjoint together with the explicit normalization and
gauge checks. For real inputs with complex outputs, the final cotangent is
projected back to the real domain.

## Forward Definition

$$
Expand Down
Loading
Loading