From 9273cfb0f86630c5197fda390c494a09fd730699 Mon Sep 17 00:00:00 2001 From: cj401-amd Date: Tue, 10 Feb 2026 15:27:21 -0600 Subject: [PATCH 1/2] update for slotdet when n=1 or 2 --- jax/_src/numpy/linalg.py | 68 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index cd000862bacc..d4093b342382 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -465,6 +465,71 @@ def matrix_rank( return reductions.sum(S > rtol, axis=-1) +def _slogdet_1x1(a: Array) -> tuple[Array, Array]: + """Analytic slogdet for 1x1 matrices. No LU/solve; avoids TRSV on ROCm.""" + dtype = lax.dtype(a) + det = a[..., 0, 0] + abs_det = ufuncs.abs(det) + is_zero = abs_det == jnp.array(0, dtype=dtype) + if jnp.iscomplexobj(a): + sign = jnp.where(is_zero, jnp.array(0, dtype=dtype), det / ufuncs.astype(abs_det, det.dtype)) + else: + sign = jnp.where(is_zero, jnp.array(0, dtype=dtype), ufuncs.sign(det)) + logabsdet = jnp.where( + is_zero, jnp.array(-np.inf, dtype=dtype), ufuncs.log(abs_det).astype(dtype)) + return sign, ufuncs.real(logabsdet) + + +def _slogdet_2x2(a: Array) -> tuple[Array, Array]: + """Analytic slogdet for 2x2 matrices. No LU/solve; avoids TRSV on ROCm.""" + dtype = lax.dtype(a) + a00, a01 = a[..., 0, 0], a[..., 0, 1] + a10, a11 = a[..., 1, 0], a[..., 1, 1] + det = (a00 * a11) - (a01 * a10) + abs_det = ufuncs.abs(det) + is_zero = abs_det == jnp.array(0, dtype=dtype) + if jnp.iscomplexobj(a): + sign = jnp.where(is_zero, jnp.array(0, dtype=dtype), det / ufuncs.astype(abs_det, det.dtype)) + else: + sign = jnp.where(is_zero, jnp.array(0, dtype=dtype), ufuncs.sign(det)) + logabsdet = jnp.where( + is_zero, jnp.array(-np.inf, dtype=dtype), ufuncs.log(abs_det).astype(dtype)) + return sign, ufuncs.real(logabsdet) + + +@custom_jvp +def _slogdet_small(a: Array) -> tuple[Array, Array]: + """slogdet for n in {1, 2} using analytic formulas only (no solve).""" + n = a.shape[-1] + if n == 1: + return _slogdet_1x1(a) + elif n == 2: + return _slogdet_2x2(a) + else: + raise ValueError(f"_slogdet_small only supports n in {{1, 2}}, got n={n}") + + +def _slogdet_small_jvp(primals, tangents): + x, = primals + g, = tangents + n = x.shape[-1] + sign, ans = _slogdet_small(x) + if n == 1: + ans_dot = ufuncs.real(g[..., 0, 0] / x[..., 0, 0]) + else: # n == 2 + x00, x01 = x[..., 0, 0], x[..., 0, 1] + x10, x11 = x[..., 1, 0], x[..., 1, 1] + det = (x00 * x11) - (x01 * x10) + ans_dot = ufuncs.real( + (g[..., 0, 0] * x11 - g[..., 0, 1] * x10 - g[..., 1, 0] * x01 + g[..., 1, 1] * x00) / det + ) + sign_dot = array_creation.zeros_like(sign) + return (sign, ans), (sign_dot, ans_dot) + + +_slogdet_small.defjvp(_slogdet_small_jvp) + + @custom_jvp def _slogdet_lu(a: Array) -> tuple[Array, Array]: dtype = lax.dtype(a) @@ -548,6 +613,9 @@ def slogdet(a: ArrayLike, *, method: str | None = None) -> SlogdetResult: if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]: raise ValueError(f"Argument to slogdet() must have shape [..., n, n], got {a_shape}") if method is None or method == "lu": + n = a_shape[-1] + if n == 1 or n == 2: + return SlogdetResult(*_slogdet_small(a)) return SlogdetResult(*_slogdet_lu(a)) elif method == "qr": return SlogdetResult(*_slogdet_qr(a)) From 707a87e87df211e6324708d877d24c03e5567c17 Mon Sep 17 00:00:00 2001 From: cj401-amd Date: Tue, 10 Feb 2026 16:31:12 -0600 Subject: [PATCH 2/2] update with a simplication for slogdet --- jax/_src/numpy/linalg.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index d4093b342382..976c509a7f54 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -465,10 +465,9 @@ def matrix_rank( return reductions.sum(S > rtol, axis=-1) -def _slogdet_1x1(a: Array) -> tuple[Array, Array]: - """Analytic slogdet for 1x1 matrices. No LU/solve; avoids TRSV on ROCm.""" +def _slogdet_from_det(det: Array, a: Array) -> tuple[Array, Array]: + """Helper to compute slogdet from a pre-computed determinant.""" dtype = lax.dtype(a) - det = a[..., 0, 0] abs_det = ufuncs.abs(det) is_zero = abs_det == jnp.array(0, dtype=dtype) if jnp.iscomplexobj(a): @@ -480,21 +479,18 @@ def _slogdet_1x1(a: Array) -> tuple[Array, Array]: return sign, ufuncs.real(logabsdet) +def _slogdet_1x1(a: Array) -> tuple[Array, Array]: + """Analytic slogdet for 1x1 matrices. No LU/solve; avoids TRSV on ROCm.""" + det = a[..., 0, 0] + return _slogdet_from_det(det, a) + + def _slogdet_2x2(a: Array) -> tuple[Array, Array]: """Analytic slogdet for 2x2 matrices. No LU/solve; avoids TRSV on ROCm.""" - dtype = lax.dtype(a) a00, a01 = a[..., 0, 0], a[..., 0, 1] a10, a11 = a[..., 1, 0], a[..., 1, 1] det = (a00 * a11) - (a01 * a10) - abs_det = ufuncs.abs(det) - is_zero = abs_det == jnp.array(0, dtype=dtype) - if jnp.iscomplexobj(a): - sign = jnp.where(is_zero, jnp.array(0, dtype=dtype), det / ufuncs.astype(abs_det, det.dtype)) - else: - sign = jnp.where(is_zero, jnp.array(0, dtype=dtype), ufuncs.sign(det)) - logabsdet = jnp.where( - is_zero, jnp.array(-np.inf, dtype=dtype), ufuncs.log(abs_det).astype(dtype)) - return sign, ufuncs.real(logabsdet) + return _slogdet_from_det(det, a) @custom_jvp