diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index cd000862bacc..976c509a7f54 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -465,6 +465,67 @@ def matrix_rank( return reductions.sum(S > rtol, axis=-1) +def _slogdet_from_det(det: Array, a: Array) -> tuple[Array, Array]: + """Helper to compute slogdet from a pre-computed determinant.""" + dtype = lax.dtype(a) + abs_det = ufuncs.abs(det) + is_zero = abs_det == jnp.array(0, dtype=dtype) + if jnp.iscomplexobj(a): + sign = jnp.where(is_zero, jnp.array(0, dtype=dtype), det / ufuncs.astype(abs_det, det.dtype)) + else: + sign = jnp.where(is_zero, jnp.array(0, dtype=dtype), ufuncs.sign(det)) + logabsdet = jnp.where( + is_zero, jnp.array(-np.inf, dtype=dtype), ufuncs.log(abs_det).astype(dtype)) + return sign, ufuncs.real(logabsdet) + + +def _slogdet_1x1(a: Array) -> tuple[Array, Array]: + """Analytic slogdet for 1x1 matrices. No LU/solve; avoids TRSV on ROCm.""" + det = a[..., 0, 0] + return _slogdet_from_det(det, a) + + +def _slogdet_2x2(a: Array) -> tuple[Array, Array]: + """Analytic slogdet for 2x2 matrices. No LU/solve; avoids TRSV on ROCm.""" + a00, a01 = a[..., 0, 0], a[..., 0, 1] + a10, a11 = a[..., 1, 0], a[..., 1, 1] + det = (a00 * a11) - (a01 * a10) + return _slogdet_from_det(det, a) + + +@custom_jvp +def _slogdet_small(a: Array) -> tuple[Array, Array]: + """slogdet for n in {1, 2} using analytic formulas only (no solve).""" + n = a.shape[-1] + if n == 1: + return _slogdet_1x1(a) + elif n == 2: + return _slogdet_2x2(a) + else: + raise ValueError(f"_slogdet_small only supports n in {{1, 2}}, got n={n}") + + +def _slogdet_small_jvp(primals, tangents): + x, = primals + g, = tangents + n = x.shape[-1] + sign, ans = _slogdet_small(x) + if n == 1: + ans_dot = ufuncs.real(g[..., 0, 0] / x[..., 0, 0]) + else: # n == 2 + x00, x01 = x[..., 0, 0], x[..., 0, 1] + x10, x11 = x[..., 1, 0], x[..., 1, 1] + det = (x00 * x11) - (x01 * x10) + ans_dot = ufuncs.real( + (g[..., 0, 0] * x11 - g[..., 0, 1] * x10 - g[..., 1, 0] * x01 + g[..., 1, 1] * x00) / det + ) + sign_dot = array_creation.zeros_like(sign) + return (sign, ans), (sign_dot, ans_dot) + + +_slogdet_small.defjvp(_slogdet_small_jvp) + + @custom_jvp def _slogdet_lu(a: Array) -> tuple[Array, Array]: dtype = lax.dtype(a) @@ -548,6 +609,9 @@ def slogdet(a: ArrayLike, *, method: str | None = None) -> SlogdetResult: if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]: raise ValueError(f"Argument to slogdet() must have shape [..., n, n], got {a_shape}") if method is None or method == "lu": + n = a_shape[-1] + if n == 1 or n == 2: + return SlogdetResult(*_slogdet_small(a)) return SlogdetResult(*_slogdet_lu(a)) elif method == "qr": return SlogdetResult(*_slogdet_qr(a))