diff --git a/glass/_array_api_utils.py b/glass/_array_api_utils.py index ac771783a..ca13d8ce3 100644 --- a/glass/_array_api_utils.py +++ b/glass/_array_api_utils.py @@ -27,7 +27,7 @@ import numpy as np - from glass._types import AnyArray, DTypeLike + from glass._types import AnyArray, DTypeLike, IntArray class CompatibleBackendNotFoundError(Exception): @@ -576,3 +576,42 @@ def ndindex(shape: tuple[int, ...], *, xp: ModuleType) -> np.ndindex: msg = "the array backend in not supported" raise NotImplementedError(msg) + + @staticmethod + def tril_indices( + n: int, + *, + k: int = 0, + m: int | None = None, + xp: ModuleType, + ) -> tuple[IntArray, ...]: + """ + Return the indices for the lower-triangle of an (n, m) array. + + Parameters + ---------- + n + The row dimension of the arrays for which the returned indices will be + valid. + k + Diagonal offset. + m + The column dimension of the arrays for which the returned arrays will be + valid. By default m is taken equal to n. + + Returns + ------- + The row and column indices, respectively. The row indices are sorted in + non-decreasing order, and the corresponding column indices are strictly + increasing for each row. + + """ + if xp.__name__ in {"numpy", "jax.numpy"}: + return xp.tril_indices(n, k=k, m=m) # type: ignore[no-any-return] + + if xp.__name__ == "array_api_strict": + np = import_numpy(xp.__name__) + return tuple(xp.asarray(arr) for arr in np.tril_indices(n, k=k, m=m)) + + msg = "the array backend in not supported" + raise NotImplementedError(msg) diff --git a/glass/fields.py b/glass/fields.py index 43fa3e787..7cb8fc6db 100644 --- a/glass/fields.py +++ b/glass/fields.py @@ -2,7 +2,6 @@ from __future__ import annotations -import itertools import math import warnings from collections.abc import Sequence @@ -126,66 +125,60 @@ def iternorm( If the covariance matrix is not positive definite. """ - # Convert to list here to allow determining the namespace - first = next(cov) # type: ignore[call-overload] - xp = first.__array_namespace__() - n = (size,) if isinstance(size, int) else size - m = xp.zeros((*n, k, k)) - a = xp.zeros((*n, k)) - s = xp.zeros((*n,)) + m = np.zeros((*n, k, k)) + a = np.zeros((*n, k)) + s = np.zeros((*n,)) q = (*n, k + 1) j = 0 if k > 0 else None # We must use cov_expanded here as cov has been consumed to determine the namespace - for i, x in enumerate(itertools.chain([first], cov)): - # Ideally would be xp.asanyarray but this does not yet exist. The key difference - # between the two in numpy is that asanyarray maintains subclasses of NDArray - # whereas asarray will return the base class NDArray. Currently, we don't seem - # to pass a subclass of NDArray so this, so it might be okay - x = xp.asarray(x) # noqa: PLW2901 - if x.shape != q: + for i, x in enumerate(cov): + x_np = np.asarray(x) + if x_np.shape != q: try: - x = xp.broadcast_to(x, q) # noqa: PLW2901 + x_np = np.broadcast_to(x_np, q) except ValueError: - msg = f"covariance row {i}: shape {x.shape} cannot be broadcast to {q}" + msg = ( + f"covariance row {i}: shape {x_np.shape} cannot be broadcast to {q}" + ) raise TypeError(msg) from None # only need to update matrix A if there are correlations if j is not None: # compute new entries of matrix A m[..., :, j] = 0 - m[..., j : j + 1, :] = xp.matmul(a[..., xp.newaxis, :], m) - m[..., j, j] = xp.where(s != 0, -1, s) - # To ensure we don't divide by zero or nan we use a mask to only divide the - # appropriate values of m and s - m_j = m[..., j, :] - s_broadcast = xp.broadcast_to(s[..., xp.newaxis], m_j.shape) - mask = (m_j != 0) & (s_broadcast != 0) & ~xp.isnan(s_broadcast) - m_j[mask] = xp.divide(m_j[mask], -s_broadcast[mask]) - m[..., j, :] = m_j + m[..., j : j + 1, :] = np.matmul(a[..., np.newaxis, :], m) + m[..., j, j] = np.where(s != 0, -1, 0) + np.divide( + m[..., j, :], + -s[..., np.newaxis], + where=(m[..., j, :] != 0), + out=m[..., j, :], + ) # compute new vector a - c = x[..., 1:, xp.newaxis] - a = xp.matmul(m[..., :j], c[..., k - j :, :]) - a += xp.matmul(m[..., j:], c[..., : k - j, :]) - a = xp.reshape(a, (*n, k)) + c = x_np[..., 1:, np.newaxis] + a = np.matmul(m[..., :j], c[..., k - j :, :]) + a += np.matmul(m[..., j:], c[..., : k - j, :]) + a = np.reshape(a, (*n, k)) # next rolling index j = (j - 1) % k # compute new standard deviation - a_np = np.asarray(a, copy=True) - einsum_result_np = np.einsum("...i,...i", a_np, a_np) - s = x[..., 0] - xp.asarray(einsum_result_np, copy=True) - if xp.any(s < 0): + s = x_np[..., 0] - np.einsum("...i,...i", a, a) + if np.any(s < 0): msg = "covariance matrix is not positive definite" raise ValueError(msg) - s = xp.sqrt(s) + s = np.sqrt(s) + + # Extract input array backend or conversion of outputs + xp = x.__array_namespace__() # yield the next index, vector a, and standard deviation s - yield j, a, s + yield j, xp.asarray(a), xp.asarray(s) def cls2cov( @@ -272,25 +265,30 @@ def discretized_cls( If the length of the Cls array is not a triangle number. """ + if len(cls) == 0: + return [] + + xp = array_api_compat.array_namespace(*cls, use_compat=False) + if ncorr is not None: n = nfields_from_nspectra(len(cls)) cls = [ - cls[i * (i + 1) // 2 + j] if j <= ncorr else np.asarray([]) + cls[i * (i + 1) // 2 + j] if j <= ncorr else xp.asarray([]) for i in range(n) for j in range(i + 1) ] if nside is not None: - pw = hp.pixwin(nside, lmax=lmax, xp=np) + pw: FloatArray = hp.pixwin(nside, lmax=lmax, xp=xp) gls = [] for cl in cls: - if len(cl) > 0: # type: ignore[arg-type] + if cl.shape[0] > 0: if lmax is not None: cl = cl[: lmax + 1] # noqa: PLW2901 if nside is not None: - n = min(len(cl), len(pw)) # type: ignore[arg-type] - cl = cl[:n] * pw[:n] ** 2 # type: ignore[operator] # noqa: PLW2901 + n = min(cl.shape[0], pw.shape[0]) + cl = cl[:n] * pw[:n] ** 2 # noqa: PLW2901 gls.append(cl) return gls @@ -369,8 +367,10 @@ def _generate_grf( If all gls are empty. """ + xp = array_api_compat.array_namespace(*gls, use_compat=False) + if rng is None: - rng = _rng.rng_dispatcher(xp=np) + rng = _rng.rng_dispatcher(xp=xp) # number of gls and number of fields ngls = len(gls) @@ -381,7 +381,7 @@ def _generate_grf( ncorr = ngrf - 1 # number of modes - n = max((len(gl) for gl in gls), default=0) # type: ignore[arg-type] + n = max((gl.shape[0] for gl in gls), default=0) if n == 0: msg = "all gls are empty" raise ValueError(msg) @@ -390,8 +390,8 @@ def _generate_grf( cov = cls2cov(gls, n, ngrf, ncorr) # working arrays for the iterative sampling - z = np.zeros(n * (n + 1) // 2, dtype=np.complex128) - y = np.zeros((n * (n + 1) // 2, ncorr), dtype=np.complex128) + z_size = n * (n + 1) // 2 + y = xp.zeros((z_size, ncorr), dtype=xp.complex128) # generate the conditional normal distribution for iterative sampling conditional_dist = iternorm(ncorr, cov, size=n) @@ -400,7 +400,7 @@ def _generate_grf( for j, a, s in conditional_dist: # standard normal random variates for alm # sample real and imaginary parts, then view as complex number - rng.standard_normal(n * (n + 1), np.float64, z.view(np.float64)) # type: ignore[call-arg] + z = rng.standard_normal((z_size,)) + (1j * rng.standard_normal((z_size,))) # scale by standard deviation of the conditional distribution # variance is distributed over real and imaginary part @@ -412,13 +412,12 @@ def _generate_grf( # store the standard normal in y array at the indicated index if j is not None: - y[:, j] = z + y = xpx.at(y)[:, j].set(z) alm = _glass_to_healpix_alm(alm) # modes with m = 0 are real-valued and come first in array - np.real(alm[:n])[:] += np.imag(alm[:n]) - np.imag(alm[:n])[:] = 0 + alm = xpx.at(alm)[:n].set(xp.real(alm[:n]) + xp.imag(alm[:n]) + 0j) # transform alm to maps # can be performed in place on the temporary alm array @@ -596,8 +595,8 @@ def spectra_indices(n: int, *, xp: ModuleType | None = None) -> IntArray: """ xp = _utils.default_xp() if xp is None else xp - i, j = xp.tril_indices(n) - return xp.asarray([i, i - j]).T + i, j = uxpx.tril_indices(n, xp=xp) + return xp.stack([i, i - j]).T def effective_cls( @@ -681,9 +680,9 @@ def effective_cls( for i1 in range(n) for i2 in range(n) ) - out[j1 + j2 + (...,)] = cl + out = xpx.at(out)[j1 + j2 + (...,)].set(cl) if weights2 is weights1 and j1 != j2: - out[j2 + j1 + (...,)] = cl + out = xpx.at(out)[j2 + j1 + (...,)].set(cl) return out @@ -796,6 +795,9 @@ def solve_gaussian_spectra( msg = "mismatch between number of fields and spectra" raise ValueError(msg) + if len(spectra) == 0: + return [] + gls = [] for i, j, cl in enumerate_spectra(spectra): if cl.size > 0: @@ -832,7 +834,7 @@ def generate( nside: int, *, ncorr: int | None = None, - rng: np.random.Generator | None = None, + rng: UnifiedGenerator | None = None, ) -> Iterator[AnyArray]: """ Sample random fields from Gaussian angular power spectra. @@ -876,7 +878,8 @@ def generate( msg = "mismatch between number of fields and gls" raise ValueError(msg) - variances = (cltovar(getcl(gls, i, i)) for i in range(n)) + # cltovar requires numpy but getcl maintains xp, so conversion is required + variances = (cltovar(np.asarray(getcl(gls, i, i))) for i in range(n)) grf = _generate_grf(gls, nside, ncorr=ncorr, rng=rng) for t, x, var in zip(fields, grf, variances, strict=True): @@ -995,6 +998,8 @@ def cov_from_spectra( Covariance matrix from the given spectra. """ + xp = array_api_compat.array_namespace(*spectra, use_compat=False) + # recover the number of fields from the number of spectra n = nfields_from_nspectra(len(spectra)) @@ -1004,14 +1009,17 @@ def cov_from_spectra( # this is the covariance matrix of the spectra # the leading dimension is k, then it is a n-by-n covariance matrix # missing entries are zero, which is the default value - cov = np.zeros((k, n, n)) + cov = xp.zeros((k, n, n), dtype=spectra[0].dtype) # fill the matrix up by going through the spectra in order # skip over entries that are None # if the spectra are ragged, some entries at high ell may remain zero # only fill the lower triangular part, everything is symmetric for i, j, cl in enumerate_spectra(spectra): - cov[: cl.size, i, j] = cov[: cl.size, j, i] = cl.reshape(-1)[:k] # type: ignore[union-attr] + size = min(k, cl.size) + cl_flat = xp.reshape(cl, (-1,)) + cov = xpx.at(cov)[:size, i, j].set(cl_flat[:size]) + cov = xpx.at(cov)[:size, j, i].set(cl_flat[:size]) return cov diff --git a/glass/grf/_solver.py b/glass/grf/_solver.py index 8533b6d43..28bd4a882 100644 --- a/glass/grf/_solver.py +++ b/glass/grf/_solver.py @@ -80,10 +80,16 @@ def solve( # noqa: PLR0912, PLR0913 :func:`glass.grf.compute`: Direct computation for band-limited spectra. """ + xp = cl.__array_namespace__() + + # This function is difficult to port to the Array API so for now we work + # in NumPy and ultimately convert back at the end of it. + cl = np.asarray(cl) + if t2 is None: t2 = t1 - n = len(cl) # type: ignore[arg-type] + n = len(cl) if pad < 0: msg = "pad must be a positive integer" raise ValueError(msg) @@ -138,4 +144,4 @@ def solve( # noqa: PLR0912, PLR0913 gl, gt, rl, fl, clerr = gl_, gt_, rl_, fl_, clerr_ - return gl, rl, info + return xp.asarray(gl), xp.asarray(rl), info diff --git a/glass/healpix.py b/glass/healpix.py index 0096c3b7b..91a5a3eca 100644 --- a/glass/healpix.py +++ b/glass/healpix.py @@ -319,7 +319,7 @@ def pixwin( lmax: int | None = None, pol: bool = False, xp: ModuleType | None = None, -) -> tuple[FloatArray, ...]: +) -> FloatArray | tuple[FloatArray, ...]: """ Return the pixel window function for the given nside. diff --git a/tests/benchmarks/test_fields.py b/tests/benchmarks/test_fields.py index e1d4eecf6..0f8d9b3c3 100644 --- a/tests/benchmarks/test_fields.py +++ b/tests/benchmarks/test_fields.py @@ -13,10 +13,10 @@ from types import ModuleType from typing import Any - from conftest import Compare, GeneratorConsumer from pytest_benchmark.fixture import BenchmarkFixture from glass._types import AngularPowerSpectra, UnifiedGenerator + from tests.fixtures.helper_classes import Compare, GeneratorConsumer @pytest.mark.stable @@ -31,7 +31,7 @@ def test_iternorm_no_size( def function_to_benchmark() -> list[Any]: generator = glass.iternorm(k, iter(array_in)) - return generator_consumer.consume( # type: ignore[no-any-return] + return generator_consumer.consume( generator, valid_exception="covariance matrix is not positive definite", ) @@ -77,7 +77,7 @@ def test_iternorm_specify_size( def function_to_benchmark() -> list[Any]: generator = glass.iternorm(k, iter(array_in), size) - return generator_consumer.consume( # type: ignore[no-any-return] + return generator_consumer.consume( generator, valid_exception="covariance matrix is not positive definite", ) @@ -112,7 +112,7 @@ def test_iternorm_k_0( def function_to_benchmark() -> list[Any]: generator = glass.iternorm(k, iter(array_in)) - return generator_consumer.consume(generator) # type: ignore[no-any-return] + return generator_consumer.consume(generator) results = benchmark(function_to_benchmark) @@ -140,7 +140,7 @@ def function_to_benchmark() -> list[Any]: nf, nc, ) - return generator_consumer.consume(generator) # type: ignore[no-any-return] + return generator_consumer.consume(generator) covs = benchmark(function_to_benchmark) cov = covs[0] @@ -156,18 +156,14 @@ def function_to_benchmark() -> list[Any]: @pytest.mark.stable @pytest.mark.parametrize("use_rng", [False, True]) @pytest.mark.parametrize("ncorr", [None, 1]) -def test_generate_grf( # noqa: PLR0913 - xpb: ModuleType, +def test_generate_grf( benchmark: BenchmarkFixture, generator_consumer: GeneratorConsumer, + ncorr: int | None, urngb: UnifiedGenerator, use_rng: bool, # noqa: FBT001 - ncorr: int | None, ) -> None: """Benchmarks for glass.fields._generate_grf with positional arguments only.""" - if xpb.__name__ == "array_api_strict": - pytest.skip(f"glass.fields._generate_grf not yet ported for {xpb.__name__}") - gls: AngularPowerSpectra = [urngb.random(1_000)] nside = 4 @@ -178,7 +174,7 @@ def function_to_benchmark() -> list[Any]: rng=urngb if use_rng else None, ncorr=ncorr, ) - return generator_consumer.consume(generator) # type: ignore[no-any-return] + return generator_consumer.consume(generator) gaussian_fields = benchmark(function_to_benchmark) @@ -195,9 +191,6 @@ def test_generate( ncorr: int | None, ) -> None: """Benchmarks for glass.generate.""" - if xpb.__name__ == "array_api_strict": - pytest.skip(f"glass.generate not yet ported for {xpb.__name__}") - n = 100 fields = [lambda x, var: x for _ in range(n)] # noqa: ARG005 fields[1] = lambda x, var: x**2 # noqa: ARG005 @@ -212,8 +205,8 @@ def function_to_benchmark() -> list[Any]: nside=nside, ncorr=ncorr, ) - return generator_consumer.consume( # type: ignore[no-any-return] - generator, + return generator_consumer.consume( + generator, # type: ignore[arg-type] valid_exception="covariance matrix is not positive definite", ) diff --git a/tests/benchmarks/test_harmonics.py b/tests/benchmarks/test_harmonics.py index 5021ad995..c004c36c1 100644 --- a/tests/benchmarks/test_harmonics.py +++ b/tests/benchmarks/test_harmonics.py @@ -13,9 +13,10 @@ if TYPE_CHECKING: from types import ModuleType - from conftest import Compare from pytest_benchmark.fixture import BenchmarkFixture + from tests.fixtures.helper_classes import Compare + @pytest.mark.unstable def test_multalm( diff --git a/tests/benchmarks/test_lensing.py b/tests/benchmarks/test_lensing.py index a60a53c3a..7fc56f3ba 100644 --- a/tests/benchmarks/test_lensing.py +++ b/tests/benchmarks/test_lensing.py @@ -9,12 +9,12 @@ if TYPE_CHECKING: from types import ModuleType - from conftest import Compare from pytest_benchmark.fixture import BenchmarkFixture from typing_extensions import Never from glass._types import FloatArray, UnifiedGenerator from glass.cosmology import Cosmology + from tests.fixtures.helper_classes import Compare @pytest.mark.stable @@ -85,9 +85,6 @@ def test_multi_plane_weights( xpb: ModuleType, ) -> None: """Benchmarks for add_window and add_plane with a multi_plane_weights.""" - if xpb.__name__ == "array_api_strict": - pytest.skip(f"glass.multi_plane_weights not yet ported for {xpb.__name__}") - # Use this over the fixture to allow us to add many more windows shells = [ glass.RadialWindow( diff --git a/tests/benchmarks/test_points.py b/tests/benchmarks/test_points.py index 51199ae5b..396b5aa55 100644 --- a/tests/benchmarks/test_points.py +++ b/tests/benchmarks/test_points.py @@ -40,10 +40,6 @@ def test_positions_from_delta( # noqa: PLR0913 remove_monopole: bool, # noqa: FBT001 ) -> None: """Benchmarks for glass.positions_from_delta.""" - if xpb.__name__ == "array_api_strict": - pytest.skip( - f"glass.lensing.multi_plane_matrix not yet ported for {xpb.__name__}", - ) nside = 48 npix = 12 * nside * nside diff --git a/tests/benchmarks/test_shells.py b/tests/benchmarks/test_shells.py index 66334f902..552d2efd4 100644 --- a/tests/benchmarks/test_shells.py +++ b/tests/benchmarks/test_shells.py @@ -9,9 +9,10 @@ if TYPE_CHECKING: from types import ModuleType - from conftest import Compare from pytest_benchmark.fixture import BenchmarkFixture + from tests.fixtures.helper_classes import Compare + @pytest.mark.unstable def test_radialwindow( diff --git a/tests/core/test_fields.py b/tests/core/test_fields.py index 55eb9dac2..d1f9fafe5 100644 --- a/tests/core/test_fields.py +++ b/tests/core/test_fields.py @@ -3,7 +3,6 @@ import importlib.util from typing import TYPE_CHECKING -import numpy as np import pytest import glass @@ -16,7 +15,7 @@ from pytest_mock import MockerFixture - from glass._types import AngularPowerSpectra + from glass._types import AngularPowerSpectra, FloatArray, UnifiedGenerator from tests.fixtures.helper_classes import Compare HAVE_JAX = importlib.util.find_spec("jax") is not None @@ -28,10 +27,6 @@ def not_triangle_numbers() -> list[int]: def test_iternorm(xp: ModuleType) -> None: - # Call jax version of iternorm once jax version is written - if xp.__name__ == "jax.numpy": - pytest.skip("Arrays in iternorm are not immutable, so do not support jax") - # check output shapes and types k = 2 @@ -277,7 +272,7 @@ def test_cls2cov_no_jax(compare: type[Compare], xpb: ModuleType) -> None: compare.assert_allclose(cov2_copy, cov3) -def test_lognormal_gls() -> None: +def test_lognormal_gls(xp: ModuleType) -> None: shift = 2 # empty cls @@ -286,19 +281,20 @@ def test_lognormal_gls() -> None: # check output shape - assert len(glass.lognormal_gls([np.linspace(1, 5, 5)], shift)) == 1 - assert len(glass.lognormal_gls([np.linspace(1, 5, 5)], shift)[0]) == 5 + out = glass.lognormal_gls([xp.linspace(1, 5, 5)], shift) + assert len(out) == 1 + assert out[0].shape[0] == 5 - inp = [np.linspace(1, 6, 5), np.linspace(1, 5, 4), np.linspace(1, 4, 3)] + inp = [xp.linspace(1, 6, 5), xp.linspace(1, 5, 4), xp.linspace(1, 4, 3)] out = glass.lognormal_gls(inp, shift) assert len(out) == 3 - assert len(out[0]) == 5 - assert len(out[1]) == 4 - assert len(out[2]) == 3 + assert out[0].shape[0] == 5 + assert out[1].shape[0] == 4 + assert out[2].shape[0] == 3 -def test_discretized_cls(compare: type[Compare]) -> None: +def test_discretized_cls(compare: type[Compare], xp: ModuleType) -> None: # empty cls result = glass.discretized_cls([]) @@ -307,47 +303,46 @@ def test_discretized_cls(compare: type[Compare]) -> None: # power spectra truncated at lmax + 1 if lmax provided result = glass.discretized_cls( - [np.arange(10), np.arange(10), np.arange(10)], + [xp.arange(10), xp.arange(10), xp.arange(10)], lmax=5, ) for cl in result: - assert len(cl) == 6 + assert cl.shape[0] == 6 # check ValueError for triangle number with pytest.raises(ValueError, match="invalid number of spectra:"): - glass.discretized_cls([np.arange(10), np.arange(10)], ncorr=1) + glass.discretized_cls([xp.arange(10), xp.arange(10)], ncorr=1) # ncorr not None - cls: AngularPowerSpectra = [np.arange(10), np.arange(10), np.arange(10)] + cls: AngularPowerSpectra = [xp.arange(10), xp.arange(10), xp.arange(10)] ncorr = 0 result = glass.discretized_cls(cls, ncorr=ncorr) - assert len(result[0]) == 10 - assert len(result[1]) == 10 - assert len(result[2]) == 0 # third correlation should be removed + assert result[0].shape[0] == 10 + assert result[1].shape[0] == 10 + assert result[2].shape[0] == 0 # third correlation should be removed # check if pixel window function was applied correctly with nside not None nside = 4 - pw = hp.pixwin(nside, lmax=7, xp=np) + pw: FloatArray = hp.pixwin(nside, lmax=7, xp=xp) - result = glass.discretized_cls([[], np.ones(10), np.ones(10)], nside=nside) + result = glass.discretized_cls( + [xp.asarray([]), xp.ones(10), xp.ones(10)], + nside=nside, + ) for cl in result: - n = min(len(cl), len(pw)) - expected = np.ones(n) * pw[:n] ** 2 # type: ignore[operator] + n = min(cl.shape[0], pw.shape[0]) + expected = xp.ones(n) * pw[:n] ** 2 compare.assert_allclose(cl[:n], expected) def test_effective_cls(compare: type[Compare], xp: ModuleType) -> None: - # Call jax version of iternorm once jax version is written - if xp.__name__ == "jax.numpy": - pytest.skip("Arrays in effective_cls are not immutable, so do not support jax") - # empty cls result = glass.effective_cls([], xp.asarray([])) @@ -384,8 +379,8 @@ def test_effective_cls(compare: type[Compare], xp: ModuleType) -> None: assert result.shape == (1, 1, 15) -def test_generate_grf(compare: type[Compare]) -> None: - gls: AngularPowerSpectra = [np.asarray([1.0, 0.5, 0.1])] +def test_generate_grf(compare: type[Compare], xp: ModuleType) -> None: + gls: AngularPowerSpectra = [xp.asarray([1.0, 0.5, 0.1])] nside = 4 ncorr = 1 @@ -394,13 +389,13 @@ def test_generate_grf(compare: type[Compare]) -> None: assert gaussian_fields[0].shape == (hp.nside2npix(nside),) # requires resetting the RNG for reproducibility - rng = _rng.rng_dispatcher(xp=np) + rng = _rng.rng_dispatcher(xp=xp) gaussian_fields = list(glass.fields._generate_grf(gls, nside, rng=rng)) assert gaussian_fields[0].shape == (hp.nside2npix(nside),) # requires resetting the RNG for reproducibility - rng = _rng.rng_dispatcher(xp=np) + rng = _rng.rng_dispatcher(xp=xp) new_gaussian_fields = list( glass.fields._generate_grf(gls, nside, ncorr=ncorr, rng=rng), ) @@ -410,7 +405,7 @@ def test_generate_grf(compare: type[Compare]) -> None: compare.assert_allclose(new_gaussian_fields[0], gaussian_fields[0]) with pytest.raises(ValueError, match="all gls are empty"): - list(glass.fields._generate_grf([np.asarray([])], nside)) + list(glass.fields._generate_grf([xp.asarray([])], nside)) def test_generate_gaussian(xp: ModuleType) -> None: @@ -423,19 +418,19 @@ def test_generate_lognormal(xp: ModuleType) -> None: glass.generate_lognormal([xp.asarray([1.0, 0.5, 0.1])], 4) -def test_generate(compare: type[Compare]) -> None: +def test_generate(compare: type[Compare], xp: ModuleType) -> None: # shape mismatch error fields = [lambda x, var: x, lambda x, var: x] # noqa: ARG005 with pytest.raises(ValueError, match="mismatch between number of fields and gls"): - list(glass.generate(fields, [np.ones(10), np.ones(10)], nside=16)) + list(glass.generate(fields, [xp.ones(10), xp.ones(10)], nside=16)) # check output shape nside = 16 npix = hp.nside2npix(nside) - gls: AngularPowerSpectra = [np.ones(10), np.ones(10), np.ones(10)] + gls: AngularPowerSpectra = [xp.ones(10), xp.ones(10), xp.ones(10)] result = list(glass.generate(fields, gls, nside=nside)) @@ -506,12 +501,12 @@ def test_nfields_from_nspectra(not_triangle_numbers: list[int]) -> None: glass.nfields_from_nspectra(t) -def test_enumerate_spectra() -> None: +def test_enumerate_spectra(compare: type[Compare], xp: ModuleType) -> None: n = 100 tn = n * (n + 1) // 2 # create mock spectra with 1 element counting to tn - spectra: AngularPowerSpectra = np.arange(tn).reshape(tn, 1) + spectra: AngularPowerSpectra = [xp.asarray(x) for x in range(tn)] # this is the expected order of indices indices = [(i, j) for i in range(n) for j in range(i, -1, -1)] @@ -521,7 +516,7 @@ def test_enumerate_spectra() -> None: # go through expected indices and values and compare for k, (i, j) in enumerate(indices): - assert next(it) == (i, j, k) + compare.assert_allclose(next(it), (i, j, k)) # make sure iterator is exhausted with pytest.raises(StopIteration): @@ -530,10 +525,13 @@ def test_enumerate_spectra() -> None: def test_spectra_indices(compare: type[Compare], xp: ModuleType) -> None: compare.assert_array_equal(glass.spectra_indices(0), xp.zeros((0, 2))) - compare.assert_array_equal(glass.spectra_indices(1), [[0, 0]]) - compare.assert_array_equal(glass.spectra_indices(2), [[0, 0], [1, 1], [1, 0]]) + compare.assert_array_equal(glass.spectra_indices(0, xp=xp), xp.zeros((0, 2))) + compare.assert_array_equal(glass.spectra_indices(1, xp=xp), [[0, 0]]) compare.assert_array_equal( - glass.spectra_indices(3), + glass.spectra_indices(2, xp=xp), [[0, 0], [1, 1], [1, 0]] + ) + compare.assert_array_equal( + glass.spectra_indices(3, xp=xp), [[0, 0], [1, 1], [1, 0], [2, 2], [2, 1], [2, 0]], ) @@ -653,11 +651,11 @@ def test_healpix_to_glass_spectra(compare: type[Compare]) -> None: def test_glass_to_healpix_alm(compare: type[Compare], xp: ModuleType) -> None: - inp = xp.asarray([00, 10, 11, 20, 21, 22, 30, 31, 32, 33]) + inp = xp.asarray([00, 10, 11, 20, 21, 22, 30, 31, 32, 33], dtype=xp.complex128) out = glass.fields._glass_to_healpix_alm(inp) compare.assert_array_equal( out, - xp.asarray([00, 10, 20, 30, 11, 21, 31, 22, 32, 33]), + xp.asarray([00, 10, 20, 30, 11, 21, 31, 22, 32, 33], dtype=xp.complex128), ) @@ -671,17 +669,21 @@ def test_lognormal_shift_hilbert2011(compare: type[Compare]) -> None: compare.assert_allclose(shifts, check, atol=1e-4, rtol=1e-4) -def test_cov_from_spectra(compare: type[Compare]) -> None: - spectra: AngularPowerSpectra = np.asarray( - [ +def test_cov_from_spectra( + compare: type[Compare], + xp: ModuleType, +) -> None: + spectra: AngularPowerSpectra = [ + xp.asarray(x) + for x in [ [110, 111, 112, 113], [220, 221, 222, 223], [210, 211, 212, 213], [330, 331, 332, 333], [320, 321, 322, 323], [310, 311, 312, 313], - ], - ) + ] + ] compare.assert_array_equal( glass.cov_from_spectra(spectra), @@ -757,44 +759,47 @@ def test_cov_from_spectra(compare: type[Compare]) -> None: ) -def test_check_posdef_spectra() -> None: +def test_check_posdef_spectra(xp: ModuleType) -> None: # posdef spectra assert glass.check_posdef_spectra( - np.asarray( - [ + [ + xp.asarray(x) + for x in [ [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [0.9, 0.9, 0.9], - ], - ), + ] + ] ) # semidef spectra assert glass.check_posdef_spectra( - np.asarray( - [ + [ + xp.asarray(x) + for x in [ [1.0, 1.0, 1.0], [1.0, 1.0, 0.0], [0.9, 1.0, 0.0], - ], - ), + ] + ] ) # indef spectra assert not glass.check_posdef_spectra( - np.asarray( - [ + [ + xp.asarray(x) + for x in [ [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.1, 1.1, 1.1], - ], - ), + ] + ] ) def test_regularized_spectra( mocker: MockerFixture, - rng: np.random.Generator, + urng: UnifiedGenerator, ) -> None: - spectra: AngularPowerSpectra = rng.random(size=(6, 101)) + spectra: AngularPowerSpectra = [urng.random(101) for _ in range(6)] # test method "nearest" cov_nearest = mocker.spy(glass.algorithm, "cov_nearest") diff --git a/tests/core/test_healpix.py b/tests/core/test_healpix.py index 792c6569b..2be77e387 100644 --- a/tests/core/test_healpix.py +++ b/tests/core/test_healpix.py @@ -287,7 +287,7 @@ def test_pixwin( # Normalize to tuple old = old if isinstance(old, tuple) else (old,) - new = new if isinstance(new, tuple) else (new,) # type: ignore[redundant-expr] + new = new if isinstance(new, tuple) else (new,) assert len(old) == len(new) for i in range(len(old)):