Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/pscpy/psc.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def decode_psc(
if var_name in field_to_component:
for field, component in field_to_component[var_name].items(): # type: ignore[index]
data_vars[field] = ds[var_name].isel({f"comp_{var_name}": component})
ds = ds.drop_vars([var_name])
ds = ds.assign(data_vars)

if length is not None:
Expand Down
64 changes: 40 additions & 24 deletions tests/test_xarray_adios2.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,14 @@ def test_filename_4(tmp_path):
return filename


def _open_dataset(filename: os.PathLike[Any]) -> xr.Dataset:
def _open_dataset(filename: os.PathLike[Any], *, decode: bool = False) -> xr.Dataset:
ds = xr.open_dataset(filename)
if decode:
ds = _decode_dataset(ds)
return ds


def _decode_dataset(ds: xr.Dataset) -> xr.Dataset:
return pscpy.decode_psc(
ds,
species_names=["e", "i"],
Expand All @@ -81,46 +87,56 @@ def _open_dataset(filename: os.PathLike[Any]) -> xr.Dataset:


def test_open_dataset():
ds = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp")
assert "jx_ec" in ds
assert ds.coords.keys() == set({"x", "y", "z"})
assert ds.jx_ec.sizes == dict(x=1, y=128, z=512) # noqa: C408
assert np.allclose(ds.jx_ec.z, np.linspace(-25.6, 25.6, 512, endpoint=False))
ds_decoded = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp", decode=True)
assert "jx_ec" in ds_decoded
assert ds_decoded.coords.keys() == set({"x", "y", "z"})
assert ds_decoded.jx_ec.sizes == dict(x=1, y=128, z=512) # noqa: C408
assert np.allclose(
ds_decoded.jx_ec.z.data, np.linspace(-25.6, 25.6, 512, endpoint=False).data
)


def test_component():
ds = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp")
assert ds.jeh.sizes == dict(x=1, y=128, z=512, comp_jeh=9) # noqa: C408
assert np.all(ds.jeh.isel(comp_jeh=0) == ds.jx_ec)
ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp")
ds_decoded = _decode_dataset(ds_raw)
assert np.all(ds_raw.jeh.isel(dim_1_9=0).data == ds_decoded.jx_ec.data)


def test_selection():
ds = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp")
assert ds.jeh.sizes == dict(x=1, y=128, z=512, comp_jeh=9) # noqa: C408
ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp")
ds_decoded = _decode_dataset(ds_raw)
assert np.all(
ds.jeh.isel(comp_jeh=0, y=slice(0, 10), z=slice(0, 40))
== ds.jx_ec.isel(y=slice(0, 10), z=slice(0, 40))
ds_raw.jeh.isel(dim_1_9=0, dim_3_128=slice(0, 10), dim_2_512=slice(0, 40)).data
== ds_decoded.jx_ec.isel(y=slice(0, 10), z=slice(0, 40)).data
)


def test_nbytes():
ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp")
ds_decoded = _decode_dataset(ds_raw)
assert ds_decoded.nbytes == ds_decoded.nbytes


def test_computed():
ds = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp")
ds = ds.assign(jx=ds.jeh.isel(comp_jeh=0))
assert np.all(ds.jx == ds.jx_ec)
ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp")
ds_decoded = _decode_dataset(ds_raw)
ds_raw = ds_raw.assign(jx=ds_raw.jeh.isel(dim_1_9=0))
assert np.all(ds_raw.jx.data == ds_decoded.jx_ec.data)


def test_computed_via_lambda():
ds = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp")
ds = ds.assign(jx=lambda ds: ds.jeh.isel(comp_jeh=0))
assert np.all(ds.jx == ds.jx_ec)
ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp")
ds_decoded = _decode_dataset(ds_raw)
ds_raw = ds_raw.assign(jx=lambda ds: ds.jeh.isel(dim_1_9=0))
assert np.all(ds_raw.jx.data == ds_decoded.jx_ec.data)


def test_pfd_moments():
ds = _open_dataset(pscpy.sample_dir / "pfd_moments.000000400.bp")
assert "all_1st" in ds
assert ds.all_1st.sizes == dict(x=1, y=128, z=512, comp_all_1st=26) # noqa: C408
assert "rho_i" in ds
assert np.all(ds.rho_i == ds.all_1st.isel(comp_all_1st=13))
ds_raw = _open_dataset(pscpy.sample_dir / "pfd_moments.000000400.bp")
ds_decoded = _decode_dataset(ds_raw)
assert "all_1st" in ds_raw
assert "rho_i" in ds_decoded
assert np.all(ds_decoded.rho_i.data == ds_raw.all_1st.isel(dim_1_26=13).data)


def test_open_dataset_steps(test_filename):
Expand Down