Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions cases/matrix_exp/identity.jsonl

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions docs/generated/complex-support.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ Generated from the checked-in complex-support ledger and the published `cases/`

## Summary

- Total tracked families: 176
- Ready for downstream: 103
- Total tracked families: 177
- Ready for downstream: 104
- Unsupported: 73
- Pending note review: 0
- Pending DB coverage: 0
Expand Down Expand Up @@ -89,6 +89,7 @@ Generated from the checked-in complex-support ledger and the published `cases/`
| lu_factor | identity | not_required | covered | complex128, complex64 | - | yes |
| lu_factor_ex | identity | not_required | covered | complex128, complex64 | - | yes |
| lu_solve | identity | not_required | covered | complex128, complex64 | - | yes |
| matrix_exp | identity | reviewed | covered | complex128, complex64 | - | yes |
| matrix_norm | identity | not_required | covered | complex128, complex64 | - | yes |
| matrix_power | identity | reviewed | covered | complex128, complex64 | - | yes |
| max_binary | identity | not_required | unsupported | - | float-only in pinned PyTorch upstream AD coverage | no |
Expand Down Expand Up @@ -247,6 +248,7 @@ Generated from the checked-in complex-support ledger and the published `cases/`
| lu_factor | identity | not_required | covered | complex128, complex64 | - | yes |
| lu_factor_ex | identity | not_required | covered | complex128, complex64 | - | yes |
| lu_solve | identity | not_required | covered | complex128, complex64 | - | yes |
| matrix_exp | identity | reviewed | covered | complex128, complex64 | - | yes |
| matrix_norm | identity | not_required | covered | complex128, complex64 | - | yes |
| matrix_power | identity | reviewed | covered | complex128, complex64 | - | yes |
| mean | identity | reviewed | covered | complex128, complex64 | - | yes |
Expand Down
3 changes: 2 additions & 1 deletion docs/generated/pytorch-upstream-publish-coverage.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ and the checked-in `cases/` tree.
- AD-relevant scalar upstream variants: 138
- Mapped publishable success families: 174
- Explicit publishable error families: 2
- Total tracked DB families: 176
- Total tracked DB families: 177

## Publishable Family Coverage

