From 7f893c2c7fdded2ba1a34a4167ebdc3253633367 Mon Sep 17 00:00:00 2001 From: connoraird Date: Mon, 19 Jan 2026 14:26:59 +0000 Subject: [PATCH 01/26] Port iternorm --- glass/fields.py | 6 +++--- tests/core/test_fields.py | 4 ---- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/glass/fields.py b/glass/fields.py index df71f4018..9f9709fee 100644 --- a/glass/fields.py +++ b/glass/fields.py @@ -126,6 +126,7 @@ def iternorm( # Convert to list here to allow determining the namespace first = next(cov) # type: ignore[call-overload] xp = first.__array_namespace__() + uxpx = _utils.XPAdditions(xp) n = (size,) if isinstance(size, int) else size @@ -173,9 +174,8 @@ def iternorm( j = (j - 1) % k # compute new standard deviation - a_np = np.asarray(a, copy=True) - einsum_result_np = np.einsum("...i,...i", a_np, a_np) - s = x[..., 0] - xp.asarray(einsum_result_np, copy=True) + einsum_result_np = uxpx.einsum("...i,...i", a, a) + s = x[..., 0] - einsum_result_np if xp.any(s < 0): msg = "covariance matrix is not positive definite" raise ValueError(msg) diff --git a/tests/core/test_fields.py b/tests/core/test_fields.py index 55eb9dac2..b40fe407b 100644 --- a/tests/core/test_fields.py +++ b/tests/core/test_fields.py @@ -28,10 +28,6 @@ def not_triangle_numbers() -> list[int]: def test_iternorm(xp: ModuleType) -> None: - # Call jax version of iternorm once jax version is written - if xp.__name__ == "jax.numpy": - pytest.skip("Arrays in iternorm are not immutable, so do not support jax") - # check output shapes and types k = 2 From 498f9957d53df4f8b149f4a2716f3c7f7d1a1f47 Mon Sep 17 00:00:00 2001 From: connoraird Date: Mon, 19 Jan 2026 14:28:56 +0000 Subject: [PATCH 02/26] Port discretized_cls --- glass/fields.py | 15 ++++++++++----- tests/core/test_fields.py | 30 ++++++++++++++++++------------ 2 files changed, 28 insertions(+), 17 deletions(-) diff --git a/glass/fields.py b/glass/fields.py index 9f9709fee..be6f76534 100644 --- a/glass/fields.py +++ b/glass/fields.py @@ -269,25 +269,30 @@ def discretized_cls( If the length of the Cls array is not a triangle number. """ + if len(cls) == 0: # type: ignore[arg-type] + return [] + + xp = array_api_compat.array_namespace(*cls, use_compat=False) + if ncorr is not None: n = nfields_from_nspectra(len(cls)) cls = [ - cls[i * (i + 1) // 2 + j] if j <= ncorr else np.asarray([]) + cls[i * (i + 1) // 2 + j] if j <= ncorr else xp.asarray([]) for i in range(n) for j in range(i + 1) ] if nside is not None: - pw = hp.pixwin(nside, lmax=lmax, xp=np) + pw = hp.pixwin(nside, lmax=lmax, xp=xp) gls = [] for cl in cls: - if len(cl) > 0: # type: ignore[arg-type] + if cl.shape[0] > 0: if lmax is not None: cl = cl[: lmax + 1] # noqa: PLW2901 if nside is not None: - n = min(len(cl), len(pw)) # type: ignore[arg-type] - cl = cl[:n] * pw[:n] ** 2 # type: ignore[operator] # noqa: PLW2901 + n = min(cl.shape[0], pw.shape[0]) + cl = cl[:n] * pw[:n] ** 2 # noqa: PLW2901 gls.append(cl) return gls diff --git a/tests/core/test_fields.py b/tests/core/test_fields.py index b40fe407b..9c484132e 100644 --- a/tests/core/test_fields.py +++ b/tests/core/test_fields.py @@ -294,7 +294,10 @@ def test_lognormal_gls() -> None: assert len(out[2]) == 3 -def test_discretized_cls(compare: type[Compare]) -> None: +def test_discretized_cls( + compare: type[Compare], + xp: ModuleType, +) -> None: # empty cls result = glass.discretized_cls([]) @@ -303,39 +306,42 @@ def test_discretized_cls(compare: type[Compare]) -> None: # power spectra truncated at lmax + 1 if lmax provided result = glass.discretized_cls( - [np.arange(10), np.arange(10), np.arange(10)], + [xp.arange(10), xp.arange(10), xp.arange(10)], lmax=5, ) for cl in result: - assert len(cl) == 6 + assert cl.shape[0] == 6 # check ValueError for triangle number with pytest.raises(ValueError, match="invalid number of spectra:"): - glass.discretized_cls([np.arange(10), np.arange(10)], ncorr=1) + glass.discretized_cls([xp.arange(10), xp.arange(10)], ncorr=1) # ncorr not None - cls: AngularPowerSpectra = [np.arange(10), np.arange(10), np.arange(10)] + cls: AngularPowerSpectra = [xp.arange(10), xp.arange(10), xp.arange(10)] ncorr = 0 result = glass.discretized_cls(cls, ncorr=ncorr) - assert len(result[0]) == 10 - assert len(result[1]) == 10 - assert len(result[2]) == 0 # third correlation should be removed + assert result[0].shape[0] == 10 + assert result[1].shape[0] == 10 + assert result[2].shape[0] == 0 # third correlation should be removed # check if pixel window function was applied correctly with nside not None nside = 4 - pw = hp.pixwin(nside, lmax=7, xp=np) + pw = hp.pixwin(nside, lmax=7, xp=xp) - result = glass.discretized_cls([[], np.ones(10), np.ones(10)], nside=nside) + result = glass.discretized_cls( + [xp.asarray([]), xp.ones(10), xp.ones(10)], + nside=nside, + ) for cl in result: - n = min(len(cl), len(pw)) - expected = np.ones(n) * pw[:n] ** 2 # type: ignore[operator] + n = min(cl.shape[0], pw.shape[0]) + expected = xp.ones(n) * pw[:n] ** 2 # type: ignore[operator] compare.assert_allclose(cl[:n], expected) From e931b16f5078764cb03f5e0bbfe21254dae3c314 Mon Sep 17 00:00:00 2001 From: connoraird Date: Mon, 19 Jan 2026 15:01:00 +0000 Subject: [PATCH 03/26] Port cov_from_spectra --- glass/fields.py | 9 +++++++-- tests/core/test_fields.py | 14 +++++++++----- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/glass/fields.py b/glass/fields.py index be6f76534..b3ced16a0 100644 --- a/glass/fields.py +++ b/glass/fields.py @@ -999,6 +999,8 @@ def cov_from_spectra( Covariance matrix from the given spectra. """ + xp = array_api_compat.array_namespace(*spectra, use_compat=False) + # recover the number of fields from the number of spectra n = nfields_from_nspectra(len(spectra)) @@ -1008,14 +1010,17 @@ def cov_from_spectra( # this is the covariance matrix of the spectra # the leading dimension is k, then it is a n-by-n covariance matrix # missing entries are zero, which is the default value - cov = np.zeros((k, n, n)) + cov = xp.zeros((k, n, n), dtype=spectra[0].dtype) # fill the matrix up by going through the spectra in order # skip over entries that are None # if the spectra are ragged, some entries at high ell may remain zero # only fill the lower triangular part, everything is symmetric for i, j, cl in enumerate_spectra(spectra): - cov[: cl.size, i, j] = cov[: cl.size, j, i] = cl.reshape(-1)[:k] # type: ignore[union-attr] + size = min(k, cl.size) + cl_flat = xp.reshape(cl, (-1,)) + cov = xpx.at(cov)[:size, i, j].set(cl_flat[:size]) + cov = xpx.at(cov)[:size, j, i].set(cl_flat[:size]) return cov diff --git a/tests/core/test_fields.py b/tests/core/test_fields.py index 9c484132e..7143c95fe 100644 --- a/tests/core/test_fields.py +++ b/tests/core/test_fields.py @@ -673,17 +673,21 @@ def test_lognormal_shift_hilbert2011(compare: type[Compare]) -> None: compare.assert_allclose(shifts, check, atol=1e-4, rtol=1e-4) -def test_cov_from_spectra(compare: type[Compare]) -> None: - spectra: AngularPowerSpectra = np.asarray( - [ +def test_cov_from_spectra( + compare: type[Compare], + xp: ModuleType, +) -> None: + spectra: AngularPowerSpectra = [ + xp.asarray(x) + for x in [ [110, 111, 112, 113], [220, 221, 222, 223], [210, 211, 212, 213], [330, 331, 332, 333], [320, 321, 322, 323], [310, 311, 312, 313], - ], - ) + ] + ] compare.assert_array_equal( glass.cov_from_spectra(spectra), From 962137e973dbac517fd129cce67bdce7356eb7fb Mon Sep 17 00:00:00 2001 From: connoraird Date: Mon, 19 Jan 2026 17:00:32 +0000 Subject: [PATCH 04/26] Port lognormal_gls --- glass/fields.py | 3 +++ glass/grf/_solver.py | 7 +++++-- tests/core/test_fields.py | 15 ++++++++------- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/glass/fields.py b/glass/fields.py index b3ced16a0..d58f4c80e 100644 --- a/glass/fields.py +++ b/glass/fields.py @@ -800,6 +800,9 @@ def solve_gaussian_spectra( msg = "mismatch between number of fields and spectra" raise ValueError(msg) + if len(spectra) == 0: # type: ignore[arg-type] + return [] + gls = [] for i, j, cl in enumerate_spectra(spectra): if cl.size > 0: diff --git a/glass/grf/_solver.py b/glass/grf/_solver.py index 8533b6d43..a25123c44 100644 --- a/glass/grf/_solver.py +++ b/glass/grf/_solver.py @@ -80,10 +80,13 @@ def solve( # noqa: PLR0912, PLR0913 :func:`glass.grf.compute`: Direct computation for band-limited spectra. """ + xp = cl.__array_namespace__() + cl = np.asarray(cl) + if t2 is None: t2 = t1 - n = len(cl) # type: ignore[arg-type] + n = len(cl) if pad < 0: msg = "pad must be a positive integer" raise ValueError(msg) @@ -138,4 +141,4 @@ def solve( # noqa: PLR0912, PLR0913 gl, gt, rl, fl, clerr = gl_, gt_, rl_, fl_, clerr_ - return gl, rl, info + return xp.asarray(gl), xp.asarray(rl), info diff --git a/tests/core/test_fields.py b/tests/core/test_fields.py index 7143c95fe..a596202a9 100644 --- a/tests/core/test_fields.py +++ b/tests/core/test_fields.py @@ -273,7 +273,7 @@ def test_cls2cov_no_jax(compare: type[Compare], xpb: ModuleType) -> None: compare.assert_allclose(cov2_copy, cov3) -def test_lognormal_gls() -> None: +def test_lognormal_gls(xp: ModuleType) -> None: shift = 2 # empty cls @@ -282,16 +282,17 @@ def test_lognormal_gls() -> None: # check output shape - assert len(glass.lognormal_gls([np.linspace(1, 5, 5)], shift)) == 1 - assert len(glass.lognormal_gls([np.linspace(1, 5, 5)], shift)[0]) == 5 + out = glass.lognormal_gls([xp.linspace(1, 5, 5)], shift) + assert len(out) == 1 + assert out[0].shape[0] == 5 - inp = [np.linspace(1, 6, 5), np.linspace(1, 5, 4), np.linspace(1, 4, 3)] + inp = [xp.linspace(1, 6, 5), xp.linspace(1, 5, 4), xp.linspace(1, 4, 3)] out = glass.lognormal_gls(inp, shift) assert len(out) == 3 - assert len(out[0]) == 5 - assert len(out[1]) == 4 - assert len(out[2]) == 3 + assert out[0].shape[0] == 5 + assert out[1].shape[0] == 4 + assert out[2].shape[0] == 3 def test_discretized_cls( From ba5714e400f0f23f041a4d723dc05b1b3866b366 Mon Sep 17 00:00:00 2001 From: connoraird Date: Mon, 19 Jan 2026 17:02:54 +0000 Subject: [PATCH 05/26] Add jax support to effective_cls --- glass/fields.py | 4 ++-- tests/core/test_fields.py | 4 ---- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/glass/fields.py b/glass/fields.py index d58f4c80e..0f29a867e 100644 --- a/glass/fields.py +++ b/glass/fields.py @@ -685,9 +685,9 @@ def effective_cls( for i1 in range(n) for i2 in range(n) ) - out[j1 + j2 + (...,)] = cl + out = xpx.at(out)[j1 + j2 + (...,)].set(cl) if weights2 is weights1 and j1 != j2: - out[j2 + j1 + (...,)] = cl + out = xpx.at(out)[j2 + j1 + (...,)].set(cl) return out diff --git a/tests/core/test_fields.py b/tests/core/test_fields.py index a596202a9..23b532856 100644 --- a/tests/core/test_fields.py +++ b/tests/core/test_fields.py @@ -347,10 +347,6 @@ def test_discretized_cls( def test_effective_cls(compare: type[Compare], xp: ModuleType) -> None: - # Call jax version of iternorm once jax version is written - if xp.__name__ == "jax.numpy": - pytest.skip("Arrays in effective_cls are not immutable, so do not support jax") - # empty cls result = glass.effective_cls([], xp.asarray([])) From 287e19b497564111f58ef7a45c5931760623db9f Mon Sep 17 00:00:00 2001 From: connoraird Date: Mon, 19 Jan 2026 17:15:18 +0000 Subject: [PATCH 06/26] Port generate and _generate_grf --- glass/fields.py | 4 +++- tests/core/test_fields.py | 9 ++++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/glass/fields.py b/glass/fields.py index 0f29a867e..9d4a36aad 100644 --- a/glass/fields.py +++ b/glass/fields.py @@ -878,12 +878,14 @@ def generate( Sampled random fields. """ + xp = array_api_compat.array_namespace(*gls, use_compat=False) + n = len(fields) if len(gls) != n * (n + 1) // 2: msg = "mismatch between number of fields and gls" raise ValueError(msg) - variances = (cltovar(getcl(gls, i, i)) for i in range(n)) + variances = (xp.asarray(cltovar(np.asarray(getcl(gls, i, i)))) for i in range(n)) grf = _generate_grf(gls, nside, ncorr=ncorr, rng=rng) for t, x, var in zip(fields, grf, variances, strict=True): diff --git a/tests/core/test_fields.py b/tests/core/test_fields.py index 23b532856..54b2f1f01 100644 --- a/tests/core/test_fields.py +++ b/tests/core/test_fields.py @@ -422,19 +422,22 @@ def test_generate_lognormal(xp: ModuleType) -> None: glass.generate_lognormal([xp.asarray([1.0, 0.5, 0.1])], 4) -def test_generate(compare: type[Compare]) -> None: +def test_generate( + compare: type[Compare], + xp: ModuleType, +) -> None: # shape mismatch error fields = [lambda x, var: x, lambda x, var: x] # noqa: ARG005 with pytest.raises(ValueError, match="mismatch between number of fields and gls"): - list(glass.generate(fields, [np.ones(10), np.ones(10)], nside=16)) + list(glass.generate(fields, [xp.ones(10), xp.ones(10)], nside=16)) # check output shape nside = 16 npix = hp.nside2npix(nside) - gls: AngularPowerSpectra = [np.ones(10), np.ones(10), np.ones(10)] + gls: AngularPowerSpectra = [xp.ones(10), xp.ones(10), xp.ones(10)] result = list(glass.generate(fields, gls, nside=nside)) From 79303465f108272f9fe109684951273ed6b4f2f3 Mon Sep 17 00:00:00 2001 From: connoraird Date: Mon, 19 Jan 2026 17:27:20 +0000 Subject: [PATCH 07/26] Test enumerate_spectra for all array backends --- tests/core/test_fields.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/core/test_fields.py b/tests/core/test_fields.py index 54b2f1f01..3b74370a0 100644 --- a/tests/core/test_fields.py +++ b/tests/core/test_fields.py @@ -463,7 +463,10 @@ def test_generate( compare.assert_allclose(result[1], result[0] ** 2, atol=1e-05) -def test_getcl(compare: type[Compare], xp: ModuleType) -> None: +def test_getcl( + compare: type[Compare], + xp: ModuleType, +) -> None: # make a mock Cls array with the index pairs as entries cls: AngularPowerSpectra = [ xp.asarray([i, j], dtype=xp.float64) @@ -508,12 +511,15 @@ def test_nfields_from_nspectra(not_triangle_numbers: list[int]) -> None: glass.nfields_from_nspectra(t) -def test_enumerate_spectra() -> None: +def test_enumerate_spectra( + compare: type[Compare], + xp: ModuleType, +) -> None: n = 100 tn = n * (n + 1) // 2 # create mock spectra with 1 element counting to tn - spectra: AngularPowerSpectra = np.arange(tn).reshape(tn, 1) + spectra: AngularPowerSpectra = [xp.asarray(x) for x in range(tn)] # this is the expected order of indices indices = [(i, j) for i in range(n) for j in range(i, -1, -1)] @@ -523,7 +529,7 @@ def test_enumerate_spectra() -> None: # go through expected indices and values and compare for k, (i, j) in enumerate(indices): - assert next(it) == (i, j, k) + compare.assert_allclose(next(it), (i, j, k)) # make sure iterator is exhausted with pytest.raises(StopIteration): From f4a9c53e7a1873eedb7ba5f31ce206bbac0b07ca Mon Sep 17 00:00:00 2001 From: connoraird Date: Mon, 19 Jan 2026 17:32:09 +0000 Subject: [PATCH 08/26] Test check_posdef_spectra for all array backends --- tests/core/test_fields.py | 49 ++++++++++++++++----------------------- 1 file changed, 20 insertions(+), 29 deletions(-) diff --git a/tests/core/test_fields.py b/tests/core/test_fields.py index 3b74370a0..f0b6f0a1b 100644 --- a/tests/core/test_fields.py +++ b/tests/core/test_fields.py @@ -295,10 +295,7 @@ def test_lognormal_gls(xp: ModuleType) -> None: assert out[2].shape[0] == 3 -def test_discretized_cls( - compare: type[Compare], - xp: ModuleType, -) -> None: +def test_discretized_cls(compare: type[Compare], xp: ModuleType) -> None: # empty cls result = glass.discretized_cls([]) @@ -422,10 +419,7 @@ def test_generate_lognormal(xp: ModuleType) -> None: glass.generate_lognormal([xp.asarray([1.0, 0.5, 0.1])], 4) -def test_generate( - compare: type[Compare], - xp: ModuleType, -) -> None: +def test_generate(compare: type[Compare], xp: ModuleType) -> None: # shape mismatch error fields = [lambda x, var: x, lambda x, var: x] # noqa: ARG005 @@ -463,10 +457,7 @@ def test_generate( compare.assert_allclose(result[1], result[0] ** 2, atol=1e-05) -def test_getcl( - compare: type[Compare], - xp: ModuleType, -) -> None: +def test_getcl(compare: type[Compare], xp: ModuleType) -> None: # make a mock Cls array with the index pairs as entries cls: AngularPowerSpectra = [ xp.asarray([i, j], dtype=xp.float64) @@ -511,10 +502,7 @@ def test_nfields_from_nspectra(not_triangle_numbers: list[int]) -> None: glass.nfields_from_nspectra(t) -def test_enumerate_spectra( - compare: type[Compare], - xp: ModuleType, -) -> None: +def test_enumerate_spectra(compare: type[Compare], xp: ModuleType) -> None: n = 100 tn = n * (n + 1) // 2 @@ -769,36 +757,39 @@ def test_cov_from_spectra( ) -def test_check_posdef_spectra() -> None: +def test_check_posdef_spectra(xp: ModuleType) -> None: # posdef spectra assert glass.check_posdef_spectra( - np.asarray( - [ + [ + xp.asarray(x) + for x in [ [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [0.9, 0.9, 0.9], - ], - ), + ] + ] ) # semidef spectra assert glass.check_posdef_spectra( - np.asarray( - [ + [ + xp.asarray(x) + for x in [ [1.0, 1.0, 1.0], [1.0, 1.0, 0.0], [0.9, 1.0, 0.0], - ], - ), + ] + ] ) # indef spectra assert not glass.check_posdef_spectra( - np.asarray( - [ + [ + xp.asarray(x) + for x in [ [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.1, 1.1, 1.1], - ], - ), + ] + ] ) From 9a7549f2a06bc0e29f25f4eedfbf6d76948bfcc0 Mon Sep 17 00:00:00 2001 From: connoraird Date: Tue, 20 Jan 2026 09:58:18 +0000 Subject: [PATCH 09/26] Test regularized_spectra with all array backends --- tests/core/test_fields.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/core/test_fields.py b/tests/core/test_fields.py index f0b6f0a1b..106e22a1d 100644 --- a/tests/core/test_fields.py +++ b/tests/core/test_fields.py @@ -17,6 +17,7 @@ from pytest_mock import MockerFixture from glass._types import AngularPowerSpectra + from glass._types import UnifiedGenerator from tests.fixtures.helper_classes import Compare HAVE_JAX = importlib.util.find_spec("jax") is not None @@ -795,9 +796,9 @@ def test_check_posdef_spectra(xp: ModuleType) -> None: def test_regularized_spectra( mocker: MockerFixture, - rng: np.random.Generator, + urng: UnifiedGenerator, ) -> None: - spectra: AngularPowerSpectra = rng.random(size=(6, 101)) + spectra: AngularPowerSpectra = [urng.random(101) for _ in range(6)] # test method "nearest" cov_nearest = mocker.spy(glass.algorithm, "cov_nearest") From 3078c7c30ab4d5142dd71d0e71fea4044c851aea Mon Sep 17 00:00:00 2001 From: connoraird Date: Thu, 22 Jan 2026 13:04:51 +0000 Subject: [PATCH 10/26] Correct empty array test and fix pw type --- glass/fields.py | 6 +++--- tests/core/test_fields.py | 15 ++++++++------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/glass/fields.py b/glass/fields.py index 9d4a36aad..9f55f30e4 100644 --- a/glass/fields.py +++ b/glass/fields.py @@ -269,7 +269,7 @@ def discretized_cls( If the length of the Cls array is not a triangle number. """ - if len(cls) == 0: # type: ignore[arg-type] + if len(cls) == 0: return [] xp = array_api_compat.array_namespace(*cls, use_compat=False) @@ -283,7 +283,7 @@ def discretized_cls( ] if nside is not None: - pw = hp.pixwin(nside, lmax=lmax, xp=xp) + pw: FloatArray = hp.pixwin(nside, lmax=lmax, xp=xp) gls = [] for cl in cls: @@ -800,7 +800,7 @@ def solve_gaussian_spectra( msg = "mismatch between number of fields and spectra" raise ValueError(msg) - if len(spectra) == 0: # type: ignore[arg-type] + if len(spectra) == 0: return [] gls = [] diff --git a/tests/core/test_fields.py b/tests/core/test_fields.py index 106e22a1d..3bf0d9897 100644 --- a/tests/core/test_fields.py +++ b/tests/core/test_fields.py @@ -3,7 +3,6 @@ import importlib.util from typing import TYPE_CHECKING -import numpy as np import pytest import glass @@ -16,8 +15,7 @@ from pytest_mock import MockerFixture - from glass._types import AngularPowerSpectra - from glass._types import UnifiedGenerator + from glass._types import AngularPowerSpectra, FloatArray, UnifiedGenerator from tests.fixtures.helper_classes import Compare HAVE_JAX = importlib.util.find_spec("jax") is not None @@ -331,7 +329,7 @@ def test_discretized_cls(compare: type[Compare], xp: ModuleType) -> None: nside = 4 - pw = hp.pixwin(nside, lmax=7, xp=xp) + pw: FloatArray = hp.pixwin(nside, lmax=7, xp=xp) result = glass.discretized_cls( [xp.asarray([]), xp.ones(10), xp.ones(10)], @@ -340,7 +338,7 @@ def test_discretized_cls(compare: type[Compare], xp: ModuleType) -> None: for cl in result: n = min(cl.shape[0], pw.shape[0]) - expected = xp.ones(n) * pw[:n] ** 2 # type: ignore[operator] + expected = xp.ones(n) * pw[:n] ** 2 compare.assert_allclose(cl[:n], expected) @@ -406,8 +404,11 @@ def test_generate_grf(compare: type[Compare]) -> None: compare.assert_allclose(new_gaussian_fields[0], gaussian_fields[0]) - with pytest.raises(ValueError, match="all gls are empty"): - list(glass.fields._generate_grf([np.asarray([])], nside)) + with pytest.raises( + ValueError, + match="all gls are empty", + ): + list(glass.fields._generate_grf([xp.asarray([])], nside)) def test_generate_gaussian(xp: ModuleType) -> None: From 59c448fc97714eb3fcbc32b0962804f8feaeb01a Mon Sep 17 00:00:00 2001 From: connoraird Date: Tue, 27 Jan 2026 14:14:27 +0000 Subject: [PATCH 11/26] Simplify iternorm port to be just convert on input and output --- glass/fields.py | 49 +++++++++++++++++++++---------------------------- 1 file changed, 21 insertions(+), 28 deletions(-) diff --git a/glass/fields.py b/glass/fields.py index 9f55f30e4..07d176023 100644 --- a/glass/fields.py +++ b/glass/fields.py @@ -126,26 +126,21 @@ def iternorm( # Convert to list here to allow determining the namespace first = next(cov) # type: ignore[call-overload] xp = first.__array_namespace__() - uxpx = _utils.XPAdditions(xp) n = (size,) if isinstance(size, int) else size - m = xp.zeros((*n, k, k)) - a = xp.zeros((*n, k)) - s = xp.zeros((*n,)) + m = np.zeros((*n, k, k)) + a = np.zeros((*n, k)) + s = np.zeros((*n,)) q = (*n, k + 1) j = 0 if k > 0 else None # We must use cov_expanded here as cov has been consumed to determine the namespace for i, x in enumerate(itertools.chain([first], cov)): - # Ideally would be xp.asanyarray but this does not yet exist. The key difference - # between the two in numpy is that asanyarray maintains subclasses of NDArray - # whereas asarray will return the base class NDArray. Currently, we don't seem - # to pass a subclass of NDArray so this, so it might be okay - x = xp.asarray(x) # noqa: PLW2901 + x = np.asarray(x) # noqa: PLW2901 if x.shape != q: try: - x = xp.broadcast_to(x, q) # noqa: PLW2901 + x = np.broadcast_to(x, q) # noqa: PLW2901 except ValueError: msg = f"covariance row {i}: shape {x.shape} cannot be broadcast to {q}" raise TypeError(msg) from None @@ -154,35 +149,33 @@ def iternorm( if j is not None: # compute new entries of matrix A m[..., :, j] = 0 - m[..., j : j + 1, :] = xp.matmul(a[..., xp.newaxis, :], m) - m[..., j, j] = xp.where(s != 0, -1, s) - # To ensure we don't divide by zero or nan we use a mask to only divide the - # appropriate values of m and s - m_j = m[..., j, :] - s_broadcast = xp.broadcast_to(s[..., xp.newaxis], m_j.shape) - mask = (m_j != 0) & (s_broadcast != 0) & ~xp.isnan(s_broadcast) - m_j[mask] = xp.divide(m_j[mask], -s_broadcast[mask]) - m[..., j, :] = m_j + m[..., j : j + 1, :] = np.matmul(a[..., np.newaxis, :], m) + m[..., j, j] = np.where(s != 0, -1, 0) + np.divide( + m[..., j, :], + -s[..., np.newaxis], + where=(m[..., j, :] != 0), + out=m[..., j, :], + ) # compute new vector a - c = x[..., 1:, xp.newaxis] - a = xp.matmul(m[..., :j], c[..., k - j :, :]) - a += xp.matmul(m[..., j:], c[..., : k - j, :]) - a = xp.reshape(a, (*n, k)) + c = x[..., 1:, np.newaxis] + a = np.matmul(m[..., :j], c[..., k - j :, :]) + a += np.matmul(m[..., j:], c[..., : k - j, :]) + a = np.reshape(a, (*n, k)) # next rolling index j = (j - 1) % k # compute new standard deviation - einsum_result_np = uxpx.einsum("...i,...i", a, a) - s = x[..., 0] - einsum_result_np - if xp.any(s < 0): + s = x[..., 0] - np.einsum("...i,...i", a, a) + if np.any(s < 0): msg = "covariance matrix is not positive definite" raise ValueError(msg) - s = xp.sqrt(s) + s = np.sqrt(s) # yield the next index, vector a, and standard deviation s - yield j, a, s + yield j, xp.asarray(a), xp.asarray(s) def cls2cov( From 9699672483f43664c18a1538b17de70fbe98bc82 Mon Sep 17 00:00:00 2001 From: connoraird Date: Tue, 27 Jan 2026 14:26:20 +0000 Subject: [PATCH 12/26] Port _generate_grf --- glass/fields.py | 17 +++++++++-------- tests/core/test_fields.py | 13 +++++-------- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/glass/fields.py b/glass/fields.py index 07d176023..491aaaeea 100644 --- a/glass/fields.py +++ b/glass/fields.py @@ -364,8 +364,10 @@ def _generate_grf( If all gls are empty. """ + xp = array_api_compat.array_namespace(*gls, use_compat=False) + if rng is None: - rng = _rng.rng_dispatcher(xp=np) + rng = _rng.rng_dispatcher(xp=xp) # number of gls and number of fields ngls = len(gls) @@ -376,7 +378,7 @@ def _generate_grf( ncorr = ngrf - 1 # number of modes - n = max((len(gl) for gl in gls), default=0) # type: ignore[arg-type] + n = max((gl.shape[0] for gl in gls), default=0) if n == 0: msg = "all gls are empty" raise ValueError(msg) @@ -385,8 +387,8 @@ def _generate_grf( cov = cls2cov(gls, n, ngrf, ncorr) # working arrays for the iterative sampling - z = np.zeros(n * (n + 1) // 2, dtype=np.complex128) - y = np.zeros((n * (n + 1) // 2, ncorr), dtype=np.complex128) + z = xp.zeros(n * (n + 1) // 2, dtype=xp.complex128) + y = xp.zeros((n * (n + 1) // 2, ncorr), dtype=xp.complex128) # generate the conditional normal distribution for iterative sampling conditional_dist = iternorm(ncorr, cov, size=n) @@ -395,7 +397,7 @@ def _generate_grf( for j, a, s in conditional_dist: # standard normal random variates for alm # sample real and imaginary parts, then view as complex number - rng.standard_normal(n * (n + 1), np.float64, z.view(np.float64)) # type: ignore[call-arg] + rng.standard_normal((n * (n + 1),)) + (1j * rng.standard_normal((n * (n + 1),))) # scale by standard deviation of the conditional distribution # variance is distributed over real and imaginary part @@ -407,13 +409,12 @@ def _generate_grf( # store the standard normal in y array at the indicated index if j is not None: - y[:, j] = z + y = xpx.at(y)[:, j].set(z) alm = _glass_to_healpix_alm(alm) # modes with m = 0 are real-valued and come first in array - np.real(alm[:n])[:] += np.imag(alm[:n]) - np.imag(alm[:n])[:] = 0 + alm = xpx.at(alm)[:n].set(xp.real(alm[:n])[:] + xp.imag(alm[:n]) + 0j) # transform alm to maps # can be performed in place on the temporary alm array diff --git a/tests/core/test_fields.py b/tests/core/test_fields.py index 3bf0d9897..a0ed2ce84 100644 --- a/tests/core/test_fields.py +++ b/tests/core/test_fields.py @@ -379,8 +379,8 @@ def test_effective_cls(compare: type[Compare], xp: ModuleType) -> None: assert result.shape == (1, 1, 15) -def test_generate_grf(compare: type[Compare]) -> None: - gls: AngularPowerSpectra = [np.asarray([1.0, 0.5, 0.1])] +def test_generate_grf(compare: type[Compare], xp: ModuleType) -> None: + gls = [xp.asarray([1.0, 0.5, 0.1])] nside = 4 ncorr = 1 @@ -389,13 +389,13 @@ def test_generate_grf(compare: type[Compare]) -> None: assert gaussian_fields[0].shape == (hp.nside2npix(nside),) # requires resetting the RNG for reproducibility - rng = _rng.rng_dispatcher(xp=np) + rng = _rng.rng_dispatcher(xp=xp) gaussian_fields = list(glass.fields._generate_grf(gls, nside, rng=rng)) assert gaussian_fields[0].shape == (hp.nside2npix(nside),) # requires resetting the RNG for reproducibility - rng = _rng.rng_dispatcher(xp=np) + rng = _rng.rng_dispatcher(xp=xp) new_gaussian_fields = list( glass.fields._generate_grf(gls, nside, ncorr=ncorr, rng=rng), ) @@ -404,10 +404,7 @@ def test_generate_grf(compare: type[Compare]) -> None: compare.assert_allclose(new_gaussian_fields[0], gaussian_fields[0]) - with pytest.raises( - ValueError, - match="all gls are empty", - ): + with pytest.raises(ValueError, match="all gls are empty"): list(glass.fields._generate_grf([xp.asarray([])], nside)) From 15965f3aa3a166e3572dad56eb779aed5cb3e0dc Mon Sep 17 00:00:00 2001 From: connoraird Date: Tue, 27 Jan 2026 14:39:20 +0000 Subject: [PATCH 13/26] Ensure spectra_indices is array api compatible --- glass/fields.py | 2 +- tests/core/test_fields.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/glass/fields.py b/glass/fields.py index 491aaaeea..4fc59577a 100644 --- a/glass/fields.py +++ b/glass/fields.py @@ -592,7 +592,7 @@ def spectra_indices(n: int, *, xp: ModuleType | None = None) -> IntArray: """ xp = _utils.default_xp() if xp is None else xp - i, j = xp.tril_indices(n) + i, j = np.tril_indices(n) return xp.asarray([i, i - j]).T diff --git a/tests/core/test_fields.py b/tests/core/test_fields.py index a0ed2ce84..635b85b88 100644 --- a/tests/core/test_fields.py +++ b/tests/core/test_fields.py @@ -525,10 +525,13 @@ def test_enumerate_spectra(compare: type[Compare], xp: ModuleType) -> None: def test_spectra_indices(compare: type[Compare], xp: ModuleType) -> None: compare.assert_array_equal(glass.spectra_indices(0), xp.zeros((0, 2))) - compare.assert_array_equal(glass.spectra_indices(1), [[0, 0]]) - compare.assert_array_equal(glass.spectra_indices(2), [[0, 0], [1, 1], [1, 0]]) + compare.assert_array_equal(glass.spectra_indices(0, xp=xp), xp.zeros((0, 2))) + compare.assert_array_equal(glass.spectra_indices(1, xp=xp), [[0, 0]]) compare.assert_array_equal( - glass.spectra_indices(3), + glass.spectra_indices(2, xp=xp), [[0, 0], [1, 1], [1, 0]] + ) + compare.assert_array_equal( + glass.spectra_indices(3, xp=xp), [[0, 0], [1, 1], [1, 0], [2, 2], [2, 1], [2, 0]], ) From d5b1ffa128eee361574e3cb0daa9339209521cbf Mon Sep 17 00:00:00 2001 From: connoraird Date: Tue, 27 Jan 2026 14:40:01 +0000 Subject: [PATCH 14/26] Ensure test_glass_to_healpix_alm uses complex arrays --- tests/core/test_fields.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/core/test_fields.py b/tests/core/test_fields.py index 635b85b88..0cc755cb3 100644 --- a/tests/core/test_fields.py +++ b/tests/core/test_fields.py @@ -651,11 +651,11 @@ def test_healpix_to_glass_spectra(compare: type[Compare]) -> None: def test_glass_to_healpix_alm(compare: type[Compare], xp: ModuleType) -> None: - inp = xp.asarray([00, 10, 11, 20, 21, 22, 30, 31, 32, 33]) + inp = xp.asarray([00, 10, 11, 20, 21, 22, 30, 31, 32, 33], dtype=xp.complex128) out = glass.fields._glass_to_healpix_alm(inp) compare.assert_array_equal( out, - xp.asarray([00, 10, 20, 30, 11, 21, 31, 22, 32, 33]), + xp.asarray([00, 10, 20, 30, 11, 21, 31, 22, 32, 33], dtype=xp.complex128), ) From 27a06c42a985d16159a3d2e88ff3475ebd1e7d22 Mon Sep 17 00:00:00 2001 From: connoraird Date: Wed, 28 Jan 2026 11:35:45 +0000 Subject: [PATCH 15/26] Fix _generate_grf by setting value of z --- glass/fields.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/glass/fields.py b/glass/fields.py index ac7429467..2fa9dc97e 100644 --- a/glass/fields.py +++ b/glass/fields.py @@ -388,8 +388,8 @@ def _generate_grf( cov = cls2cov(gls, n, ngrf, ncorr) # working arrays for the iterative sampling - z = xp.zeros(n * (n + 1) // 2, dtype=xp.complex128) - y = xp.zeros((n * (n + 1) // 2, ncorr), dtype=xp.complex128) + z_size = n * (n + 1) // 2 + y = xp.zeros((z_size, ncorr), dtype=xp.complex128) # generate the conditional normal distribution for iterative sampling conditional_dist = iternorm(ncorr, cov, size=n) @@ -398,7 +398,7 @@ def _generate_grf( for j, a, s in conditional_dist: # standard normal random variates for alm # sample real and imaginary parts, then view as complex number - rng.standard_normal((n * (n + 1),)) + (1j * rng.standard_normal((n * (n + 1),))) + z = rng.standard_normal((z_size,)) + (1j * rng.standard_normal((z_size,))) # scale by standard deviation of the conditional distribution # variance is distributed over real and imaginary part From 93b3ac8c16db8e7bea1283712b05ec598facad98 Mon Sep 17 00:00:00 2001 From: connoraird Date: Wed, 28 Jan 2026 11:54:22 +0000 Subject: [PATCH 16/26] Remove redundant conversion to array backend --- glass/fields.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/glass/fields.py b/glass/fields.py index 2fa9dc97e..5fc9861cc 100644 --- a/glass/fields.py +++ b/glass/fields.py @@ -832,7 +832,7 @@ def generate( nside: int, *, ncorr: int | None = None, - rng: np.random.Generator | None = None, + rng: UnifiedGenerator | None = None, ) -> Iterator[AnyArray]: """ Sample random fields from Gaussian angular power spectra. @@ -871,14 +871,12 @@ def generate( Sampled random fields. """ - xp = array_api_compat.array_namespace(*gls, use_compat=False) - n = len(fields) if len(gls) != n * (n + 1) // 2: msg = "mismatch between number of fields and gls" raise ValueError(msg) - variances = (xp.asarray(cltovar(np.asarray(getcl(gls, i, i)))) for i in range(n)) + variances = (cltovar(np.asarray(getcl(gls, i, i))) for i in range(n)) grf = _generate_grf(gls, nside, ncorr=ncorr, rng=rng) for t, x, var in zip(fields, grf, variances, strict=True): From d47a67c0a57ab8ad55f35294a28c57978fb95a22 Mon Sep 17 00:00:00 2001 From: connoraird Date: Wed, 28 Jan 2026 12:07:29 +0000 Subject: [PATCH 17/26] Add tril_indices to uxpx --- glass/_array_api_utils.py | 40 ++++++++++++++++++++++++++++++++++++++- glass/fields.py | 2 +- 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/glass/_array_api_utils.py b/glass/_array_api_utils.py index e73236d67..3986fc95f 100644 --- a/glass/_array_api_utils.py +++ b/glass/_array_api_utils.py @@ -27,7 +27,7 @@ import numpy as np - from glass._types import AnyArray, DTypeLike + from glass._types import AnyArray, DTypeLike, IntArray class CompatibleBackendNotFoundError(Exception): @@ -573,3 +573,41 @@ def ndindex(shape: tuple[int, ...], *, xp: ModuleType) -> np.ndindex: msg = "the array backend in not supported" raise NotImplementedError(msg) + + @staticmethod + def tril_indices( + n: int, + *, + k: int = 0, + m: int | None = None, + xp: ModuleType, + ) -> tuple[IntArray, ...]: + """ + Return the indices for the lower-triangle of an (n, m) array. + + Parameters + ---------- + n + The row dimension of the arrays for which the returned indices will be + valid. + k + Diagonal offset. + m + The column dimension of the arrays for which the returned arrays will be + valid. By default m is taken equal to n. + + Returns + ------- + The row and column indices, respectively. The row indices are sorted in + non-decreasing order, and the correspdonding column indices are strictly + increasing for each row. + """ + if xp.__name__ in {"numpy", "jax.numpy"}: + return xp.tril_indices(n, k=k, m=m) # type: ignore[no-any-return] + + if xp.__name__ == "array_api_strict": + np = import_numpy(xp.__name__) + return np.tril_indices(n, k=k, m=m) # type: ignore[no-any-return] + + msg = "the array backend in not supported" + raise NotImplementedError(msg) diff --git a/glass/fields.py b/glass/fields.py index 5fc9861cc..e9d9264a5 100644 --- a/glass/fields.py +++ b/glass/fields.py @@ -593,7 +593,7 @@ def spectra_indices(n: int, *, xp: ModuleType | None = None) -> IntArray: """ xp = _utils.default_xp() if xp is None else xp - i, j = np.tril_indices(n) + i, j = uxpx.tril_indices(n, xp=xp) return xp.asarray([i, i - j]).T From a5a0467fb2929c17217a60abaf779c0489aa66eb Mon Sep 17 00:00:00 2001 From: connoraird Date: Wed, 28 Jan 2026 12:20:34 +0000 Subject: [PATCH 18/26] Add FloatArray as a pixwin return type --- glass/healpix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/glass/healpix.py b/glass/healpix.py index edb86a521..0d8132915 100644 --- a/glass/healpix.py +++ b/glass/healpix.py @@ -319,7 +319,7 @@ def pixwin( lmax: int | None = None, pol: bool = False, xp: ModuleType | None = None, -) -> tuple[FloatArray, ...]: +) -> FloatArray | tuple[FloatArray, ...]: """ Return the pixel window function for the given nside. From d79234dc8dc788ef11134757297f69b28939018a Mon Sep 17 00:00:00 2001 From: connoraird Date: Wed, 28 Jan 2026 12:23:30 +0000 Subject: [PATCH 19/26] Add clarifying comment about porting solve via conversion on inut and output --- glass/grf/_solver.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/glass/grf/_solver.py b/glass/grf/_solver.py index a25123c44..28bd4a882 100644 --- a/glass/grf/_solver.py +++ b/glass/grf/_solver.py @@ -81,6 +81,9 @@ def solve( # noqa: PLR0912, PLR0913 """ xp = cl.__array_namespace__() + + # This function is difficult to port to the Array API so for now we work + # in NumPy and ultimately convert back at the end of it. cl = np.asarray(cl) if t2 is None: From cdeb34edc0c629ef4d7cde2ddb3c71fde1d5eb47 Mon Sep 17 00:00:00 2001 From: connoraird Date: Wed, 28 Jan 2026 12:25:12 +0000 Subject: [PATCH 20/26] fix mypy error --- tests/core/test_healpix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/core/test_healpix.py b/tests/core/test_healpix.py index e3f99ebd2..d169fd95d 100644 --- a/tests/core/test_healpix.py +++ b/tests/core/test_healpix.py @@ -287,7 +287,7 @@ def test_pixwin( # Normalize to tuple old = old if isinstance(old, tuple) else (old,) - new = new if isinstance(new, tuple) else (new,) # type: ignore[redundant-expr] + new = new if isinstance(new, tuple) else (new,) assert len(old) == len(new) for i in range(len(old)): From cc9d2e0f154f9f1c7284af7ab1cb187b5d1e0942 Mon Sep 17 00:00:00 2001 From: connoraird Date: Wed, 28 Jan 2026 14:23:08 +0000 Subject: [PATCH 21/26] Extract xp when needed for clarity --- glass/fields.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/glass/fields.py b/glass/fields.py index e9d9264a5..1f55d5488 100644 --- a/glass/fields.py +++ b/glass/fields.py @@ -2,7 +2,6 @@ from __future__ import annotations -import itertools import math import warnings from collections.abc import Sequence @@ -124,10 +123,6 @@ def iternorm( If the covariance matrix is not positive definite. """ - # Convert to list here to allow determining the namespace - first = next(cov) # type: ignore[call-overload] - xp = first.__array_namespace__() - n = (size,) if isinstance(size, int) else size m = np.zeros((*n, k, k)) @@ -137,13 +132,15 @@ def iternorm( j = 0 if k > 0 else None # We must use cov_expanded here as cov has been consumed to determine the namespace - for i, x in enumerate(itertools.chain([first], cov)): - x = np.asarray(x) # noqa: PLW2901 - if x.shape != q: + for i, x in enumerate(cov): + x_np = np.asarray(x) + if x_np.shape != q: try: - x = np.broadcast_to(x, q) # noqa: PLW2901 + x_np = np.broadcast_to(x_np, q) except ValueError: - msg = f"covariance row {i}: shape {x.shape} cannot be broadcast to {q}" + msg = ( + f"covariance row {i}: shape {x_np.shape} cannot be broadcast to {q}" + ) raise TypeError(msg) from None # only need to update matrix A if there are correlations @@ -160,7 +157,7 @@ def iternorm( ) # compute new vector a - c = x[..., 1:, np.newaxis] + c = x_np[..., 1:, np.newaxis] a = np.matmul(m[..., :j], c[..., k - j :, :]) a += np.matmul(m[..., j:], c[..., : k - j, :]) a = np.reshape(a, (*n, k)) @@ -169,13 +166,15 @@ def iternorm( j = (j - 1) % k # compute new standard deviation - s = x[..., 0] - np.einsum("...i,...i", a, a) + s = x_np[..., 0] - np.einsum("...i,...i", a, a) if np.any(s < 0): msg = "covariance matrix is not positive definite" raise ValueError(msg) s = np.sqrt(s) # yield the next index, vector a, and standard deviation s + # converting back to the input array namespace. + xp = x.__array_namespace__() yield j, xp.asarray(a), xp.asarray(s) From b5a14d31c9bc1e7451cd4c6a0c00cdc2fbe12226 Mon Sep 17 00:00:00 2001 From: connoraird Date: Wed, 28 Jan 2026 14:28:45 +0000 Subject: [PATCH 22/26] Return the correct array type from tril_indices --- glass/_array_api_utils.py | 4 ++-- glass/fields.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/glass/_array_api_utils.py b/glass/_array_api_utils.py index 3986fc95f..2406aa9af 100644 --- a/glass/_array_api_utils.py +++ b/glass/_array_api_utils.py @@ -599,7 +599,7 @@ def tril_indices( Returns ------- The row and column indices, respectively. The row indices are sorted in - non-decreasing order, and the correspdonding column indices are strictly + non-decreasing order, and the corresponding column indices are strictly increasing for each row. """ if xp.__name__ in {"numpy", "jax.numpy"}: @@ -607,7 +607,7 @@ def tril_indices( if xp.__name__ == "array_api_strict": np = import_numpy(xp.__name__) - return np.tril_indices(n, k=k, m=m) # type: ignore[no-any-return] + return tuple(xp.asarray(arr) for arr in np.tril_indices(n, k=k, m=m)) msg = "the array backend in not supported" raise NotImplementedError(msg) diff --git a/glass/fields.py b/glass/fields.py index 1f55d5488..925d85076 100644 --- a/glass/fields.py +++ b/glass/fields.py @@ -593,7 +593,7 @@ def spectra_indices(n: int, *, xp: ModuleType | None = None) -> IntArray: xp = _utils.default_xp() if xp is None else xp i, j = uxpx.tril_indices(n, xp=xp) - return xp.asarray([i, i - j]).T + return xp.stack([i, i - j]).T def effective_cls( From 3dd2a45e1f79e7897235cc19e4d5568a9d0b1836 Mon Sep 17 00:00:00 2001 From: connoraird Date: Wed, 28 Jan 2026 14:35:30 +0000 Subject: [PATCH 23/26] Add comment to explain why np.asarray is needed --- glass/fields.py | 1 + 1 file changed, 1 insertion(+) diff --git a/glass/fields.py b/glass/fields.py index 925d85076..722083ea6 100644 --- a/glass/fields.py +++ b/glass/fields.py @@ -875,6 +875,7 @@ def generate( msg = "mismatch between number of fields and gls" raise ValueError(msg) + # cltovar requires numpy but getcl maintains xp, so conversion is required variances = (cltovar(np.asarray(getcl(gls, i, i))) for i in range(n)) grf = _generate_grf(gls, nside, ncorr=ncorr, rng=rng) From d4fb652c7db6f04afef235f134c71b8c8f53da89 Mon Sep 17 00:00:00 2001 From: connoraird Date: Wed, 28 Jan 2026 15:13:14 +0000 Subject: [PATCH 24/26] Add newline in docstring --- glass/_array_api_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/glass/_array_api_utils.py b/glass/_array_api_utils.py index 2406aa9af..45d78c2cb 100644 --- a/glass/_array_api_utils.py +++ b/glass/_array_api_utils.py @@ -601,6 +601,7 @@ def tril_indices( The row and column indices, respectively. The row indices are sorted in non-decreasing order, and the corresponding column indices are strictly increasing for each row. + """ if xp.__name__ in {"numpy", "jax.numpy"}: return xp.tril_indices(n, k=k, m=m) # type: ignore[no-any-return] From eb6ca31aab19d5534c7fff682e8b5790a49d4431 Mon Sep 17 00:00:00 2001 From: connoraird Date: Wed, 28 Jan 2026 15:18:56 +0000 Subject: [PATCH 25/26] Reorder comments --- glass/fields.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/glass/fields.py b/glass/fields.py index 722083ea6..6c60a3e3e 100644 --- a/glass/fields.py +++ b/glass/fields.py @@ -172,9 +172,10 @@ def iternorm( raise ValueError(msg) s = np.sqrt(s) - # yield the next index, vector a, and standard deviation s - # converting back to the input array namespace. + # Extract input array backend or conversion of outputs xp = x.__array_namespace__() + + # yield the next index, vector a, and standard deviation s yield j, xp.asarray(a), xp.asarray(s) From 690182caf946ec5c756d9f9c23f2046bd0373323 Mon Sep 17 00:00:00 2001 From: "Patrick J. Roddy" Date: Thu, 29 Jan 2026 12:07:31 +0000 Subject: [PATCH 26/26] gh-1000: Remove `conftest` imports and consistent `pytest.skip` (#1001) --- glass/fields.py | 2 +- tests/benchmarks/test_fields.py | 27 ++++++++++----------------- tests/benchmarks/test_galaxies.py | 4 ---- tests/benchmarks/test_harmonics.py | 3 ++- tests/benchmarks/test_lensing.py | 5 +---- tests/benchmarks/test_points.py | 4 ---- tests/benchmarks/test_shells.py | 3 ++- tests/core/test_fields.py | 2 +- tests/core/test_galaxies.py | 7 +++---- 9 files changed, 20 insertions(+), 37 deletions(-) diff --git a/glass/fields.py b/glass/fields.py index 5846c3e1d..7cb8fc6db 100644 --- a/glass/fields.py +++ b/glass/fields.py @@ -417,7 +417,7 @@ def _generate_grf( alm = _glass_to_healpix_alm(alm) # modes with m = 0 are real-valued and come first in array - alm = xpx.at(alm)[:n].set(xp.real(alm[:n])[:] + xp.imag(alm[:n]) + 0j) + alm = xpx.at(alm)[:n].set(xp.real(alm[:n]) + xp.imag(alm[:n]) + 0j) # transform alm to maps # can be performed in place on the temporary alm array diff --git a/tests/benchmarks/test_fields.py b/tests/benchmarks/test_fields.py index e1d4eecf6..0f8d9b3c3 100644 --- a/tests/benchmarks/test_fields.py +++ b/tests/benchmarks/test_fields.py @@ -13,10 +13,10 @@ from types import ModuleType from typing import Any - from conftest import Compare, GeneratorConsumer from pytest_benchmark.fixture import BenchmarkFixture from glass._types import AngularPowerSpectra, UnifiedGenerator + from tests.fixtures.helper_classes import Compare, GeneratorConsumer @pytest.mark.stable @@ -31,7 +31,7 @@ def test_iternorm_no_size( def function_to_benchmark() -> list[Any]: generator = glass.iternorm(k, iter(array_in)) - return generator_consumer.consume( # type: ignore[no-any-return] + return generator_consumer.consume( generator, valid_exception="covariance matrix is not positive definite", ) @@ -77,7 +77,7 @@ def test_iternorm_specify_size( def function_to_benchmark() -> list[Any]: generator = glass.iternorm(k, iter(array_in), size) - return generator_consumer.consume( # type: ignore[no-any-return] + return generator_consumer.consume( generator, valid_exception="covariance matrix is not positive definite", ) @@ -112,7 +112,7 @@ def test_iternorm_k_0( def function_to_benchmark() -> list[Any]: generator = glass.iternorm(k, iter(array_in)) - return generator_consumer.consume(generator) # type: ignore[no-any-return] + return generator_consumer.consume(generator) results = benchmark(function_to_benchmark) @@ -140,7 +140,7 @@ def function_to_benchmark() -> list[Any]: nf, nc, ) - return generator_consumer.consume(generator) # type: ignore[no-any-return] + return generator_consumer.consume(generator) covs = benchmark(function_to_benchmark) cov = covs[0] @@ -156,18 +156,14 @@ def function_to_benchmark() -> list[Any]: @pytest.mark.stable @pytest.mark.parametrize("use_rng", [False, True]) @pytest.mark.parametrize("ncorr", [None, 1]) -def test_generate_grf( # noqa: PLR0913 - xpb: ModuleType, +def test_generate_grf( benchmark: BenchmarkFixture, generator_consumer: GeneratorConsumer, + ncorr: int | None, urngb: UnifiedGenerator, use_rng: bool, # noqa: FBT001 - ncorr: int | None, ) -> None: """Benchmarks for glass.fields._generate_grf with positional arguments only.""" - if xpb.__name__ == "array_api_strict": - pytest.skip(f"glass.fields._generate_grf not yet ported for {xpb.__name__}") - gls: AngularPowerSpectra = [urngb.random(1_000)] nside = 4 @@ -178,7 +174,7 @@ def function_to_benchmark() -> list[Any]: rng=urngb if use_rng else None, ncorr=ncorr, ) - return generator_consumer.consume(generator) # type: ignore[no-any-return] + return generator_consumer.consume(generator) gaussian_fields = benchmark(function_to_benchmark) @@ -195,9 +191,6 @@ def test_generate( ncorr: int | None, ) -> None: """Benchmarks for glass.generate.""" - if xpb.__name__ == "array_api_strict": - pytest.skip(f"glass.generate not yet ported for {xpb.__name__}") - n = 100 fields = [lambda x, var: x for _ in range(n)] # noqa: ARG005 fields[1] = lambda x, var: x**2 # noqa: ARG005 @@ -212,8 +205,8 @@ def function_to_benchmark() -> list[Any]: nside=nside, ncorr=ncorr, ) - return generator_consumer.consume( # type: ignore[no-any-return] - generator, + return generator_consumer.consume( + generator, # type: ignore[arg-type] valid_exception="covariance matrix is not positive definite", ) diff --git a/tests/benchmarks/test_galaxies.py b/tests/benchmarks/test_galaxies.py index ef906b5c3..26058efcb 100644 --- a/tests/benchmarks/test_galaxies.py +++ b/tests/benchmarks/test_galaxies.py @@ -66,13 +66,9 @@ def test_redshifts_from_nz( def test_galaxy_shear( benchmark: BenchmarkFixture, urngb: UnifiedGenerator, - xpb: ModuleType, reduced_shear: bool, # noqa: FBT001 ) -> None: """Benchmark for galaxies.galaxy_shear.""" - if xpb.__name__ == "array_api_strict": - pytest.skip(f"glass.galaxy_shear not yet ported for {xpb.__name__}") - scale_factor = 100 size = (12 * scale_factor,) kappa = urngb.normal(size=size) diff --git a/tests/benchmarks/test_harmonics.py b/tests/benchmarks/test_harmonics.py index 5021ad995..c004c36c1 100644 --- a/tests/benchmarks/test_harmonics.py +++ b/tests/benchmarks/test_harmonics.py @@ -13,9 +13,10 @@ if TYPE_CHECKING: from types import ModuleType - from conftest import Compare from pytest_benchmark.fixture import BenchmarkFixture + from tests.fixtures.helper_classes import Compare + @pytest.mark.unstable def test_multalm( diff --git a/tests/benchmarks/test_lensing.py b/tests/benchmarks/test_lensing.py index a60a53c3a..7fc56f3ba 100644 --- a/tests/benchmarks/test_lensing.py +++ b/tests/benchmarks/test_lensing.py @@ -9,12 +9,12 @@ if TYPE_CHECKING: from types import ModuleType - from conftest import Compare from pytest_benchmark.fixture import BenchmarkFixture from typing_extensions import Never from glass._types import FloatArray, UnifiedGenerator from glass.cosmology import Cosmology + from tests.fixtures.helper_classes import Compare @pytest.mark.stable @@ -85,9 +85,6 @@ def test_multi_plane_weights( xpb: ModuleType, ) -> None: """Benchmarks for add_window and add_plane with a multi_plane_weights.""" - if xpb.__name__ == "array_api_strict": - pytest.skip(f"glass.multi_plane_weights not yet ported for {xpb.__name__}") - # Use this over the fixture to allow us to add many more windows shells = [ glass.RadialWindow( diff --git a/tests/benchmarks/test_points.py b/tests/benchmarks/test_points.py index 51199ae5b..396b5aa55 100644 --- a/tests/benchmarks/test_points.py +++ b/tests/benchmarks/test_points.py @@ -40,10 +40,6 @@ def test_positions_from_delta( # noqa: PLR0913 remove_monopole: bool, # noqa: FBT001 ) -> None: """Benchmarks for glass.positions_from_delta.""" - if xpb.__name__ == "array_api_strict": - pytest.skip( - f"glass.lensing.multi_plane_matrix not yet ported for {xpb.__name__}", - ) nside = 48 npix = 12 * nside * nside diff --git a/tests/benchmarks/test_shells.py b/tests/benchmarks/test_shells.py index 66334f902..552d2efd4 100644 --- a/tests/benchmarks/test_shells.py +++ b/tests/benchmarks/test_shells.py @@ -9,9 +9,10 @@ if TYPE_CHECKING: from types import ModuleType - from conftest import Compare from pytest_benchmark.fixture import BenchmarkFixture + from tests.fixtures.helper_classes import Compare + @pytest.mark.unstable def test_radialwindow( diff --git a/tests/core/test_fields.py b/tests/core/test_fields.py index 0cc755cb3..d1f9fafe5 100644 --- a/tests/core/test_fields.py +++ b/tests/core/test_fields.py @@ -380,7 +380,7 @@ def test_effective_cls(compare: type[Compare], xp: ModuleType) -> None: def test_generate_grf(compare: type[Compare], xp: ModuleType) -> None: - gls = [xp.asarray([1.0, 0.5, 0.1])] + gls: AngularPowerSpectra = [xp.asarray([1.0, 0.5, 0.1])] nside = 4 ncorr = 1 diff --git a/tests/core/test_galaxies.py b/tests/core/test_galaxies.py index 4ab0504b9..341f7f65d 100644 --- a/tests/core/test_galaxies.py +++ b/tests/core/test_galaxies.py @@ -18,7 +18,8 @@ def test_redshifts(mocker: MockerFixture, xp: ModuleType) -> None: if xp.__name__ == "jax.numpy": - pytest.skip("Arrays in redshifts are not immutable, so do not support jax") + pytest.skip(f"glass.redshifts not yet ported for {xp.__name__}") + # create a mock radial window function w = mocker.Mock() w.za = xp.linspace(0.0, 1.0, 20) @@ -37,9 +38,7 @@ def test_redshifts(mocker: MockerFixture, xp: ModuleType) -> None: def test_redshifts_from_nz(urng: UnifiedGenerator, xp: ModuleType) -> None: if xp.__name__ == "jax.numpy": - pytest.skip( - "Arrays in redshifts_from_nz are not immutable, so do not support jax", - ) + pytest.skip(f"glass.redshifts_from_nz not yet ported for {xp.__name__}") # test sampling