Standalone implementation of differentiable eigenvalue decomposition with CPU (LAPACK) and GPU (cuSOLVER) backends. Extracted from pyscfad.
- Generalized Problems:
A @ V = B @ V @ diag(W), etc. - JAX Integrated: Full support for
jit,vmap,grad, andjvp. - High Performance: Optimized LAPACK (CPU) and cuSOLVER (GPU) kernels.
- Precision:
float32/64andcomplex64/128. - Degeneracy Handling: Configurable
deg_threshfor stable gradients.
# Install from source
pip install .
# For GPU support in this environment
pip install .[cuda-local]import jax
import jax.numpy as jnp
from eigh import eigh
jax.config.update("jax_enable_x64", True)
A = jnp.array([[2., 1.], [1., 2.]])
w, v = eigh(A) # Standard
grad = jax.grad(lambda A: eigh(A)[0].sum())(A) # Differentiableeigh(a, b=None, *, lower=True, eigvals_only=False, type=1, deg_thresh=1e-9)Scipy-compatible interface.typesupports 1:A@v=B@v@λ, 2:A@B@v=v@λ, 3:B@A@v=v@λ.eigh_gen(a, b, *, lower=True, itype=1, deg_thresh=1e-9)Lower-level generalized solver.
Individual eigenvalue gradients are ill-defined for degenerate (repeated) eigenvalues. However, symmetric functions (like sum, var, trace) have stable gradients. The deg_thresh parameter (default 1e-9) masks divisions by near-zero gaps to maintain stability.
- Requirements: CMake 3.18+, C++17, JAX, NumPy, LAPACK/CUDA.
- Tests:
pytest tests/test_eigh.py # Core functionality pytest tests/test_eigh_gen.py # Generalized itypes pytest tests/test_eigh_jit.py # JIT & vmap
- GPU Setup:
source setup_gpu_env_clean.sh ./run_gpu.sh python example_simple.py
Apache License 2.0. If used in research, please cite:
@software{pyscfad,
author = {Zhang, Xing},
title = {PySCFad: Automatic Differentiation for PySCF},
url = {https://github.com/fishjojo/pyscfad},
year = {2021-2025}
}