Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions jax/_src/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,67 @@ def matrix_rank(
return reductions.sum(S > rtol, axis=-1)


def _slogdet_from_det(det: Array, a: Array) -> tuple[Array, Array]:
"""Helper to compute slogdet from a pre-computed determinant."""
dtype = lax.dtype(a)
abs_det = ufuncs.abs(det)
is_zero = abs_det == jnp.array(0, dtype=dtype)
if jnp.iscomplexobj(a):
sign = jnp.where(is_zero, jnp.array(0, dtype=dtype), det / ufuncs.astype(abs_det, det.dtype))
else:
sign = jnp.where(is_zero, jnp.array(0, dtype=dtype), ufuncs.sign(det))
logabsdet = jnp.where(
is_zero, jnp.array(-np.inf, dtype=dtype), ufuncs.log(abs_det).astype(dtype))
return sign, ufuncs.real(logabsdet)


def _slogdet_1x1(a: Array) -> tuple[Array, Array]:
"""Analytic slogdet for 1x1 matrices. No LU/solve; avoids TRSV on ROCm."""
det = a[..., 0, 0]
return _slogdet_from_det(det, a)


def _slogdet_2x2(a: Array) -> tuple[Array, Array]:
"""Analytic slogdet for 2x2 matrices. No LU/solve; avoids TRSV on ROCm."""
a00, a01 = a[..., 0, 0], a[..., 0, 1]
a10, a11 = a[..., 1, 0], a[..., 1, 1]
det = (a00 * a11) - (a01 * a10)
return _slogdet_from_det(det, a)


@custom_jvp
def _slogdet_small(a: Array) -> tuple[Array, Array]:
"""slogdet for n in {1, 2} using analytic formulas only (no solve)."""
n = a.shape[-1]
if n == 1:
return _slogdet_1x1(a)
elif n == 2:
return _slogdet_2x2(a)
else:
raise ValueError(f"_slogdet_small only supports n in {{1, 2}}, got n={n}")


def _slogdet_small_jvp(primals, tangents):
x, = primals
g, = tangents
n = x.shape[-1]
sign, ans = _slogdet_small(x)
if n == 1:
ans_dot = ufuncs.real(g[..., 0, 0] / x[..., 0, 0])
else: # n == 2
x00, x01 = x[..., 0, 0], x[..., 0, 1]
x10, x11 = x[..., 1, 0], x[..., 1, 1]
det = (x00 * x11) - (x01 * x10)
ans_dot = ufuncs.real(
(g[..., 0, 0] * x11 - g[..., 0, 1] * x10 - g[..., 1, 0] * x01 + g[..., 1, 1] * x00) / det
)
sign_dot = array_creation.zeros_like(sign)
return (sign, ans), (sign_dot, ans_dot)


_slogdet_small.defjvp(_slogdet_small_jvp)


@custom_jvp
def _slogdet_lu(a: Array) -> tuple[Array, Array]:
dtype = lax.dtype(a)
Expand Down Expand Up @@ -548,6 +609,9 @@ def slogdet(a: ArrayLike, *, method: str | None = None) -> SlogdetResult:
if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]:
raise ValueError(f"Argument to slogdet() must have shape [..., n, n], got {a_shape}")
if method is None or method == "lu":
n = a_shape[-1]
if n == 1 or n == 2:
return SlogdetResult(*_slogdet_small(a))
return SlogdetResult(*_slogdet_lu(a))
elif method == "qr":
return SlogdetResult(*_slogdet_qr(a))
Expand Down