Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
7f893c2
Port iternorm
connoraird Jan 19, 2026
498f995
Port discretized_cls
connoraird Jan 19, 2026
e931b16
Port cov_from_spectra
connoraird Jan 19, 2026
962137e
Port lognormal_gls
connoraird Jan 19, 2026
ba5714e
Add jax support to effective_cls
connoraird Jan 19, 2026
287e19b
Port generate and _generate_grf
connoraird Jan 19, 2026
7930346
Test enumerate_spectra for all array backends
connoraird Jan 19, 2026
f4a9c53
Test check_posdef_spectra for all array backends
connoraird Jan 19, 2026
9a7549f
Test regularized_spectra with all array backends
connoraird Jan 20, 2026
3078c7c
Correct empty array test and fix pw type
connoraird Jan 22, 2026
59c448f
Simplify iternorm port to be just convert on input and output
connoraird Jan 27, 2026
9699672
Port _generate_grf
connoraird Jan 27, 2026
15965f3
Ensure spectra_indices is array api compatible
connoraird Jan 27, 2026
d5b1ffa
Ensure test_glass_to_healpix_alm uses complex arrays
connoraird Jan 27, 2026
c2d3a93
Merge branch 'main' into connor/issue-977
connoraird Jan 27, 2026
27a06c4
Fix _generate_grf by setting value of z
connoraird Jan 28, 2026
93b3ac8
Remove redundant conversion to array backend
connoraird Jan 28, 2026
d47a67c
Add tril_indices to uxpx
connoraird Jan 28, 2026
a5a0467
Add FloatArray as a pixwin return type
connoraird Jan 28, 2026
d79234d
Add clarifying comment about porting solve via conversion on inut and…
connoraird Jan 28, 2026
cdeb34e
fix mypy error
connoraird Jan 28, 2026
b35b51a
Merge branch 'main' into connor/issue-977
paddyroddy Jan 28, 2026
cc9d2e0
Extract xp when needed for clarity
connoraird Jan 28, 2026
b5a14d3
Return the correct array type from tril_indices
connoraird Jan 28, 2026
3dd2a45
Add comment to explain why np.asarray is needed
connoraird Jan 28, 2026
d4fb652
Add newline in docstring
connoraird Jan 28, 2026
eb6ca31
Reorder comments
connoraird Jan 28, 2026
4eaeb91
Merge branch 'main' into connor/issue-977
connoraird Jan 29, 2026
690182c
gh-1000: Remove `conftest` imports and consistent `pytest.skip` (#1001)
paddyroddy Jan 29, 2026
cd2f310
Merge remote-tracking branch 'origin/main' into connor/issue-977
connoraird Jan 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion glass/_array_api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

import numpy as np

from glass._types import AnyArray, DTypeLike
from glass._types import AnyArray, DTypeLike, IntArray


class CompatibleBackendNotFoundError(Exception):
Expand Down Expand Up @@ -576,3 +576,42 @@ def ndindex(shape: tuple[int, ...], *, xp: ModuleType) -> np.ndindex:

msg = "the array backend in not supported"
raise NotImplementedError(msg)

@staticmethod
def tril_indices(
n: int,
*,
k: int = 0,
m: int | None = None,
xp: ModuleType,
) -> tuple[IntArray, ...]:
"""
Return the indices for the lower-triangle of an (n, m) array.

Parameters
----------
n
The row dimension of the arrays for which the returned indices will be
valid.
k
Diagonal offset.
m
The column dimension of the arrays for which the returned arrays will be
valid. By default m is taken equal to n.

Returns
-------
The row and column indices, respectively. The row indices are sorted in
non-decreasing order, and the corresponding column indices are strictly
increasing for each row.

"""
if xp.__name__ in {"numpy", "jax.numpy"}:
return xp.tril_indices(n, k=k, m=m) # type: ignore[no-any-return]

if xp.__name__ == "array_api_strict":
np = import_numpy(xp.__name__)
return tuple(xp.asarray(arr) for arr in np.tril_indices(n, k=k, m=m))

msg = "the array backend in not supported"
raise NotImplementedError(msg)
122 changes: 65 additions & 57 deletions glass/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from __future__ import annotations

import itertools
import math
import warnings
from collections.abc import Sequence
Expand Down Expand Up @@ -126,66 +125,60 @@ def iternorm(
If the covariance matrix is not positive definite.

"""
# Convert to list here to allow determining the namespace
first = next(cov) # type: ignore[call-overload]
xp = first.__array_namespace__()

n = (size,) if isinstance(size, int) else size

m = xp.zeros((*n, k, k))
a = xp.zeros((*n, k))
s = xp.zeros((*n,))
m = np.zeros((*n, k, k))
a = np.zeros((*n, k))
s = np.zeros((*n,))
q = (*n, k + 1)
j = 0 if k > 0 else None

# We must use cov_expanded here as cov has been consumed to determine the namespace
for i, x in enumerate(itertools.chain([first], cov)):
# Ideally would be xp.asanyarray but this does not yet exist. The key difference
# between the two in numpy is that asanyarray maintains subclasses of NDArray
# whereas asarray will return the base class NDArray. Currently, we don't seem
# to pass a subclass of NDArray so this, so it might be okay
x = xp.asarray(x) # noqa: PLW2901
if x.shape != q:
for i, x in enumerate(cov):
x_np = np.asarray(x)
if x_np.shape != q:
try:
x = xp.broadcast_to(x, q) # noqa: PLW2901
x_np = np.broadcast_to(x_np, q)
except ValueError:
msg = f"covariance row {i}: shape {x.shape} cannot be broadcast to {q}"
msg = (
f"covariance row {i}: shape {x_np.shape} cannot be broadcast to {q}"
)
raise TypeError(msg) from None

# only need to update matrix A if there are correlations
if j is not None:
# compute new entries of matrix A
m[..., :, j] = 0
m[..., j : j + 1, :] = xp.matmul(a[..., xp.newaxis, :], m)
m[..., j, j] = xp.where(s != 0, -1, s)
# To ensure we don't divide by zero or nan we use a mask to only divide the
# appropriate values of m and s
m_j = m[..., j, :]
s_broadcast = xp.broadcast_to(s[..., xp.newaxis], m_j.shape)
mask = (m_j != 0) & (s_broadcast != 0) & ~xp.isnan(s_broadcast)
m_j[mask] = xp.divide(m_j[mask], -s_broadcast[mask])
m[..., j, :] = m_j
m[..., j : j + 1, :] = np.matmul(a[..., np.newaxis, :], m)
m[..., j, j] = np.where(s != 0, -1, 0)
np.divide(
m[..., j, :],
-s[..., np.newaxis],
where=(m[..., j, :] != 0),
out=m[..., j, :],
)

# compute new vector a
c = x[..., 1:, xp.newaxis]
a = xp.matmul(m[..., :j], c[..., k - j :, :])
a += xp.matmul(m[..., j:], c[..., : k - j, :])
a = xp.reshape(a, (*n, k))
c = x_np[..., 1:, np.newaxis]
a = np.matmul(m[..., :j], c[..., k - j :, :])
a += np.matmul(m[..., j:], c[..., : k - j, :])
a = np.reshape(a, (*n, k))

# next rolling index
j = (j - 1) % k

# compute new standard deviation
a_np = np.asarray(a, copy=True)
einsum_result_np = np.einsum("...i,...i", a_np, a_np)
s = x[..., 0] - xp.asarray(einsum_result_np, copy=True)
if xp.any(s < 0):
s = x_np[..., 0] - np.einsum("...i,...i", a, a)
if np.any(s < 0):
msg = "covariance matrix is not positive definite"
raise ValueError(msg)
s = xp.sqrt(s)
s = np.sqrt(s)

# Extract input array backend or conversion of outputs
xp = x.__array_namespace__()

# yield the next index, vector a, and standard deviation s
yield j, a, s
yield j, xp.asarray(a), xp.asarray(s)


def cls2cov(
Expand Down Expand Up @@ -272,25 +265,30 @@ def discretized_cls(
If the length of the Cls array is not a triangle number.

"""
if len(cls) == 0:
return []

xp = array_api_compat.array_namespace(*cls, use_compat=False)

if ncorr is not None:
n = nfields_from_nspectra(len(cls))
cls = [
cls[i * (i + 1) // 2 + j] if j <= ncorr else np.asarray([])
cls[i * (i + 1) // 2 + j] if j <= ncorr else xp.asarray([])
for i in range(n)
for j in range(i + 1)
]

if nside is not None:
pw = hp.pixwin(nside, lmax=lmax, xp=np)
pw: FloatArray = hp.pixwin(nside, lmax=lmax, xp=xp)

gls = []
for cl in cls:
if len(cl) > 0: # type: ignore[arg-type]
if cl.shape[0] > 0:
if lmax is not None:
cl = cl[: lmax + 1] # noqa: PLW2901
if nside is not None:
n = min(len(cl), len(pw)) # type: ignore[arg-type]
cl = cl[:n] * pw[:n] ** 2 # type: ignore[operator] # noqa: PLW2901
n = min(cl.shape[0], pw.shape[0])
cl = cl[:n] * pw[:n] ** 2 # noqa: PLW2901
gls.append(cl)
return gls

Expand Down Expand Up @@ -369,8 +367,10 @@ def _generate_grf(
If all gls are empty.

"""
xp = array_api_compat.array_namespace(*gls, use_compat=False)

if rng is None:
rng = _rng.rng_dispatcher(xp=np)
rng = _rng.rng_dispatcher(xp=xp)

# number of gls and number of fields
ngls = len(gls)
Expand All @@ -381,7 +381,7 @@ def _generate_grf(
ncorr = ngrf - 1

# number of modes
n = max((len(gl) for gl in gls), default=0) # type: ignore[arg-type]
n = max((gl.shape[0] for gl in gls), default=0)
if n == 0:
msg = "all gls are empty"
raise ValueError(msg)
Expand All @@ -390,8 +390,8 @@ def _generate_grf(
cov = cls2cov(gls, n, ngrf, ncorr)

# working arrays for the iterative sampling
z = np.zeros(n * (n + 1) // 2, dtype=np.complex128)
y = np.zeros((n * (n + 1) // 2, ncorr), dtype=np.complex128)
z_size = n * (n + 1) // 2
y = xp.zeros((z_size, ncorr), dtype=xp.complex128)

# generate the conditional normal distribution for iterative sampling
conditional_dist = iternorm(ncorr, cov, size=n)
Expand All @@ -400,7 +400,7 @@ def _generate_grf(
for j, a, s in conditional_dist:
# standard normal random variates for alm
# sample real and imaginary parts, then view as complex number
rng.standard_normal(n * (n + 1), np.float64, z.view(np.float64)) # type: ignore[call-arg]
z = rng.standard_normal((z_size,)) + (1j * rng.standard_normal((z_size,)))

# scale by standard deviation of the conditional distribution
# variance is distributed over real and imaginary part
Expand All @@ -412,13 +412,12 @@ def _generate_grf(

# store the standard normal in y array at the indicated index
if j is not None:
y[:, j] = z
y = xpx.at(y)[:, j].set(z)

alm = _glass_to_healpix_alm(alm)

# modes with m = 0 are real-valued and come first in array
np.real(alm[:n])[:] += np.imag(alm[:n])
np.imag(alm[:n])[:] = 0
alm = xpx.at(alm)[:n].set(xp.real(alm[:n]) + xp.imag(alm[:n]) + 0j)

# transform alm to maps
# can be performed in place on the temporary alm array
Expand Down Expand Up @@ -596,8 +595,8 @@ def spectra_indices(n: int, *, xp: ModuleType | None = None) -> IntArray:
"""
xp = _utils.default_xp() if xp is None else xp

i, j = xp.tril_indices(n)
return xp.asarray([i, i - j]).T
i, j = uxpx.tril_indices(n, xp=xp)
return xp.stack([i, i - j]).T


def effective_cls(
Expand Down Expand Up @@ -681,9 +680,9 @@ def effective_cls(
for i1 in range(n)
for i2 in range(n)
)
out[j1 + j2 + (...,)] = cl
out = xpx.at(out)[j1 + j2 + (...,)].set(cl)
if weights2 is weights1 and j1 != j2:
out[j2 + j1 + (...,)] = cl
out = xpx.at(out)[j2 + j1 + (...,)].set(cl)
return out


Expand Down Expand Up @@ -796,6 +795,9 @@ def solve_gaussian_spectra(
msg = "mismatch between number of fields and spectra"
raise ValueError(msg)

if len(spectra) == 0:
return []

gls = []
for i, j, cl in enumerate_spectra(spectra):
if cl.size > 0:
Expand Down Expand Up @@ -832,7 +834,7 @@ def generate(
nside: int,
*,
ncorr: int | None = None,
rng: np.random.Generator | None = None,
rng: UnifiedGenerator | None = None,
) -> Iterator[AnyArray]:
"""
Sample random fields from Gaussian angular power spectra.
Expand Down Expand Up @@ -876,7 +878,8 @@ def generate(
msg = "mismatch between number of fields and gls"
raise ValueError(msg)

variances = (cltovar(getcl(gls, i, i)) for i in range(n))
# cltovar requires numpy but getcl maintains xp, so conversion is required
variances = (cltovar(np.asarray(getcl(gls, i, i))) for i in range(n))
grf = _generate_grf(gls, nside, ncorr=ncorr, rng=rng)

for t, x, var in zip(fields, grf, variances, strict=True):
Expand Down Expand Up @@ -995,6 +998,8 @@ def cov_from_spectra(
Covariance matrix from the given spectra.

"""
xp = array_api_compat.array_namespace(*spectra, use_compat=False)

# recover the number of fields from the number of spectra
n = nfields_from_nspectra(len(spectra))

Expand All @@ -1004,14 +1009,17 @@ def cov_from_spectra(
# this is the covariance matrix of the spectra
# the leading dimension is k, then it is a n-by-n covariance matrix
# missing entries are zero, which is the default value
cov = np.zeros((k, n, n))
cov = xp.zeros((k, n, n), dtype=spectra[0].dtype)

# fill the matrix up by going through the spectra in order
# skip over entries that are None
# if the spectra are ragged, some entries at high ell may remain zero
# only fill the lower triangular part, everything is symmetric
for i, j, cl in enumerate_spectra(spectra):
cov[: cl.size, i, j] = cov[: cl.size, j, i] = cl.reshape(-1)[:k] # type: ignore[union-attr]
size = min(k, cl.size)
cl_flat = xp.reshape(cl, (-1,))
cov = xpx.at(cov)[:size, i, j].set(cl_flat[:size])
cov = xpx.at(cov)[:size, j, i].set(cl_flat[:size])

return cov

Expand Down
10 changes: 8 additions & 2 deletions glass/grf/_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,16 @@ def solve( # noqa: PLR0912, PLR0913
:func:`glass.grf.compute`: Direct computation for band-limited spectra.

"""
xp = cl.__array_namespace__()

# This function is difficult to port to the Array API so for now we work
# in NumPy and ultimately convert back at the end of it.
cl = np.asarray(cl)

if t2 is None:
t2 = t1

n = len(cl) # type: ignore[arg-type]
n = len(cl)
if pad < 0:
msg = "pad must be a positive integer"
raise ValueError(msg)
Expand Down Expand Up @@ -138,4 +144,4 @@ def solve( # noqa: PLR0912, PLR0913

gl, gt, rl, fl, clerr = gl_, gt_, rl_, fl_, clerr_

return gl, rl, info
return xp.asarray(gl), xp.asarray(rl), info
2 changes: 1 addition & 1 deletion glass/healpix.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def pixwin(
lmax: int | None = None,
pol: bool = False,
xp: ModuleType | None = None,
) -> tuple[FloatArray, ...]:
) -> FloatArray | tuple[FloatArray, ...]:
"""
Return the pixel window function for the given nside.

Expand Down
Loading
Loading