From fe612a5b217fbc7bfe9c9ed552bca93927fb3471 Mon Sep 17 00:00:00 2001 From: James McClung Date: Mon, 17 Feb 2025 14:55:50 -0500 Subject: [PATCH 1/9] psc: drop original data to avoid duping during decode --- src/pscpy/psc.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/pscpy/psc.py b/src/pscpy/psc.py index 36dd00d..fb193ae 100644 --- a/src/pscpy/psc.py +++ b/src/pscpy/psc.py @@ -118,6 +118,7 @@ def decode_psc( if var_name in field_to_component: for field, component in field_to_component[var_name].items(): # type: ignore[index] data_vars[field] = ds[var_name].isel({f"comp_{var_name}": component}) + ds.drop_vars(var_name) ds = ds.assign(data_vars) if length is not None: From dad357efaf94c0de0bca9bb657f46e031203cf3b Mon Sep 17 00:00:00 2001 From: James McClung Date: Mon, 17 Feb 2025 15:04:22 -0500 Subject: [PATCH 2/9] psc: fix for drop_vars not being in-place --- src/pscpy/psc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pscpy/psc.py b/src/pscpy/psc.py index fb193ae..1e85dee 100644 --- a/src/pscpy/psc.py +++ b/src/pscpy/psc.py @@ -118,7 +118,7 @@ def decode_psc( if var_name in field_to_component: for field, component in field_to_component[var_name].items(): # type: ignore[index] data_vars[field] = ds[var_name].isel({f"comp_{var_name}": component}) - ds.drop_vars(var_name) + ds = ds.drop_vars(var_name) ds = ds.assign(data_vars) if length is not None: From 7e654150e7968aef202240c13379f8cab2580eeb Mon Sep 17 00:00:00 2001 From: James McClung Date: Mon, 17 Feb 2025 15:06:41 -0500 Subject: [PATCH 3/9] psc: fix for drop_vars taking [Hashable] --- src/pscpy/psc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pscpy/psc.py b/src/pscpy/psc.py index 1e85dee..2e3ae15 100644 --- a/src/pscpy/psc.py +++ b/src/pscpy/psc.py @@ -118,7 +118,7 @@ def decode_psc( if var_name in field_to_component: for field, component in field_to_component[var_name].items(): # type: ignore[index] data_vars[field] = ds[var_name].isel({f"comp_{var_name}": component}) - ds = ds.drop_vars(var_name) + ds = ds.drop_vars([var_name]) ds = ds.assign(data_vars) if length is not None: From c161fd491b3646fc61847c63a786d0208cc3d1b1 Mon Sep 17 00:00:00 2001 From: James McClung Date: Mon, 17 Feb 2025 15:22:17 -0500 Subject: [PATCH 4/9] test: compare decoded to raw datasets --- tests/test_xarray_adios2.py | 61 ++++++++++++++++++++++--------------- 1 file changed, 37 insertions(+), 24 deletions(-) diff --git a/tests/test_xarray_adios2.py b/tests/test_xarray_adios2.py index 74f0513..ccaf005 100644 --- a/tests/test_xarray_adios2.py +++ b/tests/test_xarray_adios2.py @@ -70,8 +70,14 @@ def test_filename_4(tmp_path): return filename -def _open_dataset(filename: os.PathLike[Any]) -> xr.Dataset: +def _open_dataset(filename: os.PathLike[Any], *, decode: bool = False) -> xr.Dataset: ds = xr.open_dataset(filename) + if decode: + ds = _decode_dataset(ds) + return ds + + +def _decode_dataset(ds: xr.Dataset) -> xr.Dataset: return pscpy.decode_psc( ds, species_names=["e", "i"], @@ -81,46 +87,53 @@ def _open_dataset(filename: os.PathLike[Any]) -> xr.Dataset: def test_open_dataset(): - ds = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") - assert "jx_ec" in ds - assert ds.coords.keys() == set({"x", "y", "z"}) - assert ds.jx_ec.sizes == dict(x=1, y=128, z=512) # noqa: C408 - assert np.allclose(ds.jx_ec.z, np.linspace(-25.6, 25.6, 512, endpoint=False)) + ds_decoded = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp", decode=True) + assert "jx_ec" in ds_decoded + assert ds_decoded.coords.keys() == set({"x", "y", "z"}) + assert ds_decoded.jx_ec.sizes == dict(x=1, y=128, z=512) # noqa: C408 + assert np.allclose( + ds_decoded.jx_ec.z, np.linspace(-25.6, 25.6, 512, endpoint=False) + ) def test_component(): - ds = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") - assert ds.jeh.sizes == dict(x=1, y=128, z=512, comp_jeh=9) # noqa: C408 - assert np.all(ds.jeh.isel(comp_jeh=0) == ds.jx_ec) + ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") + ds_decoded = _decode_dataset(ds_raw) + assert ds_raw.jeh.sizes == dict(x=1, y=128, z=512, comp_jeh=9) # noqa: C408 + assert np.all(ds_raw.jeh.isel(comp_jeh=0) == ds_decoded.jx_ec) def test_selection(): - ds = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") - assert ds.jeh.sizes == dict(x=1, y=128, z=512, comp_jeh=9) # noqa: C408 + ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") + ds_decoded = _decode_dataset(ds_raw) + assert ds_raw.jeh.sizes == dict(x=1, y=128, z=512, comp_jeh=9) # noqa: C408 assert np.all( - ds.jeh.isel(comp_jeh=0, y=slice(0, 10), z=slice(0, 40)) - == ds.jx_ec.isel(y=slice(0, 10), z=slice(0, 40)) + ds_raw.jeh.isel(comp_jeh=0, y=slice(0, 10), z=slice(0, 40)) + == ds_decoded.jx_ec.isel(y=slice(0, 10), z=slice(0, 40)) ) def test_computed(): - ds = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") - ds = ds.assign(jx=ds.jeh.isel(comp_jeh=0)) - assert np.all(ds.jx == ds.jx_ec) + ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") + ds_decoded = _decode_dataset(ds_raw) + ds_raw = ds_raw.assign(jx=ds_raw.jeh.isel(comp_jeh=0)) + assert np.all(ds_raw.jx == ds_decoded.jx_ec) def test_computed_via_lambda(): - ds = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") - ds = ds.assign(jx=lambda ds: ds.jeh.isel(comp_jeh=0)) - assert np.all(ds.jx == ds.jx_ec) + ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") + ds_decoded = _decode_dataset(ds_raw) + ds_raw = ds_raw.assign(jx=lambda ds: ds.jeh.isel(comp_jeh=0)) + assert np.all(ds_raw.jx == ds_decoded.jx_ec) def test_pfd_moments(): - ds = _open_dataset(pscpy.sample_dir / "pfd_moments.000000400.bp") - assert "all_1st" in ds - assert ds.all_1st.sizes == dict(x=1, y=128, z=512, comp_all_1st=26) # noqa: C408 - assert "rho_i" in ds - assert np.all(ds.rho_i == ds.all_1st.isel(comp_all_1st=13)) + ds_raw = _open_dataset(pscpy.sample_dir / "pfd_moments.000000400.bp") + ds_decoded = _decode_dataset(ds_raw) + assert "all_1st" in ds_raw + assert ds_raw.all_1st.sizes == dict(x=1, y=128, z=512, comp_all_1st=26) # noqa: C408 + assert "rho_i" in ds_decoded + assert np.all(ds_decoded.rho_i == ds_raw.all_1st.isel(comp_all_1st=13)) def test_open_dataset_steps(test_filename): From c8c555a6e8e7de90dc1df404390d9e04639bf825 Mon Sep 17 00:00:00 2001 From: James McClung Date: Mon, 17 Feb 2025 15:26:17 -0500 Subject: [PATCH 5/9] test: rm check for sizes of raw variables --- tests/test_xarray_adios2.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/test_xarray_adios2.py b/tests/test_xarray_adios2.py index ccaf005..033ae18 100644 --- a/tests/test_xarray_adios2.py +++ b/tests/test_xarray_adios2.py @@ -99,14 +99,12 @@ def test_open_dataset(): def test_component(): ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") ds_decoded = _decode_dataset(ds_raw) - assert ds_raw.jeh.sizes == dict(x=1, y=128, z=512, comp_jeh=9) # noqa: C408 assert np.all(ds_raw.jeh.isel(comp_jeh=0) == ds_decoded.jx_ec) def test_selection(): ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") ds_decoded = _decode_dataset(ds_raw) - assert ds_raw.jeh.sizes == dict(x=1, y=128, z=512, comp_jeh=9) # noqa: C408 assert np.all( ds_raw.jeh.isel(comp_jeh=0, y=slice(0, 10), z=slice(0, 40)) == ds_decoded.jx_ec.isel(y=slice(0, 10), z=slice(0, 40)) @@ -131,7 +129,6 @@ def test_pfd_moments(): ds_raw = _open_dataset(pscpy.sample_dir / "pfd_moments.000000400.bp") ds_decoded = _decode_dataset(ds_raw) assert "all_1st" in ds_raw - assert ds_raw.all_1st.sizes == dict(x=1, y=128, z=512, comp_all_1st=26) # noqa: C408 assert "rho_i" in ds_decoded assert np.all(ds_decoded.rho_i == ds_raw.all_1st.isel(comp_all_1st=13)) From b97444c29a4d3b2406aa44c782fd412c56e06f8f Mon Sep 17 00:00:00 2001 From: James McClung Date: Mon, 17 Feb 2025 15:27:35 -0500 Subject: [PATCH 6/9] test: fix dim names for raw ds --- tests/test_xarray_adios2.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_xarray_adios2.py b/tests/test_xarray_adios2.py index 033ae18..7676f33 100644 --- a/tests/test_xarray_adios2.py +++ b/tests/test_xarray_adios2.py @@ -99,14 +99,14 @@ def test_open_dataset(): def test_component(): ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") ds_decoded = _decode_dataset(ds_raw) - assert np.all(ds_raw.jeh.isel(comp_jeh=0) == ds_decoded.jx_ec) + assert np.all(ds_raw.jeh.isel(dim_1_9=0) == ds_decoded.jx_ec) def test_selection(): ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") ds_decoded = _decode_dataset(ds_raw) assert np.all( - ds_raw.jeh.isel(comp_jeh=0, y=slice(0, 10), z=slice(0, 40)) + ds_raw.jeh.isel(dim_1_9=0, dim_3_128=slice(0, 10), dim_2_512=slice(0, 40)) == ds_decoded.jx_ec.isel(y=slice(0, 10), z=slice(0, 40)) ) @@ -114,14 +114,14 @@ def test_selection(): def test_computed(): ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") ds_decoded = _decode_dataset(ds_raw) - ds_raw = ds_raw.assign(jx=ds_raw.jeh.isel(comp_jeh=0)) + ds_raw = ds_raw.assign(jx=ds_raw.jeh.isel(dim_1_9=0)) assert np.all(ds_raw.jx == ds_decoded.jx_ec) def test_computed_via_lambda(): ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") ds_decoded = _decode_dataset(ds_raw) - ds_raw = ds_raw.assign(jx=lambda ds: ds.jeh.isel(comp_jeh=0)) + ds_raw = ds_raw.assign(jx=lambda ds: ds.jeh.isel(dim_1_9=0)) assert np.all(ds_raw.jx == ds_decoded.jx_ec) @@ -130,7 +130,7 @@ def test_pfd_moments(): ds_decoded = _decode_dataset(ds_raw) assert "all_1st" in ds_raw assert "rho_i" in ds_decoded - assert np.all(ds_decoded.rho_i == ds_raw.all_1st.isel(comp_all_1st=13)) + assert np.all(ds_decoded.rho_i == ds_raw.all_1st.isel(dim_1_26=13)) def test_open_dataset_steps(test_filename): From 811e67ae25a76af046cfc9edbc8a4e9d62385abe Mon Sep 17 00:00:00 2001 From: James McClung Date: Mon, 17 Feb 2025 15:38:54 -0500 Subject: [PATCH 7/9] test: fix comparison of data I think this was actually a harmless problem before, too. It seems `==` of xarray datasets returns a bool, not an array, so the `np.all`s were redundant. Also, dataset `==` must have included dimension names. I broke the equality of dimension names, but not the equality of the data itself, as these tests now show. --- tests/test_xarray_adios2.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_xarray_adios2.py b/tests/test_xarray_adios2.py index 7676f33..129c62e 100644 --- a/tests/test_xarray_adios2.py +++ b/tests/test_xarray_adios2.py @@ -92,22 +92,22 @@ def test_open_dataset(): assert ds_decoded.coords.keys() == set({"x", "y", "z"}) assert ds_decoded.jx_ec.sizes == dict(x=1, y=128, z=512) # noqa: C408 assert np.allclose( - ds_decoded.jx_ec.z, np.linspace(-25.6, 25.6, 512, endpoint=False) + ds_decoded.jx_ec.z.data, np.linspace(-25.6, 25.6, 512, endpoint=False).data ) def test_component(): ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") ds_decoded = _decode_dataset(ds_raw) - assert np.all(ds_raw.jeh.isel(dim_1_9=0) == ds_decoded.jx_ec) + assert np.all(ds_raw.jeh.isel(dim_1_9=0).data == ds_decoded.jx_ec.data) def test_selection(): ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") ds_decoded = _decode_dataset(ds_raw) assert np.all( - ds_raw.jeh.isel(dim_1_9=0, dim_3_128=slice(0, 10), dim_2_512=slice(0, 40)) - == ds_decoded.jx_ec.isel(y=slice(0, 10), z=slice(0, 40)) + ds_raw.jeh.isel(dim_1_9=0, dim_3_128=slice(0, 10), dim_2_512=slice(0, 40)).data + == ds_decoded.jx_ec.isel(y=slice(0, 10), z=slice(0, 40)).data ) @@ -115,14 +115,14 @@ def test_computed(): ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") ds_decoded = _decode_dataset(ds_raw) ds_raw = ds_raw.assign(jx=ds_raw.jeh.isel(dim_1_9=0)) - assert np.all(ds_raw.jx == ds_decoded.jx_ec) + assert np.all(ds_raw.jx.data == ds_decoded.jx_ec.data) def test_computed_via_lambda(): ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") ds_decoded = _decode_dataset(ds_raw) ds_raw = ds_raw.assign(jx=lambda ds: ds.jeh.isel(dim_1_9=0)) - assert np.all(ds_raw.jx == ds_decoded.jx_ec) + assert np.all(ds_raw.jx.data == ds_decoded.jx_ec.data) def test_pfd_moments(): @@ -130,7 +130,7 @@ def test_pfd_moments(): ds_decoded = _decode_dataset(ds_raw) assert "all_1st" in ds_raw assert "rho_i" in ds_decoded - assert np.all(ds_decoded.rho_i == ds_raw.all_1st.isel(dim_1_26=13)) + assert np.all(ds_decoded.rho_i.data == ds_raw.all_1st.isel(dim_1_26=13).data) def test_open_dataset_steps(test_filename): From bc050dabccaecd9bc78ee54eb554b30a963f01f3 Mon Sep 17 00:00:00 2001 From: James McClung Date: Mon, 17 Feb 2025 15:43:02 -0500 Subject: [PATCH 8/9] test: +test_nbytes --- tests/test_xarray_adios2.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/test_xarray_adios2.py b/tests/test_xarray_adios2.py index 129c62e..8460671 100644 --- a/tests/test_xarray_adios2.py +++ b/tests/test_xarray_adios2.py @@ -105,10 +105,13 @@ def test_component(): def test_selection(): ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") ds_decoded = _decode_dataset(ds_raw) - assert np.all( - ds_raw.jeh.isel(dim_1_9=0, dim_3_128=slice(0, 10), dim_2_512=slice(0, 40)).data - == ds_decoded.jx_ec.isel(y=slice(0, 10), z=slice(0, 40)).data - ) + assert np.all(ds_raw.jeh.isel(dim_1_9=0, dim_3_128=slice(0, 10), dim_2_512=slice(0, 40)).data == ds_decoded.jx_ec.isel(y=slice(0, 10), z=slice(0, 40)).data) + + +def test_nbytes(): + ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") + ds_decoded = _decode_dataset(ds_raw) + assert ds_decoded.nbytes == ds_decoded.nbytes def test_computed(): From f6d5e5c895a02de3a43c135fd4709bdd368deab0 Mon Sep 17 00:00:00 2001 From: James McClung Date: Mon, 17 Feb 2025 15:55:20 -0500 Subject: [PATCH 9/9] test: fix formatting --- tests/test_xarray_adios2.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_xarray_adios2.py b/tests/test_xarray_adios2.py index 8460671..f852bf3 100644 --- a/tests/test_xarray_adios2.py +++ b/tests/test_xarray_adios2.py @@ -105,7 +105,10 @@ def test_component(): def test_selection(): ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp") ds_decoded = _decode_dataset(ds_raw) - assert np.all(ds_raw.jeh.isel(dim_1_9=0, dim_3_128=slice(0, 10), dim_2_512=slice(0, 40)).data == ds_decoded.jx_ec.isel(y=slice(0, 10), z=slice(0, 40)).data) + assert np.all( + ds_raw.jeh.isel(dim_1_9=0, dim_3_128=slice(0, 10), dim_2_512=slice(0, 40)).data + == ds_decoded.jx_ec.isel(y=slice(0, 10), z=slice(0, 40)).data + ) def test_nbytes():