Expand Down Expand Up @@ -93,6 +93,7 @@ publishable upstream coverage that is not yet materialized in this repository.
| lu_factor | identity | success | float64, complex128, float32, complex64 | float64, complex128, float32, complex64 | - |
| lu_factor_ex | identity | success | float64, complex128, float32, complex64 | float64, complex128, float32, complex64 | - |
| lu_solve | identity | success | float64, complex128, float32, complex64 | float64, complex128, float32, complex64 | - |
| matrix_exp | identity | success | float64, complex128, float32, complex64 | float64, complex128, float32, complex64 | - |
| matrix_norm | identity | success | float64, complex128, float32, complex64 | float64, complex128, float32, complex64 | - |
| matrix_power | identity | success | float64, complex128, float32, complex64 | float64, complex128, float32, complex64 | - |
| max_binary | identity | success | float64, float32 | float64, float32 | - |
Expand Down
13 changes: 13 additions & 0 deletions docs/math/complex-support.json
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,19 @@
},
"unsupported_reason": null
},
{
"op": "matrix_exp",
"family": "identity",
"note": {
"path": "docs/math/matrix_exp.md",
"anchor": "family-identity",
"status": "reviewed"
},
"db": {
"status": "covered"
},
"unsupported_reason": null
},
{
"op": "matrix_power",
"family": "identity",
Expand Down
34 changes: 34 additions & 0 deletions docs/math/eig.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,37 @@ column-wise phase ambiguity of raw eigenvectors.
### `eigvals/identity`

The eigenvalue-only family reuses the diagonal part of the same differential.

## Complex Oracle Strategy

### Phase ambiguity resolution

The `values_vectors_abs` observable publishes eigenvalues together with
`abs(eigenvectors)`. Raw eigenvectors are defined only up to per-column complex
phase $V \mapsto V \operatorname{diag}(e^{i\phi_k})$. Taking the element-wise
absolute value collapses this gauge freedom into a well-defined observable whose
AD is unambiguous.

### Real-to-complex output handling

For real input matrices ($A \in \mathbb{R}^{N \times N}$), PyTorch's
`linalg.eig` returns complex-valued eigenvalues and eigenvectors. The forward
rule operates in the complex domain; the reverse rule projects the cotangent
back to the real domain via $\bar{A} \leftarrow \operatorname{Re}(\bar{A})$
(`handle_r_to_c`). The oracle DB includes float32 and float64 cases that
exercise this path.

### Complex-input coverage

The oracle DB includes complex64 and complex128 input cases. For complex inputs
the full complex formula applies: the normalization correction uses
$\operatorname{Re}(V^\dagger \dot{V}_{\mathrm{raw}})$, the reverse rule uses
$V^{-\dagger} G V^\dagger$, and the gauge invariance check verifies
$\operatorname{Im}(\operatorname{diag}(V^\dagger \bar{V})) = 0$.

### Future considerations

If downstream `tenferro-rs` requires a different observable representation
(sorted eigenvalues, a different gauge-fixing convention, or separate real/imaginary
parts), a new DB family can be added alongside `values_vectors_abs` without
breaking the existing contract.
8 changes: 5 additions & 3 deletions docs/math/matrix_exp.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ PyTorch factors this pattern through the helper
2. A. H. Al-Mohy and N. J. Higham, "Computing the Frechet Derivative of the
Matrix Exponential," 2009.

## DB Status
## DB Families

`matrix_exp` is documented here as a known rule, but it is **not yet materialized**
in the current published `cases/` tree.
<a id="family-identity"></a>
### `matrix_exp/identity`

The DB publishes the matrix exponential output directly.
6 changes: 6 additions & 0 deletions docs/math/registry.json
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,12 @@
"note_path": "docs/math/norm.md",
"anchor": "family-matrix-norm-identity"
},
{
"op": "matrix_exp",
"family": "identity",
"note_path": "docs/math/matrix_exp.md",
"anchor": "family-identity"
},
{
"op": "matrix_power",
"family": "identity",
Expand Down
28 changes: 26 additions & 2 deletions generators/pytorch_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,32 @@ def _build_scalar_case_specs() -> tuple[CaseFamilySpec, ...]:
return tuple(specs)


def _build_cmi_linalg_case_specs() -> tuple[CaseFamilySpec, ...]:
"""Case specs for linalg ops whose OpInfo lives in common_methods_invocations."""
return (
CaseFamilySpec(
op="matrix_exp",
family="identity",
observable_kind="identity",
expected_behavior="success",
source_file="torch/testing/_internal/common_methods_invocations.py",
source_function="sample_inputs_matrix_exp",
upstream_name="matrix_exp",
upstream_variant_name="",
hvp_enabled=True,
inventory_kind="cmi_linalg",
supported_dtype_names=("float64", "complex128", "float32", "complex64"),
),
)


def _build_case_specs() -> tuple[CaseFamilySpec, ...]:
return _build_success_case_specs() + _build_error_case_specs() + _build_scalar_case_specs()
return (
_build_success_case_specs()
+ _build_error_case_specs()
+ _build_cmi_linalg_case_specs()
+ _build_scalar_case_specs()
)

@lru_cache(maxsize=1)
def _case_specs_cached() -> tuple[CaseFamilySpec, ...]:
Expand Down Expand Up @@ -578,7 +602,7 @@ def _is_skippable_hvp_runtime_error(exc: RuntimeError) -> bool:


def _resolve_runtime_for_spec(spec: CaseFamilySpec):
if spec.inventory_kind == "scalar":
if spec.inventory_kind in ("scalar", "cmi_linalg"):
return import_scalar_generation_runtime()
return import_generation_runtime()

Expand Down
4 changes: 2 additions & 2 deletions scripts/check_upstream_ad_tolerances.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def audit_against_upstream_ad_tolerances(
spec = spec_index[(op, family)]
resolver = (
resolve_upstream_scalar_ad_tolerance
if spec.inventory_kind == "scalar"
if spec.inventory_kind in ("scalar", "cmi_linalg")
else resolve_upstream_ad_tolerance
)
upstream = resolver(
Expand Down Expand Up @@ -116,7 +116,7 @@ def audit_against_upstream_ad_tolerances(
spec = spec_index[(op, family)]
resolver = (
resolve_upstream_scalar_ad_tolerance
if spec.inventory_kind == "scalar"
if spec.inventory_kind in ("scalar", "cmi_linalg")
else resolve_upstream_ad_tolerance
)
upstream = resolver(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_math_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,12 +364,12 @@ def test_repo_scalar_ops_note_exposes_representative_op_anchors(self) -> None:

self.assertEqual({"op-abs", "op-add", "op-sum", "op-var"} - anchors, set())

def test_repo_matrix_exp_note_marks_db_status(self) -> None:
def test_repo_matrix_exp_note_has_db_family(self) -> None:
text = (
Path(__file__).resolve().parents[1] / "docs" / "math" / "matrix_exp.md"
).read_text(encoding="utf-8")

self.assertIn("not yet materialized", text)
self.assertIn("family-identity", text)

def test_repo_registry_contains_representative_family_mappings(self) -> None:
root = Path(__file__).resolve().parents[1]
Expand Down
4 changes: 2 additions & 2 deletions validators/replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def _find_candidate_samples(


def _prepare_samples_for_spec_dtype(torch, spec, *, dtype_name: str) -> list[PreparedSample]:
if getattr(spec, "inventory_kind", "linalg") == "scalar":
if getattr(spec, "inventory_kind", "linalg") in ("scalar", "cmi_linalg"):
_, runtime_source = import_scalar_generation_runtime()
else:
_, runtime_source = import_generation_runtime()
Expand Down Expand Up @@ -249,7 +249,7 @@ def _replay_success_case_for_sample(
input_names = tuple(inputs.keys())
first_order = _first_order_comparison(comparison)
second_order = _second_order_comparison(comparison)
if getattr(spec, "inventory_kind", "linalg") == "scalar":
if getattr(spec, "inventory_kind", "linalg") in ("scalar", "cmi_linalg"):
_, runtime_source = import_scalar_generation_runtime()
else:
_, runtime_source = import_generation_runtime()
Expand Down
Loading