diff --git a/src/pscpy/psc.py b/src/pscpy/psc.py index 36dd00d..2e3ae15 100644 --- a/src/pscpy/psc.py +++ b/src/pscpy/psc.py @@ -118,6 +118,7 @@ def decode_psc( if var_name in field_to_component: for field, component in field_to_component[var_name].items(): # type: ignore[index] data_vars[field] = ds[var_name].isel({f"comp_{var_name}": component}) + ds = ds.drop_vars([var_name]) ds = ds.assign(data_vars) if length is not None: diff --git a/tests/test_xarray_adios2.py b/tests/test_xarray_adios2.py index 74f0513..f852bf3 100644 --- a/tests/test_xarray_adios2.py +++ b/tests/test_xarray_adios2.py @@ -70,8 +70,14 @@ def test_filename_4(tmp_path): return filename -def _open_dataset(filename: os.PathLike[Any]) -> xr.Dataset: +def _open_dataset(filename: os.PathLike[Any], *, decode: bool = False) -> xr.Dataset: ds = xr.open_dataset(filename) + if decode: + ds = _decode_dataset(ds) + return ds + + +def _decode_dataset(ds: xr.Dataset) -> xr.Dataset: return pscpy.decode_psc( ds, species_names=["e", "i"], @@ -81,46 +87,56 @@ def _open_dataset(filename: os.PathLike[Any]) -> xr.Dataset: def test_open_dataset(): - ds = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") - assert "jx_ec" in ds - assert ds.coords.keys() == set({"x", "y", "z"}) - assert ds.jx_ec.sizes == dict(x=1, y=128, z=512) # noqa: C408 - assert np.allclose(ds.jx_ec.z, np.linspace(-25.6, 25.6, 512, endpoint=False)) + ds_decoded = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp", decode=True) + assert "jx_ec" in ds_decoded + assert ds_decoded.coords.keys() == set({"x", "y", "z"}) + assert ds_decoded.jx_ec.sizes == dict(x=1, y=128, z=512) # noqa: C408 + assert np.allclose( + ds_decoded.jx_ec.z.data, np.linspace(-25.6, 25.6, 512, endpoint=False).data + ) def test_component(): - ds = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") - assert ds.jeh.sizes == dict(x=1, y=128, z=512, comp_jeh=9) # noqa: C408 - assert np.all(ds.jeh.isel(comp_jeh=0) == ds.jx_ec) + ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") + ds_decoded = _decode_dataset(ds_raw) + assert np.all(ds_raw.jeh.isel(dim_1_9=0).data == ds_decoded.jx_ec.data) def test_selection(): - ds = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") - assert ds.jeh.sizes == dict(x=1, y=128, z=512, comp_jeh=9) # noqa: C408 + ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") + ds_decoded = _decode_dataset(ds_raw) assert np.all( - ds.jeh.isel(comp_jeh=0, y=slice(0, 10), z=slice(0, 40)) - == ds.jx_ec.isel(y=slice(0, 10), z=slice(0, 40)) + ds_raw.jeh.isel(dim_1_9=0, dim_3_128=slice(0, 10), dim_2_512=slice(0, 40)).data + == ds_decoded.jx_ec.isel(y=slice(0, 10), z=slice(0, 40)).data ) +def test_nbytes(): + ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") + ds_decoded = _decode_dataset(ds_raw) + assert ds_decoded.nbytes == ds_decoded.nbytes + + def test_computed(): - ds = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") - ds = ds.assign(jx=ds.jeh.isel(comp_jeh=0)) - assert np.all(ds.jx == ds.jx_ec) + ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") + ds_decoded = _decode_dataset(ds_raw) + ds_raw = ds_raw.assign(jx=ds_raw.jeh.isel(dim_1_9=0)) + assert np.all(ds_raw.jx.data == ds_decoded.jx_ec.data) def test_computed_via_lambda(): - ds = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") - ds = ds.assign(jx=lambda ds: ds.jeh.isel(comp_jeh=0)) - assert np.all(ds.jx == ds.jx_ec) + ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") + ds_decoded = _decode_dataset(ds_raw) + ds_raw = ds_raw.assign(jx=lambda ds: ds.jeh.isel(dim_1_9=0)) + assert np.all(ds_raw.jx.data == ds_decoded.jx_ec.data) def test_pfd_moments(): - ds = _open_dataset(pscpy.sample_dir / "pfd_moments.000000400.bp") - assert "all_1st" in ds - assert ds.all_1st.sizes == dict(x=1, y=128, z=512, comp_all_1st=26) # noqa: C408 - assert "rho_i" in ds - assert np.all(ds.rho_i == ds.all_1st.isel(comp_all_1st=13)) + ds_raw = _open_dataset(pscpy.sample_dir / "pfd_moments.000000400.bp") + ds_decoded = _decode_dataset(ds_raw) + assert "all_1st" in ds_raw + assert "rho_i" in ds_decoded + assert np.all(ds_decoded.rho_i.data == ds_raw.all_1st.isel(dim_1_26=13).data) def test_open_dataset_steps(test_filename):