Skip to content
2 changes: 1 addition & 1 deletion glass/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def _generate_grf(
alm = _glass_to_healpix_alm(alm)

# modes with m = 0 are real-valued and come first in array
alm = xpx.at(alm)[:n].set(xp.real(alm[:n])[:] + xp.imag(alm[:n]) + 0j)
alm = xpx.at(alm)[:n].set(xp.real(alm[:n]) + xp.imag(alm[:n]) + 0j)

# transform alm to maps
# can be performed in place on the temporary alm array
Expand Down
27 changes: 10 additions & 17 deletions tests/benchmarks/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
from types import ModuleType
from typing import Any

from conftest import Compare, GeneratorConsumer
from pytest_benchmark.fixture import BenchmarkFixture

from glass._types import AngularPowerSpectra, UnifiedGenerator
from tests.fixtures.helper_classes import Compare, GeneratorConsumer


@pytest.mark.stable
Expand All @@ -31,7 +31,7 @@ def test_iternorm_no_size(

def function_to_benchmark() -> list[Any]:
generator = glass.iternorm(k, iter(array_in))
return generator_consumer.consume( # type: ignore[no-any-return]
return generator_consumer.consume(
generator,
valid_exception="covariance matrix is not positive definite",
)
Expand Down Expand Up @@ -77,7 +77,7 @@ def test_iternorm_specify_size(

def function_to_benchmark() -> list[Any]:
generator = glass.iternorm(k, iter(array_in), size)
return generator_consumer.consume( # type: ignore[no-any-return]
return generator_consumer.consume(
generator,
valid_exception="covariance matrix is not positive definite",
)
Expand Down Expand Up @@ -112,7 +112,7 @@ def test_iternorm_k_0(

def function_to_benchmark() -> list[Any]:
generator = glass.iternorm(k, iter(array_in))
return generator_consumer.consume(generator) # type: ignore[no-any-return]
return generator_consumer.consume(generator)

results = benchmark(function_to_benchmark)

Expand Down Expand Up @@ -140,7 +140,7 @@ def function_to_benchmark() -> list[Any]:
nf,
nc,
)
return generator_consumer.consume(generator) # type: ignore[no-any-return]
return generator_consumer.consume(generator)

covs = benchmark(function_to_benchmark)
cov = covs[0]
Expand All @@ -156,18 +156,14 @@ def function_to_benchmark() -> list[Any]:
@pytest.mark.stable
@pytest.mark.parametrize("use_rng", [False, True])
@pytest.mark.parametrize("ncorr", [None, 1])
def test_generate_grf( # noqa: PLR0913
xpb: ModuleType,
def test_generate_grf(
benchmark: BenchmarkFixture,
generator_consumer: GeneratorConsumer,
ncorr: int | None,
urngb: UnifiedGenerator,
use_rng: bool, # noqa: FBT001
ncorr: int | None,
) -> None:
"""Benchmarks for glass.fields._generate_grf with positional arguments only."""
if xpb.__name__ == "array_api_strict":
pytest.skip(f"glass.fields._generate_grf not yet ported for {xpb.__name__}")

gls: AngularPowerSpectra = [urngb.random(1_000)]
nside = 4

Expand All @@ -178,7 +174,7 @@ def function_to_benchmark() -> list[Any]:
rng=urngb if use_rng else None,
ncorr=ncorr,
)
return generator_consumer.consume(generator) # type: ignore[no-any-return]
return generator_consumer.consume(generator)

gaussian_fields = benchmark(function_to_benchmark)

Expand All @@ -195,9 +191,6 @@ def test_generate(
ncorr: int | None,
) -> None:
"""Benchmarks for glass.generate."""
if xpb.__name__ == "array_api_strict":
pytest.skip(f"glass.generate not yet ported for {xpb.__name__}")

n = 100
fields = [lambda x, var: x for _ in range(n)] # noqa: ARG005
fields[1] = lambda x, var: x**2 # noqa: ARG005
Expand All @@ -212,8 +205,8 @@ def function_to_benchmark() -> list[Any]:
nside=nside,
ncorr=ncorr,
)
return generator_consumer.consume( # type: ignore[no-any-return]
generator,
return generator_consumer.consume(
generator, # type: ignore[arg-type]
valid_exception="covariance matrix is not positive definite",
)

Expand Down
4 changes: 0 additions & 4 deletions tests/benchmarks/test_galaxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,9 @@ def test_redshifts_from_nz(
def test_galaxy_shear(
benchmark: BenchmarkFixture,
urngb: UnifiedGenerator,
xpb: ModuleType,
reduced_shear: bool, # noqa: FBT001
) -> None:
"""Benchmark for galaxies.galaxy_shear."""
if xpb.__name__ == "array_api_strict":
pytest.skip(f"glass.galaxy_shear not yet ported for {xpb.__name__}")

scale_factor = 100
size = (12 * scale_factor,)
kappa = urngb.normal(size=size)
Expand Down
3 changes: 2 additions & 1 deletion tests/benchmarks/test_harmonics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
if TYPE_CHECKING:
from types import ModuleType

from conftest import Compare
from pytest_benchmark.fixture import BenchmarkFixture

from tests.fixtures.helper_classes import Compare


@pytest.mark.unstable
def test_multalm(
Expand Down
5 changes: 1 addition & 4 deletions tests/benchmarks/test_lensing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
if TYPE_CHECKING:
from types import ModuleType

from conftest import Compare
from pytest_benchmark.fixture import BenchmarkFixture
from typing_extensions import Never

from glass._types import FloatArray, UnifiedGenerator
from glass.cosmology import Cosmology
from tests.fixtures.helper_classes import Compare


@pytest.mark.stable
Expand Down Expand Up @@ -85,9 +85,6 @@ def test_multi_plane_weights(
xpb: ModuleType,
) -> None:
"""Benchmarks for add_window and add_plane with a multi_plane_weights."""
if xpb.__name__ == "array_api_strict":
pytest.skip(f"glass.multi_plane_weights not yet ported for {xpb.__name__}")

# Use this over the fixture to allow us to add many more windows
shells = [
glass.RadialWindow(
Expand Down
4 changes: 0 additions & 4 deletions tests/benchmarks/test_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,6 @@ def test_positions_from_delta( # noqa: PLR0913
remove_monopole: bool, # noqa: FBT001
) -> None:
"""Benchmarks for glass.positions_from_delta."""
if xpb.__name__ == "array_api_strict":
pytest.skip(
f"glass.lensing.multi_plane_matrix not yet ported for {xpb.__name__}",
)
nside = 48
npix = 12 * nside * nside

Expand Down
3 changes: 2 additions & 1 deletion tests/benchmarks/test_shells.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
if TYPE_CHECKING:
from types import ModuleType

from conftest import Compare
from pytest_benchmark.fixture import BenchmarkFixture

from tests.fixtures.helper_classes import Compare


@pytest.mark.unstable
def test_radialwindow(
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def test_effective_cls(compare: type[Compare], xp: ModuleType) -> None:


def test_generate_grf(compare: type[Compare], xp: ModuleType) -> None:
gls = [xp.asarray([1.0, 0.5, 0.1])]
gls: AngularPowerSpectra = [xp.asarray([1.0, 0.5, 0.1])]
nside = 4
ncorr = 1

Expand Down
7 changes: 3 additions & 4 deletions tests/core/test_galaxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

def test_redshifts(mocker: MockerFixture, xp: ModuleType) -> None:
if xp.__name__ == "jax.numpy":
pytest.skip("Arrays in redshifts are not immutable, so do not support jax")
pytest.skip(f"glass.redshifts not yet ported for {xp.__name__}")

# create a mock radial window function
w = mocker.Mock()
w.za = xp.linspace(0.0, 1.0, 20)
Expand All @@ -37,9 +38,7 @@ def test_redshifts(mocker: MockerFixture, xp: ModuleType) -> None:

def test_redshifts_from_nz(urng: UnifiedGenerator, xp: ModuleType) -> None:
if xp.__name__ == "jax.numpy":
pytest.skip(
"Arrays in redshifts_from_nz are not immutable, so do not support jax",
)
pytest.skip(f"glass.redshifts_from_nz not yet ported for {xp.__name__}")

# test sampling

Expand Down