From dac16eae01fbaa13a5217a521fd053ec6c6d43f6 Mon Sep 17 00:00:00 2001 From: "Patrick J. Roddy" Date: Mon, 26 Jan 2026 10:51:59 +0000 Subject: [PATCH 1/9] gh-100: port `_generate_grf` --- glass/fields.py | 20 +++++++++++--------- tests/benchmarks/test_fields.py | 8 ++------ tests/core/test_fields.py | 17 +++++++++-------- 3 files changed, 22 insertions(+), 23 deletions(-) diff --git a/glass/fields.py b/glass/fields.py index 7af8aecc0..7ed9116b7 100644 --- a/glass/fields.py +++ b/glass/fields.py @@ -363,9 +363,12 @@ def _generate_grf( ------ ValueError If all gls are empty. + """ + xp = array_api_compat.array_namespace(*gls, use_compat=False) + if rng is None: - rng = _rng.rng_dispatcher(xp=np) + rng = _rng.rng_dispatcher(xp=xp) # number of gls and number of fields ngls = len(gls) @@ -376,7 +379,7 @@ def _generate_grf( ncorr = ngrf - 1 # number of modes - n = max((len(gl) for gl in gls), default=0) # type: ignore[arg-type] + n = max((gl.shape[0] for gl in gls), default=0) if n == 0: msg = "all gls are empty" raise ValueError(msg) @@ -385,8 +388,7 @@ def _generate_grf( cov = cls2cov(gls, n, ngrf, ncorr) # working arrays for the iterative sampling - z = np.zeros(n * (n + 1) // 2, dtype=np.complex128) - y = np.zeros((n * (n + 1) // 2, ncorr), dtype=np.complex128) + y = xp.zeros((n * (n + 1) // 2, ncorr), dtype=xp.complex128) # generate the conditional normal distribution for iterative sampling conditional_dist = iternorm(ncorr, cov, size=n) @@ -394,8 +396,9 @@ def _generate_grf( # sample the fields from the conditional distribution for j, a, s in conditional_dist: # standard normal random variates for alm - # sample real and imaginary parts, then view as complex number - rng.standard_normal(n * (n + 1), np.float64, z.view(np.float64)) # type: ignore[call-arg] + z = rng.standard_normal(n * (n + 1) // 2) + 1j * rng.standard_normal( + n * (n + 1) // 2, + ) # scale by standard deviation of the conditional distribution # variance is distributed over real and imaginary part @@ -407,13 +410,12 @@ def _generate_grf( # store the standard normal in y array at the indicated index if j is not None: - y[:, j] = z + y = xpx.at(y)[:, j].set(z) alm = _glass_to_healpix_alm(alm) # modes with m = 0 are real-valued and come first in array - np.real(alm[:n])[:] += np.imag(alm[:n]) - np.imag(alm[:n])[:] = 0 + alm = xpx.at(alm)[:n].set(xp.real(alm[:n]) + xp.imag(alm[:n]) + 0j) # transform alm to maps # can be performed in place on the temporary alm array diff --git a/tests/benchmarks/test_fields.py b/tests/benchmarks/test_fields.py index ef47f1a03..8b70a3598 100644 --- a/tests/benchmarks/test_fields.py +++ b/tests/benchmarks/test_fields.py @@ -16,7 +16,7 @@ from conftest import Compare, GeneratorConsumer from pytest_benchmark.fixture import BenchmarkFixture - from glass._types import UnifiedGenerator + from glass._types import AngularPowerSpectra, UnifiedGenerator @pytest.mark.stable @@ -157,7 +157,6 @@ def function_to_benchmark() -> list[Any]: @pytest.mark.parametrize("use_rng", [False, True]) @pytest.mark.parametrize("ncorr", [None, 1]) def test_generate_grf( # noqa: PLR0913 - xpb: ModuleType, benchmark: BenchmarkFixture, generator_consumer: GeneratorConsumer, urngb: UnifiedGenerator, @@ -165,10 +164,7 @@ def test_generate_grf( # noqa: PLR0913 ncorr: int | None, ) -> None: """Benchmarks for glass.fields._generate_grf with positional arguments only.""" - if xpb.__name__ == "array_api_strict": - pytest.skip(f"glass.fields._generate_grf not yet ported for {xpb.__name__}") - - gls = [urngb.random(1_000)] + gls: AngularPowerSpectra = [urngb.random(1_000)] nside = 4 def function_to_benchmark() -> list[Any]: diff --git a/tests/core/test_fields.py b/tests/core/test_fields.py index cfb1e5fda..e7e877578 100644 --- a/tests/core/test_fields.py +++ b/tests/core/test_fields.py @@ -16,6 +16,7 @@ from pytest_mock import MockerFixture + from glass._types import AngularPowerSpectra from tests.fixtures.helper_classes import Compare HAVE_JAX = importlib.util.find_spec("jax") is not None @@ -383,8 +384,8 @@ def test_effective_cls(compare: type[Compare], xp: ModuleType) -> None: assert result.shape == (1, 1, 15) -def test_generate_grf(compare: type[Compare]) -> None: - gls = [np.asarray([1.0, 0.5, 0.1])] +def test_generate_grf(compare: type[Compare], xp: ModuleType) -> None: + gls: AngularPowerSpectra = [xp.asarray([1.0, 0.5, 0.1])] nside = 4 ncorr = 1 @@ -393,13 +394,13 @@ def test_generate_grf(compare: type[Compare]) -> None: assert gaussian_fields[0].shape == (hp.nside2npix(nside),) # requires resetting the RNG for reproducibility - rng = _rng.rng_dispatcher(xp=np) + rng = _rng.rng_dispatcher(xp=xp) gaussian_fields = list(glass.fields._generate_grf(gls, nside, rng=rng)) assert gaussian_fields[0].shape == (hp.nside2npix(nside),) # requires resetting the RNG for reproducibility - rng = _rng.rng_dispatcher(xp=np) + rng = _rng.rng_dispatcher(xp=xp) new_gaussian_fields = list( glass.fields._generate_grf(gls, nside, ncorr=ncorr, rng=rng), ) @@ -409,7 +410,7 @@ def test_generate_grf(compare: type[Compare]) -> None: compare.assert_allclose(new_gaussian_fields[0], gaussian_fields[0]) with pytest.raises(ValueError, match="all gls are empty"): - list(glass.fields._generate_grf([], nside)) + list(glass.fields._generate_grf(xp.asarray([]), nside)) def test_generate_gaussian(xp: ModuleType) -> None: @@ -651,12 +652,12 @@ def test_healpix_to_glass_spectra(compare: type[Compare]) -> None: compare.assert_array_equal(out, [11, 22, 21, 33, 32, 31, 44, 43, 42, 41]) -def test_glass_to_healpix_alm(compare: type[Compare]) -> None: - inp = np.asarray([00, 10, 11, 20, 21, 22, 30, 31, 32, 33]) +def test_glass_to_healpix_alm(compare: type[Compare], xp: ModuleType) -> None: + inp = xp.asarray([00, 10, 11, 20, 21, 22, 30, 31, 32, 33]) out = glass.fields._glass_to_healpix_alm(inp) compare.assert_array_equal( out, - np.asarray([00, 10, 20, 30, 11, 21, 31, 22, 32, 33]), + xp.asarray([00, 10, 20, 30, 11, 21, 31, 22, 32, 33]), ) From 3fdb390d9a409da8c8c5ff1c4b364cb0ba4e965d Mon Sep 17 00:00:00 2001 From: "Patrick J. Roddy" Date: Mon, 26 Jan 2026 10:54:22 +0000 Subject: [PATCH 2/9] Revert test change --- tests/core/test_fields.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/core/test_fields.py b/tests/core/test_fields.py index e7e877578..a9545a632 100644 --- a/tests/core/test_fields.py +++ b/tests/core/test_fields.py @@ -652,12 +652,12 @@ def test_healpix_to_glass_spectra(compare: type[Compare]) -> None: compare.assert_array_equal(out, [11, 22, 21, 33, 32, 31, 44, 43, 42, 41]) -def test_glass_to_healpix_alm(compare: type[Compare], xp: ModuleType) -> None: - inp = xp.asarray([00, 10, 11, 20, 21, 22, 30, 31, 32, 33]) +def test_glass_to_healpix_alm(compare: type[Compare]) -> None: + inp = np.asarray([00, 10, 11, 20, 21, 22, 30, 31, 32, 33]) out = glass.fields._glass_to_healpix_alm(inp) compare.assert_array_equal( out, - xp.asarray([00, 10, 20, 30, 11, 21, 31, 22, 32, 33]), + np.asarray([00, 10, 20, 30, 11, 21, 31, 22, 32, 33]), ) From 469b454ad32055e1d9672f8f1181a9bcc54581af Mon Sep 17 00:00:00 2001 From: "Patrick J. Roddy" Date: Mon, 26 Jan 2026 10:57:44 +0000 Subject: [PATCH 3/9] Remove noqa --- tests/benchmarks/test_fields.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/benchmarks/test_fields.py b/tests/benchmarks/test_fields.py index 8b70a3598..0f8e714d0 100644 --- a/tests/benchmarks/test_fields.py +++ b/tests/benchmarks/test_fields.py @@ -156,7 +156,7 @@ def function_to_benchmark() -> list[Any]: @pytest.mark.stable @pytest.mark.parametrize("use_rng", [False, True]) @pytest.mark.parametrize("ncorr", [None, 1]) -def test_generate_grf( # noqa: PLR0913 +def test_generate_grf( benchmark: BenchmarkFixture, generator_consumer: GeneratorConsumer, urngb: UnifiedGenerator, From 754d7293a6714455cf1ccea1f158b5897a092478 Mon Sep 17 00:00:00 2001 From: "Patrick J. Roddy" Date: Mon, 26 Jan 2026 10:59:12 +0000 Subject: [PATCH 4/9] Put in list --- tests/core/test_fields.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/core/test_fields.py b/tests/core/test_fields.py index a9545a632..ed9f00106 100644 --- a/tests/core/test_fields.py +++ b/tests/core/test_fields.py @@ -410,7 +410,7 @@ def test_generate_grf(compare: type[Compare], xp: ModuleType) -> None: compare.assert_allclose(new_gaussian_fields[0], gaussian_fields[0]) with pytest.raises(ValueError, match="all gls are empty"): - list(glass.fields._generate_grf(xp.asarray([]), nside)) + list(glass.fields._generate_grf([xp.asarray([])], nside)) def test_generate_gaussian(xp: ModuleType) -> None: From 4d33f4e284e424872c2de905d0ba10af9b96fa6e Mon Sep 17 00:00:00 2001 From: "Patrick J. Roddy" Date: Mon, 26 Jan 2026 15:04:34 +0000 Subject: [PATCH 5/9] Pass array-api-strict --- glass/fields.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/glass/fields.py b/glass/fields.py index 7ed9116b7..5d020fc9f 100644 --- a/glass/fields.py +++ b/glass/fields.py @@ -940,10 +940,12 @@ def _glass_to_healpix_alm(alm: ComplexArray) -> ComplexArray: alm in HEALPix order. """ + xp = alm.__array_namespace__() + n = _inv_triangle_number(alm.size) - ell = np.arange(n) + ell = xp.arange(n) out = [alm[ell[m:] * (ell[m:] + 1) // 2 + m] for m in ell] - return np.concatenate(out) + return xp.concat(out) def lognormal_shift_hilbert2011(z: float) -> float: From 89af895ff70165e7a55084a428afbc09bca09332 Mon Sep 17 00:00:00 2001 From: "Patrick J. Roddy" Date: Mon, 26 Jan 2026 15:17:49 +0000 Subject: [PATCH 6/9] Tidy up imports --- tests/benchmarks/test_fields.py | 2 +- tests/benchmarks/test_harmonics.py | 3 ++- tests/benchmarks/test_lensing.py | 2 +- tests/benchmarks/test_shells.py | 3 ++- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/benchmarks/test_fields.py b/tests/benchmarks/test_fields.py index 0f8e714d0..810e2cc7b 100644 --- a/tests/benchmarks/test_fields.py +++ b/tests/benchmarks/test_fields.py @@ -13,10 +13,10 @@ from types import ModuleType from typing import Any - from conftest import Compare, GeneratorConsumer from pytest_benchmark.fixture import BenchmarkFixture from glass._types import AngularPowerSpectra, UnifiedGenerator + from tests.fixtures.helper_classes import Compare, GeneratorConsumer @pytest.mark.stable diff --git a/tests/benchmarks/test_harmonics.py b/tests/benchmarks/test_harmonics.py index 5021ad995..c004c36c1 100644 --- a/tests/benchmarks/test_harmonics.py +++ b/tests/benchmarks/test_harmonics.py @@ -13,9 +13,10 @@ if TYPE_CHECKING: from types import ModuleType - from conftest import Compare from pytest_benchmark.fixture import BenchmarkFixture + from tests.fixtures.helper_classes import Compare + @pytest.mark.unstable def test_multalm( diff --git a/tests/benchmarks/test_lensing.py b/tests/benchmarks/test_lensing.py index d34bac0bd..df6ba1784 100644 --- a/tests/benchmarks/test_lensing.py +++ b/tests/benchmarks/test_lensing.py @@ -9,12 +9,12 @@ if TYPE_CHECKING: from types import ModuleType - from conftest import Compare from pytest_benchmark.fixture import BenchmarkFixture from typing_extensions import Never from glass._types import FloatArray, UnifiedGenerator from glass.cosmology import Cosmology + from tests.fixtures.helper_classes import Compare @pytest.mark.stable diff --git a/tests/benchmarks/test_shells.py b/tests/benchmarks/test_shells.py index 66334f902..552d2efd4 100644 --- a/tests/benchmarks/test_shells.py +++ b/tests/benchmarks/test_shells.py @@ -9,9 +9,10 @@ if TYPE_CHECKING: from types import ModuleType - from conftest import Compare from pytest_benchmark.fixture import BenchmarkFixture + from tests.fixtures.helper_classes import Compare + @pytest.mark.unstable def test_radialwindow( From 0dd209a8d380874f0521266c23dc1afc8641335f Mon Sep 17 00:00:00 2001 From: "Patrick J. Roddy" Date: Mon, 26 Jan 2026 15:18:13 +0000 Subject: [PATCH 7/9] Don't skip in benchmarks --- tests/benchmarks/test_fields.py | 5 +---- tests/benchmarks/test_galaxies.py | 4 ---- tests/benchmarks/test_lensing.py | 6 ------ tests/benchmarks/test_points.py | 4 ---- 4 files changed, 1 insertion(+), 18 deletions(-) diff --git a/tests/benchmarks/test_fields.py b/tests/benchmarks/test_fields.py index 810e2cc7b..25b4dffa2 100644 --- a/tests/benchmarks/test_fields.py +++ b/tests/benchmarks/test_fields.py @@ -159,9 +159,9 @@ def function_to_benchmark() -> list[Any]: def test_generate_grf( benchmark: BenchmarkFixture, generator_consumer: GeneratorConsumer, + ncorr: int | None, urngb: UnifiedGenerator, use_rng: bool, # noqa: FBT001 - ncorr: int | None, ) -> None: """Benchmarks for glass.fields._generate_grf with positional arguments only.""" gls: AngularPowerSpectra = [urngb.random(1_000)] @@ -191,9 +191,6 @@ def test_generate( ncorr: int | None, ) -> None: """Benchmarks for glass.generate.""" - if xpb.__name__ == "array_api_strict": - pytest.skip(f"glass.generate not yet ported for {xpb.__name__}") - n = 100 fields = [lambda x, var: x for _ in range(n)] # noqa: ARG005 fields[1] = lambda x, var: x**2 # noqa: ARG005 diff --git a/tests/benchmarks/test_galaxies.py b/tests/benchmarks/test_galaxies.py index ef906b5c3..26058efcb 100644 --- a/tests/benchmarks/test_galaxies.py +++ b/tests/benchmarks/test_galaxies.py @@ -66,13 +66,9 @@ def test_redshifts_from_nz( def test_galaxy_shear( benchmark: BenchmarkFixture, urngb: UnifiedGenerator, - xpb: ModuleType, reduced_shear: bool, # noqa: FBT001 ) -> None: """Benchmark for galaxies.galaxy_shear.""" - if xpb.__name__ == "array_api_strict": - pytest.skip(f"glass.galaxy_shear not yet ported for {xpb.__name__}") - scale_factor = 100 size = (12 * scale_factor,) kappa = urngb.normal(size=size) diff --git a/tests/benchmarks/test_lensing.py b/tests/benchmarks/test_lensing.py index df6ba1784..fb5e017ec 100644 --- a/tests/benchmarks/test_lensing.py +++ b/tests/benchmarks/test_lensing.py @@ -26,9 +26,6 @@ def test_multi_plane_matrix( xpb: ModuleType, ) -> None: """Benchmarks for add_window and add_plane with a multi_plane_matrix.""" - if xpb.__name__ == "array_api_strict": - pytest.skip(f"glass.multi_plane_matrix not yet ported for {xpb.__name__}") - # Use this over the fixture to allow us to add many more windows shells = [ glass.RadialWindow( @@ -84,9 +81,6 @@ def test_multi_plane_weights( xpb: ModuleType, ) -> None: """Benchmarks for add_window and add_plane with a multi_plane_weights.""" - if xpb.__name__ == "array_api_strict": - pytest.skip(f"glass.multi_plane_weights not yet ported for {xpb.__name__}") - # Use this over the fixture to allow us to add many more windows shells = [ glass.RadialWindow( diff --git a/tests/benchmarks/test_points.py b/tests/benchmarks/test_points.py index 5451b51f4..727641b70 100644 --- a/tests/benchmarks/test_points.py +++ b/tests/benchmarks/test_points.py @@ -41,10 +41,6 @@ def test_positions_from_delta( # noqa: PLR0913 remove_monopole: bool, # noqa: FBT001 ) -> None: """Benchmarks for glass.positions_from_delta.""" - if xpb.__name__ == "array_api_strict": - pytest.skip( - f"glass.lensing.multi_plane_matrix not yet ported for {xpb.__name__}", - ) nside = 48 npix = 12 * nside * nside From 9c666baa02fbc9f25d6897b0ba9da9d2bc8027ef Mon Sep 17 00:00:00 2001 From: "Patrick J. Roddy" Date: Mon, 26 Jan 2026 15:18:32 +0000 Subject: [PATCH 8/9] Clean up `pytest.skip` --- tests/core/test_fields.py | 9 +++++---- tests/core/test_galaxies.py | 7 +++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/core/test_fields.py b/tests/core/test_fields.py index ed9f00106..5fe7e8f20 100644 --- a/tests/core/test_fields.py +++ b/tests/core/test_fields.py @@ -28,9 +28,8 @@ def not_triangle_numbers() -> list[int]: def test_iternorm(xp: ModuleType) -> None: - # Call jax version of iternorm once jax version is written if xp.__name__ == "jax.numpy": - pytest.skip("Arrays in iternorm are not immutable, so do not support jax") + pytest.skip(f"glass.iternorm not yet ported for {xp.__name__}") # check output shapes and types @@ -344,9 +343,8 @@ def test_discretized_cls(compare: type[Compare]) -> None: def test_effective_cls(compare: type[Compare], xp: ModuleType) -> None: - # Call jax version of iternorm once jax version is written if xp.__name__ == "jax.numpy": - pytest.skip("Arrays in effective_cls are not immutable, so do not support jax") + pytest.skip(f"glass.effective_cls not yet ported for {xp.__name__}") # empty cls @@ -385,6 +383,9 @@ def test_effective_cls(compare: type[Compare], xp: ModuleType) -> None: def test_generate_grf(compare: type[Compare], xp: ModuleType) -> None: + if xp.__name__ == "jax.numpy": + pytest.skip(f"glass.fields._generate_grf not yet ported for {xp.__name__}") + gls: AngularPowerSpectra = [xp.asarray([1.0, 0.5, 0.1])] nside = 4 ncorr = 1 diff --git a/tests/core/test_galaxies.py b/tests/core/test_galaxies.py index 4ab0504b9..341f7f65d 100644 --- a/tests/core/test_galaxies.py +++ b/tests/core/test_galaxies.py @@ -18,7 +18,8 @@ def test_redshifts(mocker: MockerFixture, xp: ModuleType) -> None: if xp.__name__ == "jax.numpy": - pytest.skip("Arrays in redshifts are not immutable, so do not support jax") + pytest.skip(f"glass.redshifts not yet ported for {xp.__name__}") + # create a mock radial window function w = mocker.Mock() w.za = xp.linspace(0.0, 1.0, 20) @@ -37,9 +38,7 @@ def test_redshifts(mocker: MockerFixture, xp: ModuleType) -> None: def test_redshifts_from_nz(urng: UnifiedGenerator, xp: ModuleType) -> None: if xp.__name__ == "jax.numpy": - pytest.skip( - "Arrays in redshifts_from_nz are not immutable, so do not support jax", - ) + pytest.skip(f"glass.redshifts_from_nz not yet ported for {xp.__name__}") # test sampling From d3297288812a852088d4d6b62747a9c23c2fd814 Mon Sep 17 00:00:00 2001 From: "Patrick J. Roddy" Date: Mon, 26 Jan 2026 15:22:32 +0000 Subject: [PATCH 9/9] Fix mypy --- tests/benchmarks/test_fields.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/benchmarks/test_fields.py b/tests/benchmarks/test_fields.py index 25b4dffa2..21f0d7e57 100644 --- a/tests/benchmarks/test_fields.py +++ b/tests/benchmarks/test_fields.py @@ -31,7 +31,7 @@ def test_iternorm_no_size( def function_to_benchmark() -> list[Any]: generator = glass.iternorm(k, iter(array_in)) - return generator_consumer.consume( # type: ignore[no-any-return] + return generator_consumer.consume( generator, valid_exception="covariance matrix is not positive definite", ) @@ -77,7 +77,7 @@ def test_iternorm_specify_size( def function_to_benchmark() -> list[Any]: generator = glass.iternorm(k, iter(array_in), size) - return generator_consumer.consume( # type: ignore[no-any-return] + return generator_consumer.consume( generator, valid_exception="covariance matrix is not positive definite", ) @@ -112,7 +112,7 @@ def test_iternorm_k_0( def function_to_benchmark() -> list[Any]: generator = glass.iternorm(k, iter(array_in)) - return generator_consumer.consume(generator) # type: ignore[no-any-return] + return generator_consumer.consume(generator) results = benchmark(function_to_benchmark) @@ -140,7 +140,7 @@ def function_to_benchmark() -> list[Any]: nf, nc, ) - return generator_consumer.consume(generator) # type: ignore[no-any-return] + return generator_consumer.consume(generator) covs = benchmark(function_to_benchmark) cov = covs[0] @@ -174,7 +174,7 @@ def function_to_benchmark() -> list[Any]: rng=urngb if use_rng else None, ncorr=ncorr, ) - return generator_consumer.consume(generator) # type: ignore[no-any-return] + return generator_consumer.consume(generator) gaussian_fields = benchmark(function_to_benchmark) @@ -205,8 +205,8 @@ def function_to_benchmark() -> list[Any]: nside=nside, ncorr=ncorr, ) - return generator_consumer.consume( # type: ignore[no-any-return] - generator, + return generator_consumer.consume( + generator, # type: ignore[arg-type] valid_exception="covariance matrix is not positive definite", )