From 9f1da4f76b57f69575143e272ed81e8be31aafc8 Mon Sep 17 00:00:00 2001 From: Rui Xue Date: Mon, 23 Mar 2026 14:29:40 -0500 Subject: [PATCH 1/4] Add Zarr v3 sharding support for main and pointing datasets in MS conversion functions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Changes: - `encoding.py`: rewrite add_encoding() with full sharding support - new `shards` parameter accepts dict (per-dim absolute shard sizes) or int (uniform factor: shard = factor × chunk for every dim) - ShardingCodec wraps BytesCodec + compressor as inner-chunk codecs - `convert_msv2_to_processing_set.py`: expose sharding at public API - add main_shards / pointing_shards (dict | int | None) parameters - forward both params to convert_and_write_partition in serial and parallel (dask.delayed) call sites - `conversion.py`: thread sharding down to the write layer - add main_shards / pointing_shards to convert_and_write_partition - pass consolidated=False to ms_xdt.to_zarr() to suppress warning - update compressor docstring type annotation --- .../_utils/_msv2/conversion.py | 37 ++++++-- .../measurement_set/_utils/_zarr/encoding.py | 85 ++++++++++++++++++- .../convert_msv2_to_processing_set.py | 26 +++++- 3 files changed, 134 insertions(+), 14 deletions(-) diff --git a/src/xradio/measurement_set/_utils/_msv2/conversion.py b/src/xradio/measurement_set/_utils/_msv2/conversion.py index 794339a7..f7da5f39 100644 --- a/src/xradio/measurement_set/_utils/_msv2/conversion.py +++ b/src/xradio/measurement_set/_utils/_msv2/conversion.py @@ -1018,6 +1018,8 @@ def convert_and_write_partition( storage_backend="zarr", parallel_mode: str = "none", persistence_mode: str = "w-", + main_shards: dict | int | None = None, + pointing_shards: dict | int | None = None, ): """_summary_ @@ -1049,8 +1051,8 @@ def convert_and_write_partition( _description_, by default None sys_cal_interpolate : bool, optional _description_, by default None - compressor : numcodecs.abc.Codec, optional - _description_, by default numcodecs.Zstd(level=2) + compressor : zarr.abc.codec.BytesBytesCodec, optional + _description_, by default zarr.codecs.ZstdCodec(level=2) add_reshaping_indices : bool, optional _description_, by default False storage_backend : str, optional @@ -1059,6 +1061,15 @@ def convert_and_write_partition( _description_ persistence_mode: str = "w-", _description_, by default "w-" + main_shards : dict | int | None, optional + Sharding for the main dataset. Pass a ``dict`` of absolute shard sizes + keyed by dimension name, or a positive ``int`` factor so that + ``shard_size = factor × chunk_size`` for every dimension. When set, + zarr v3 ``ShardingCodec`` is used and ``main_chunksize`` defines the + inner-chunk shape within each shard. By default ``None`` (no sharding). + pointing_shards : dict | int | None, optional + Sharding for the pointing dataset, analogous to ``main_shards``. + By default ``None``. Returns ------- @@ -1119,9 +1130,9 @@ def get_observation_info(in_file, observation_id, scan_intents): "software_name": "xradio", "version": importlib.metadata.version("xradio"), }, - "creation_date": datetime.datetime.now( - datetime.timezone.utc - ).isoformat(), + "creation_date": ( + datetime.datetime.now(datetime.timezone.utc).isoformat() + ), "type": "visibility", } ) @@ -1282,7 +1293,12 @@ def get_observation_info(in_file, observation_id, scan_intents): pointing_chunksize = parse_chunksize( pointing_chunksize, "pointing", pointing_xds ) - add_encoding(pointing_xds, compressor=compressor, chunks=pointing_chunksize) + add_encoding( + pointing_xds, + compressor=compressor, + chunks=pointing_chunksize, + shards=pointing_shards, + ) logger.debug( "Time pointing (with add compressor and chunking) " + str(time.time() - start) @@ -1360,7 +1376,9 @@ def get_observation_info(in_file, observation_id, scan_intents): # xds ready, prepare to write start = time.time() - add_encoding(xds, compressor=compressor, chunks=main_chunksize) + add_encoding( + xds, compressor=compressor, chunks=main_chunksize, shards=main_shards + ) logger.debug("Time add compressor and chunk " + str(time.time() - start)) os.path.join( @@ -1417,6 +1435,7 @@ def get_observation_info(in_file, observation_id, scan_intents): store=os.path.join(out_file, ms_v4_name), mode=persistence_mode, zarr_format=ZARR_FORMAT, + consolidated=False, ) elif storage_backend == "netcdf": # xds.to_netcdf(path=file_name+"/MAIN", mode=mode) #Does not work @@ -1527,7 +1546,9 @@ def add_group_to_data_groups( "flag": "FLAG", "weight": "WEIGHT", "field_and_source": f"field_and_source_{what_group}_xds", - "description": f"Data group derived from the data column '{correlated_data_name}' of an MSv2 converted to MSv4", + "description": ( + f"Data group derived from the data column '{correlated_data_name}' of an MSv2 converted to MSv4" + ), "date": datetime.datetime.now(datetime.timezone.utc).isoformat(), } if uvw: diff --git a/src/xradio/measurement_set/_utils/_zarr/encoding.py b/src/xradio/measurement_set/_utils/_zarr/encoding.py index c22a173c..82bd2b4b 100644 --- a/src/xradio/measurement_set/_utils/_zarr/encoding.py +++ b/src/xradio/measurement_set/_utils/_zarr/encoding.py @@ -1,9 +1,88 @@ -def add_encoding(xds, compressor, chunks=None): +import zarr.codecs + + +def add_encoding( + xds, + compressor, + chunks: dict | None = None, + shards: dict | int | None = None, +) -> None: + """Set zarr encoding on every data variable in *xds*. + + Parameters + ---------- + xds : + Dataset whose data variables will have their ``.encoding`` set in-place. + compressor : + A zarr v3 ``BytesBytesCodec`` (e.g. ``zarr.codecs.ZstdCodec(level=2)``) + used for compressing inner chunks. + chunks : + Inner-chunk sizes keyed by dimension name. Missing dimensions default + to the full axis length. When *shards* is ``None`` these are the + on-disk chunk sizes; when *shards* is provided they are the inner-chunk + sizes inside each shard. + shards : + Controls zarr v3 sharding. + + - ``dict[str, int]`` — per-dimension absolute shard sizes, keyed by + dimension name (same keys as *chunks*). Dimensions absent from the + dict default to the full axis length. + - ``int`` — uniform factor applied to every dimension: + ``shard_size = factor × chunk_size``. Must be a positive integer; + divisibility is guaranteed by construction. + - ``None`` (default) — no sharding. + """ if chunks is None: chunks = xds.sizes - chunks = {**dict(xds.sizes), **chunks} # Add missing sizes if presents. + chunks = {**dict(xds.sizes), **chunks} # Add missing sizes if present. for da_name in list(xds.data_vars): da_chunks = [chunks[dim_name] for dim_name in xds[da_name].sizes] - xds[da_name].encoding = {"compressors": (compressor,), "chunks": da_chunks} + + if shards is None: + xds[da_name].encoding = {"compressors": (compressor,), "chunks": da_chunks} + else: + if isinstance(shards, int): + if shards < 1: + raise ValueError( + f"Shard factor must be a positive integer, got {shards}." + ) + shard_shape = [c * shards for c in da_chunks] + else: + shard_shape = [ + shards.get(dim_name, xds.sizes[dim_name]) + for dim_name in xds[da_name].sizes + ] + # Validate: inner chunks must divide evenly into shards. + for dim_name, inner, outer in zip( + xds[da_name].dims, da_chunks, shard_shape + ): + if outer % inner != 0: + raise ValueError( + f'Shard size {outer} for dimension "{dim_name}" must be an ' + f"exact multiple of the inner chunk size {inner}." + ) + # Each Dask write task must map to exactly one shard to avoid concurrent + # read-modify-write races on the same shard file. Only rechunk when the + # existing Dask chunk shape does not already match the shard shape; an + # unnecessary rechunk inserts merge tasks in the graph and spikes memory. + arr = xds[da_name].data + already_aligned = hasattr(arr, "chunks") and all( + # Every Dask chunk along this axis equals shard_size (the last + # chunk may be smaller at the array edge — that is fine). + all(c == s or c == xds.sizes[dim] % s for c in arr.chunks[i]) + for i, (dim, s) in enumerate(zip(xds[da_name].dims, shard_shape)) + ) + if not already_aligned: + shard_rechunk = dict(zip(xds[da_name].dims, shard_shape)) + xds[da_name] = xds[da_name].chunk(shard_rechunk) + xds[da_name].encoding = { + "codecs": [ + zarr.codecs.ShardingCodec( + chunk_shape=tuple(da_chunks), + codecs=[zarr.codecs.BytesCodec(), compressor], + ) + ], + "chunks": shard_shape, + } diff --git a/src/xradio/measurement_set/convert_msv2_to_processing_set.py b/src/xradio/measurement_set/convert_msv2_to_processing_set.py index 5ab39a58..9a587fc6 100644 --- a/src/xradio/measurement_set/convert_msv2_to_processing_set.py +++ b/src/xradio/measurement_set/convert_msv2_to_processing_set.py @@ -70,6 +70,8 @@ def convert_msv2_to_processing_set( storage_backend: Literal["zarr", "netcdf"] = "zarr", parallel_mode: Literal["none", "partition", "time"] = "none", persistence_mode: str = "w-", + main_shards: dict | int | None = None, + pointing_shards: dict | int | None = None, ): """Convert a Measurement Set v2 into a Processing Set of Measurement Set v4. @@ -101,8 +103,18 @@ def convert_msv2_to_processing_set( Whether to interpolate the time axis of the system calibration data variables (sys_cal_xds) to the time axis of the main dataset use_table_iter : bool, optional Whether to use the table iterator to read the main table of the MS v2. This should be set to True when reading datasets with large number of rows and few partitions, by default False. - compressor : numcodecs.abc.Codec, optional - The Blosc compressor to use when saving the converted data to disk using Zarr, by default numcodecs.Zstd(level=2). + compressor : zarr.abc.codec.BytesBytesCodec, optional + The codec to use when saving the converted data to disk using Zarr, by default zarr.codecs.ZstdCodec(level=2). + main_shards : dict | int | None, optional + Sharding for the main dataset. Pass a ``dict`` of absolute shard sizes + keyed by dimension name (same keys as ``main_chunksize``), or a positive + ``int`` factor so that ``shard_size = factor × chunk_size`` for every + dimension. When provided, zarr v3 ``ShardingCodec`` is used and + ``main_chunksize`` becomes the inner-chunk size within each shard. + By default ``None`` (no sharding). + pointing_shards : dict | int | None, optional + Sharding for the pointing dataset, analogous to ``main_shards``. + By default ``None``. add_reshaping_indices : bool, optional Whether to add the tidxs, bidxs and row_id variables to each partition of the main dataset. These can be used to reshape the data back to the original ordering in the MS v2. This is mainly intended for testing and debugging, by default False. storage_backend : Literal["zarr", "netcdf"], optional @@ -126,7 +138,11 @@ def convert_msv2_to_processing_set( if not str(out_file).endswith("ps.zarr"): out_file += ".ps.zarr" - ps_dt.to_zarr(store=out_file, mode=persistence_mode, zarr_format=ZARR_FORMAT) + ps_dt.to_zarr( + store=out_file, + mode=persistence_mode, + zarr_format=ZARR_FORMAT, + ) # Check `parallel_mode` is valid try: @@ -194,6 +210,8 @@ def convert_msv2_to_processing_set( compressor=compressor, parallel_mode=parallel_mode, persistence_mode=persistence_mode, + main_shards=main_shards, + pointing_shards=pointing_shards, ) ) else: @@ -216,6 +234,8 @@ def convert_msv2_to_processing_set( compressor=compressor, parallel_mode=parallel_mode, persistence_mode=persistence_mode, + main_shards=main_shards, + pointing_shards=pointing_shards, ) end_time = time.time() logger.debug( From 346a0d202ad7e30d86ebe5968f39bf827d5ade36 Mon Sep 17 00:00:00 2001 From: Rui Xue Date: Mon, 23 Mar 2026 23:35:37 -0500 Subject: [PATCH 2/4] Use Xarray native Zarr encoding and `align_chunks` for sharding - Upgrade xarray to >=2025.09.0 and refactor to use xarray's native "shards"/"chunks" encoding keys instead of constructing ShardingCodec directly. - Delegate Dask rechunking to to_zarr(align_chunks=True) instead of manual alignment. - Simplify add_encoding() logic and remove zarr.codecs import. --- pyproject.toml | 12 ++--- .../_utils/_msv2/conversion.py | 2 + .../measurement_set/_utils/_zarr/encoding.py | 46 ++++--------------- 3 files changed, 17 insertions(+), 43 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b6752d3f..fe580e98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ readme = "README.md" requires-python = ">= 3.11, < 3.14" dependencies = [ - 'xarray' + 'xarray>=2025.09.0' ] [project.optional-dependencies] @@ -22,7 +22,7 @@ zarr = [ 'toolviper>=0.0.12', 's3fs', 'scipy', - 'xarray', + 'xarray>=2025.09.0', 'zarr>=3,<4', 'pyarrow', 'psutil' # psutil is needed so large FITS images are not loaded into memory @@ -37,7 +37,7 @@ test = [ 'toolviper>=0.0.12', 's3fs', 'scipy', - 'xarray', + 'xarray>=2025.09.0', 'zarr>=3,<4', 'pyarrow', 'psutil', @@ -52,7 +52,7 @@ casacore = [ 'toolviper>=0.0.12', 's3fs', 'scipy', - 'xarray', + 'xarray>=2025.09.0', 'zarr>=3,<4', 'pyarrow', 'psutil', @@ -64,7 +64,7 @@ interactive = [ 'toolviper>=0.0.12', 's3fs', 'scipy', - 'xarray', + 'xarray>=2025.09.0', 'zarr>=3,<4', 'pyarrow', 'psutil', @@ -94,7 +94,7 @@ all = [ 'toolviper>=0.0.12', 's3fs', 'scipy', - 'xarray', + 'xarray>=2025.09.0', 'zarr>=3,<4', 'pyarrow', 'psutil', diff --git a/src/xradio/measurement_set/_utils/_msv2/conversion.py b/src/xradio/measurement_set/_utils/_msv2/conversion.py index f7da5f39..0c49e3b2 100644 --- a/src/xradio/measurement_set/_utils/_msv2/conversion.py +++ b/src/xradio/measurement_set/_utils/_msv2/conversion.py @@ -1436,6 +1436,8 @@ def get_observation_info(in_file, observation_id, scan_intents): mode=persistence_mode, zarr_format=ZARR_FORMAT, consolidated=False, + safe_chunks=True, + align_chunks=True, ) elif storage_backend == "netcdf": # xds.to_netcdf(path=file_name+"/MAIN", mode=mode) #Does not work diff --git a/src/xradio/measurement_set/_utils/_zarr/encoding.py b/src/xradio/measurement_set/_utils/_zarr/encoding.py index 82bd2b4b..cb41c508 100644 --- a/src/xradio/measurement_set/_utils/_zarr/encoding.py +++ b/src/xradio/measurement_set/_utils/_zarr/encoding.py @@ -1,6 +1,3 @@ -import zarr.codecs - - def add_encoding( xds, compressor, @@ -38,11 +35,10 @@ def add_encoding( chunks = {**dict(xds.sizes), **chunks} # Add missing sizes if present. for da_name in list(xds.data_vars): - da_chunks = [chunks[dim_name] for dim_name in xds[da_name].sizes] + da_chunks = [chunks[dim_name] for dim_name in xds[da_name].dims] + encoding = {"chunks": da_chunks, "compressors": (compressor,)} - if shards is None: - xds[da_name].encoding = {"compressors": (compressor,), "chunks": da_chunks} - else: + if shards is not None: if isinstance(shards, int): if shards < 1: raise ValueError( @@ -51,38 +47,14 @@ def add_encoding( shard_shape = [c * shards for c in da_chunks] else: shard_shape = [ - shards.get(dim_name, xds.sizes[dim_name]) - for dim_name in xds[da_name].sizes + shards.get(dim, xds.sizes[dim]) for dim in xds[da_name].dims ] - # Validate: inner chunks must divide evenly into shards. - for dim_name, inner, outer in zip( - xds[da_name].dims, da_chunks, shard_shape - ): + for dim, inner, outer in zip(xds[da_name].dims, da_chunks, shard_shape): if outer % inner != 0: raise ValueError( - f'Shard size {outer} for dimension "{dim_name}" must be an ' + f'Shard size {outer} for dimension "{dim}" must be an ' f"exact multiple of the inner chunk size {inner}." ) - # Each Dask write task must map to exactly one shard to avoid concurrent - # read-modify-write races on the same shard file. Only rechunk when the - # existing Dask chunk shape does not already match the shard shape; an - # unnecessary rechunk inserts merge tasks in the graph and spikes memory. - arr = xds[da_name].data - already_aligned = hasattr(arr, "chunks") and all( - # Every Dask chunk along this axis equals shard_size (the last - # chunk may be smaller at the array edge — that is fine). - all(c == s or c == xds.sizes[dim] % s for c in arr.chunks[i]) - for i, (dim, s) in enumerate(zip(xds[da_name].dims, shard_shape)) - ) - if not already_aligned: - shard_rechunk = dict(zip(xds[da_name].dims, shard_shape)) - xds[da_name] = xds[da_name].chunk(shard_rechunk) - xds[da_name].encoding = { - "codecs": [ - zarr.codecs.ShardingCodec( - chunk_shape=tuple(da_chunks), - codecs=[zarr.codecs.BytesCodec(), compressor], - ) - ], - "chunks": shard_shape, - } + encoding["shards"] = shard_shape + + xds[da_name].encoding = encoding From 7e2ef5c17ab00ac6d8492df82f689baada3011ea Mon Sep 17 00:00:00 2001 From: Rui Xue Date: Mon, 23 Mar 2026 23:47:32 -0500 Subject: [PATCH 3/4] Add tests for Zarr v3 sharding functionality and update encoding logic --- .../_utils/_zarr/test_encoding.py | 140 ++++++++++++++++-- 1 file changed, 131 insertions(+), 9 deletions(-) diff --git a/tests/unit/measurement_set/_utils/_zarr/test_encoding.py b/tests/unit/measurement_set/_utils/_zarr/test_encoding.py index 751bf437..591f93d9 100644 --- a/tests/unit/measurement_set/_utils/_zarr/test_encoding.py +++ b/tests/unit/measurement_set/_utils/_zarr/test_encoding.py @@ -1,9 +1,24 @@ -import numcodecs +import dask.array as da +import numpy as np +import pytest import xarray as xr +import zarr.codecs from xradio.measurement_set._utils._zarr.encoding import add_encoding -single_encoding = numcodecs.Zstd(level=2) +compressor = zarr.codecs.ZstdCodec(level=2) + + +def _make_xds(time=6, freq=8) -> xr.Dataset: + """Small 2-D dataset with a dask-backed data variable.""" + return xr.Dataset( + { + "vis": xr.DataArray( + da.zeros((time, freq), chunks=(time, freq), dtype=np.complex64), + dims=["time", "frequency"], + ), + } + ) def test_add_encoding_wo_chunks(): @@ -18,9 +33,9 @@ def test_add_encoding_wo_chunks(): } ) - add_encoding(xds, single_encoding) + add_encoding(xds, compressor) assert xds - assert xds.da.encoding == {"compressors": (single_encoding,), "chunks": [3]} + assert xds.da.encoding == {"compressors": (compressor,), "chunks": [3]} def test_add_encoding_with_wrong_chunks(): @@ -35,10 +50,9 @@ def test_add_encoding_with_wrong_chunks(): } ) - chunks_size = 1 - add_encoding(xds, single_encoding, chunks={"zz_not_there": chunks_size}) + add_encoding(xds, compressor, chunks={"zz_not_there": 1}) assert xds - assert xds.da.encoding == {"compressors": (single_encoding,), "chunks": [3]} + assert xds.da.encoding == {"compressors": (compressor,), "chunks": [3]} def test_add_encoding_with_chunks(): @@ -54,9 +68,117 @@ def test_add_encoding_with_chunks(): ) chunks_size = 1 - add_encoding(xds, single_encoding, chunks={"x": chunks_size}) + add_encoding(xds, compressor, chunks={"x": chunks_size}) assert xds assert xds.da.encoding == { - "compressors": (single_encoding,), + "compressors": (compressor,), "chunks": [chunks_size], } + + +def test_sharding_sets_shards_and_chunks(): + """With sharding, encoding has 'shards' (outer) and 'chunks' (inner).""" + xds = _make_xds(time=6, freq=8) + add_encoding( + xds, + compressor, + chunks={"time": 2, "frequency": 4}, + shards={"time": 6, "frequency": 8}, + ) + + enc = xds["vis"].encoding + assert enc["chunks"] == [2, 4] # inner chunk shape + assert enc["shards"] == [6, 8] # outer shard shape + assert enc["compressors"] == (compressor,) + assert "codecs" not in enc + + +def test_sharding_inner_chunk_defaults_to_full_axis(): + """Omitting chunks= makes inner chunks equal to full axis (1 inner chunk/shard).""" + xds = _make_xds(time=6, freq=8) + add_encoding(xds, compressor, shards={"time": 6, "frequency": 8}) + + enc = xds["vis"].encoding + assert enc["chunks"] == [6, 8] # inner == full axis + assert enc["shards"] == [6, 8] + + +def test_sharding_absent_dim_defaults_to_full_axis(): + """A dimension absent from shards= gets shard size == full axis length.""" + xds = _make_xds(time=6, freq=8) + add_encoding( + xds, compressor, chunks={"time": 2, "frequency": 4}, shards={"time": 6} + ) + + enc = xds["vis"].encoding + assert enc["chunks"] == [2, 4] # inner chunks + assert enc["shards"] == [6, 8] # frequency shard spans full axis (8) + + +def test_sharding_compressor_always_present(): + """The compressor is in encoding even when sharding is enabled.""" + xds = _make_xds(time=6, freq=8) + add_encoding( + xds, + compressor, + chunks={"time": 2, "frequency": 4}, + shards={"time": 6, "frequency": 8}, + ) + + assert xds["vis"].encoding["compressors"] == (compressor,) + + +def test_sharding_raises_on_non_divisible_inner_chunk(): + """ValueError when inner chunk does not divide evenly into shard.""" + xds = _make_xds(time=6, freq=8) + with pytest.raises(ValueError, match="exact multiple"): + add_encoding(xds, compressor, chunks={"time": 4}, shards={"time": 6}) + + +def test_no_sharding_unchanged_when_shards_is_none(): + """Passing shards=None produces plain encoding without 'shards' key.""" + xds = _make_xds(time=6, freq=8) + add_encoding(xds, compressor, chunks={"time": 2, "frequency": 4}, shards=None) + + enc = xds["vis"].encoding + assert "shards" not in enc + assert enc == {"compressors": (compressor,), "chunks": [2, 4]} + + +def test_sharding_factor_sets_shard_shape(): + """Integer factor: shard_size == factor * chunk_size for every dim.""" + xds = _make_xds(time=6, freq=8) + add_encoding(xds, compressor, chunks={"time": 2, "frequency": 2}, shards=3) + + enc = xds["vis"].encoding + assert enc["chunks"] == [2, 2] # inner chunks unchanged + assert enc["shards"] == [6, 6] # 3 * 2 = 6 for both dims + assert enc["compressors"] == (compressor,) + + +def test_sharding_factor_shard_equals_factor_times_chunk(): + """Shard sizes are exactly factor * chunk for every dim, uncapped.""" + xds = _make_xds(time=6, freq=8) + add_encoding(xds, compressor, chunks={"time": 2, "frequency": 2}, shards=2) + + enc = xds["vis"].encoding + assert enc["chunks"] == [2, 2] # inner + assert enc["shards"] == [4, 4] # 2*2=4 for both dims + + +def test_sharding_factor_divisibility_always_holds(): + """Factor path must never raise a divisibility error.""" + for factor in (1, 2, 3): + xds = _make_xds(time=6, freq=8) + add_encoding(xds, compressor, chunks={"time": 1, "frequency": 1}, shards=factor) + assert xds["vis"].encoding["chunks"] == [1, 1] + assert xds["vis"].encoding["shards"] == [factor, factor] + + +def test_sharding_factor_invalid_raises(): + """Non-positive factor must raise ValueError.""" + xds = _make_xds(time=6, freq=8) + with pytest.raises(ValueError, match="positive integer"): + add_encoding(xds, compressor, shards=0) + with pytest.raises(ValueError, match="positive integer"): + add_encoding(xds, compressor, shards=-2) From b48bb421c0ee100bc0b79a17cbc64f812f2719ff Mon Sep 17 00:00:00 2001 From: Rui Xue Date: Thu, 2 Apr 2026 23:21:52 -0500 Subject: [PATCH 4/4] Implement Zarr v3 encoding utilities and integrate sharding support in image writing functions --- src/xradio/_utils/zarr/encoding.py | 67 +++++++ src/xradio/image/_util/_zarr/xds_to_zarr.py | 16 +- .../image/_util/_zarr/zarr_low_level.py | 163 ++++++------------ src/xradio/image/_util/zarr.py | 10 +- src/xradio/image/image.py | 21 ++- .../measurement_set/_utils/_zarr/encoding.py | 61 +------ 6 files changed, 159 insertions(+), 179 deletions(-) create mode 100644 src/xradio/_utils/zarr/encoding.py diff --git a/src/xradio/_utils/zarr/encoding.py b/src/xradio/_utils/zarr/encoding.py new file mode 100644 index 00000000..89dccc1e --- /dev/null +++ b/src/xradio/_utils/zarr/encoding.py @@ -0,0 +1,67 @@ +"""Shared zarr v3 encoding utilities (compressor + sharding).""" + +import zarr.abc.codec +import zarr.codecs +import xarray as xr + + +def add_encoding( + xds: xr.Dataset, + compressor: zarr.abc.codec.BytesBytesCodec, + chunks: dict | None = None, + shards: dict | int | None = None, +) -> None: + """Set zarr encoding on every data variable in *xds*. + + Parameters + ---------- + xds : + Dataset whose data variables will have their ``.encoding`` set in-place. + compressor : + A zarr v3 ``BytesBytesCodec`` (e.g. ``zarr.codecs.ZstdCodec(level=2)``) + used for compressing inner chunks. + chunks : + Inner-chunk sizes keyed by dimension name. Missing dimensions default + to the full axis length. When *shards* is ``None`` these are the + on-disk chunk sizes; when *shards* is provided they are the inner-chunk + sizes inside each shard. + shards : + Controls zarr v3 sharding. + + - ``dict[str, int]`` — per-dimension absolute shard sizes, keyed by + dimension name (same keys as *chunks*). Dimensions absent from the + dict default to the full axis length. + - ``int`` — uniform factor applied to every dimension: + ``shard_size = factor × chunk_size``. Must be a positive integer; + divisibility is guaranteed by construction. + - ``None`` (default) — no sharding. + """ + if chunks is None: + chunks = xds.sizes + + chunks = {**dict(xds.sizes), **chunks} # Add missing sizes if present. + + for da_name in list(xds.data_vars): + da_chunks = [chunks[dim_name] for dim_name in xds[da_name].dims] + encoding = {"chunks": da_chunks, "compressors": (compressor,)} + + if shards is not None: + if isinstance(shards, int): + if shards < 1: + raise ValueError( + f"Shard factor must be a positive integer, got {shards}." + ) + shard_shape = [c * shards for c in da_chunks] + else: + shard_shape = [ + shards.get(dim, xds.sizes[dim]) for dim in xds[da_name].dims + ] + for dim, inner, outer in zip(xds[da_name].dims, da_chunks, shard_shape): + if outer % inner != 0: + raise ValueError( + f'Shard size {outer} for dimension "{dim}" must be an ' + f"exact multiple of the inner chunk size {inner}." + ) + encoding["shards"] = shard_shape + + xds[da_name].encoding = encoding diff --git a/src/xradio/image/_util/_zarr/xds_to_zarr.py b/src/xradio/image/_util/_zarr/xds_to_zarr.py index 509c2de7..71316101 100644 --- a/src/xradio/image/_util/_zarr/xds_to_zarr.py +++ b/src/xradio/image/_util/_zarr/xds_to_zarr.py @@ -2,13 +2,19 @@ import logging import numpy as np import xarray as xr +import zarr.abc.codec import os from .common import _np_types, _top_level_sub_xds from xradio._utils.zarr.config import ZARR_FORMAT -def _write_zarr(xds: xr.Dataset, zarr_store: str): +def _write_zarr( + xds: xr.Dataset, + zarr_store: str, + compressor: zarr.abc.codec.BytesBytesCodec | None = None, + shards: dict | int | None = None, +): max_chunk_size = 0.95 * 2**30 for dv in xds.data_vars: obj = xds[dv] @@ -26,6 +32,14 @@ def _write_zarr(xds: xr.Dataset, zarr_store: str): f"by at least a factor of {chunk_size_bytes/max_chunk_size}." ) xds_copy = xds.copy(deep=True) + if compressor is not None or shards is not None: + import zarr.codecs + from xradio._utils.zarr.encoding import add_encoding + + _compressor = ( + compressor if compressor is not None else zarr.codecs.ZstdCodec(level=2) + ) + add_encoding(xds_copy, compressor=_compressor, shards=shards) sub_xds_dict = _encode(xds_copy, zarr_store) z_obj = xds_copy.to_zarr(store=zarr_store, compute=True, zarr_format=ZARR_FORMAT) if sub_xds_dict: diff --git a/src/xradio/image/_util/_zarr/zarr_low_level.py b/src/xradio/image/_util/_zarr/zarr_low_level.py index d0a7bf5d..3cbce2d2 100644 --- a/src/xradio/image/_util/_zarr/zarr_low_level.py +++ b/src/xradio/image/_util/_zarr/zarr_low_level.py @@ -265,29 +265,13 @@ def create_data_variable_meta_data( ): zarr_meta = data_variables_and_dims - fs, items = _get_file_system_and_items(zarr_group_name) + from xradio._utils.zarr.config import ZARR_FORMAT for data_variable_key, dims_dtype_name in data_variables_and_dims.items(): - # print(data_variable_key, dims_dtype_name) - dims = dims_dtype_name["dims"] dtype = dims_dtype_name["dtype"] data_variable_name = dims_dtype_name["name"] - data_variable_path = os.path.join(zarr_group_name, data_variable_name) - if isinstance(fs, s3fs.core.S3FileSystem): - # N.b.,stateful "folder creation" is not a well defined concept for S3 objects and URIs - # see https://github.com/fsspec/s3fs/issues/401 - # nor is a path specifier (cf. "URI") - fs.mkdir(data_variable_path) - else: - # default to assuming we can use the os module and mkdir system call - os.system("mkdir " + data_variable_path) - # Create .zattrs - zattrs = { - "_ARRAY_DIMENSIONS": dims, - # "coordinates": "time declination right_ascension" - } shape = [] chunks = [] @@ -298,117 +282,68 @@ def create_data_variable_meta_data( else: chunks.append(xds_dims[d]) - # print(chunks,shape) - # assuming data_variable_path has been set compatibly - zattrs_file = os.path.join(data_variable_path, ".zattrs") - - if isinstance(fs, s3fs.core.S3FileSystem): - with fs.open(zattrs_file, "w") as file: - json.dump( - zattrs, - file, - indent=4, - sort_keys=True, - ensure_ascii=True, - separators=(",", ": "), - cls=NumberEncoder, - ) - else: - # default to assuming we can use primitives - write_json_file(zattrs, zattrs_file) - - # Create .zarray - from zarr import n5 + fill_value = "NaN" if "f" in dtype else None - compressor_config = n5.compressor_config_to_zarr( - n5.compressor_config_to_n5(compressor.get_config()) + # Determine codec pipeline for zarr array metadata. + # compressor must be a zarr v3 BytesBytesCodec (e.g. ZstdCodec). + if isinstance(compressor, zarr.abc.codec.BytesBytesCodec): + codecs = [zarr.codecs.BytesCodec(), compressor] + else: + codecs = None # fall back to zarr defaults for unrecognised types + + open_kwargs: dict = dict( + mode="w", + shape=shape, + chunks=chunks, + dtype=dtype, + fill_value=fill_value, + zarr_format=ZARR_FORMAT, ) + if codecs is not None: + open_kwargs["codecs"] = codecs - if "f" in dtype: - zarray = { - "chunks": chunks, - "compressor": compressor_config, - "dtype": dtype, - "fill_value": "NaN", - "filters": None, - "order": "C", - "shape": shape, - "zarr_format": 2, - } - - else: - zarray = { - "chunks": chunks, - "compressor": compressor_config, - "dtype": dtype, - "fill_value": None, - "filters": None, - "order": "C", - "shape": shape, - "zarr_format": 2, - } + z_arr = zarr.open_array(data_variable_path, **open_kwargs) + z_arr.attrs["_ARRAY_DIMENSIONS"] = dims zarr_meta[data_variable_key]["chunks"] = chunks zarr_meta[data_variable_key]["shape"] = shape - # again, assuming data_variable_path has been set compatibly - zarray_file = os.path.join(data_variable_path, ".zarray") - - if isinstance(fs, s3fs.core.S3FileSystem): - with fs.open(zarray_file, "w") as file: - json.dump( - zarray, - file, - indent=4, - sort_keys=True, - ensure_ascii=True, - separators=(",", ": "), - cls=NumberEncoder, - ) - else: - # default to assuming we can use primitives - write_json_file(zarray, zarray_file) - return zarr_meta -def write_chunk(img_xds, meta, parallel_dims_chunk_id, compressor, image_file): +def write_chunk(img_xds, meta, parallel_dims_chunk_id, image_file): + """Write one chunk of *img_xds* into the pre-created zarr array. + + Compression is handled by zarr itself (configured in + ``create_data_variable_meta_data``); this function only computes the + correct destination slice and delegates the write to the zarr API. + """ dims = meta["dims"] - dtype = meta["dtype"] data_variable_name = meta["name"] chunks = meta["chunks"] shape = meta["shape"] - chunk_name = "" - if data_variable_name in img_xds: - for d in img_xds[data_variable_name].dims: - if d in parallel_dims_chunk_id: - chunk_name = chunk_name + str(parallel_dims_chunk_id[d]) + "." - else: - chunk_name = chunk_name + "0." - chunk_name = chunk_name[:-1] - - if list(img_xds[data_variable_name].shape) != list(chunks): - array = pad_array_with_nans( - img_xds[data_variable_name].values, - output_shape=chunks, - dtype=dtype, - ) - else: - array = img_xds[data_variable_name].values - write_binary_blob_to_disk( - array, - file_path=os.path.join(image_file, data_variable_name, chunk_name), - compressor=compressor, - ) + if data_variable_name not in img_xds: + return + + array = img_xds[data_variable_name].values + + # Compute the destination slice in the full zarr array. + slices = [] + for i, d in enumerate(dims): + if d in parallel_dims_chunk_id: + idx = parallel_dims_chunk_id[d] + start = idx * chunks[i] + end = start + array.shape[i] + else: + start, end = 0, array.shape[i] + slices.append(slice(start, min(end, shape[i]))) - # z_chunk = zarr.open( - # os.path.join(image_file, data_variable_name, chunk_name), - # mode="a", - # shape=meta["shape"], - # chunks=meta["chunks"], - # dtype=meta["dtype"], - # compressor=compressor, - # ) + from xradio._utils.zarr.config import ZARR_FORMAT - # return z_chunk + z_arr = zarr.open_array( + os.path.join(image_file, data_variable_name), + mode="r+", + zarr_format=ZARR_FORMAT, + ) + z_arr[tuple(slices)] = array diff --git a/src/xradio/image/_util/zarr.py b/src/xradio/image/_util/zarr.py index c5357f80..a1ef6f08 100644 --- a/src/xradio/image/_util/zarr.py +++ b/src/xradio/image/_util/zarr.py @@ -3,11 +3,17 @@ import numpy as np import os import xarray as xr +import zarr.abc.codec from ..._utils.zarr.common import _open_dataset -def _xds_to_zarr(xds: xr.Dataset, zarr_store: str): - _write_zarr(xds, zarr_store) +def _xds_to_zarr( + xds: xr.Dataset, + zarr_store: str, + compressor: zarr.abc.codec.BytesBytesCodec | None = None, + shards: dict | int | None = None, +): + _write_zarr(xds, zarr_store, compressor=compressor, shards=shards) def _xds_from_zarr( diff --git a/src/xradio/image/image.py b/src/xradio/image/image.py index 9805a333..143365e4 100755 --- a/src/xradio/image/image.py +++ b/src/xradio/image/image.py @@ -11,6 +11,7 @@ import shutil import toolviper.utils.logger as logger import xarray as xr +import zarr.abc.codec from xradio.image._util.image_factory import ( _make_empty_aperture_image, @@ -192,7 +193,12 @@ def load_image(store: str, block_des: dict = None, do_sky_coords=True) -> xr.Dat def write_image( - xds: xr.Dataset, imagename: str, out_format: str = "casa", overwrite: bool = False + xds: xr.Dataset, + imagename: str, + out_format: str = "casa", + overwrite: bool = False, + compressor: zarr.abc.codec.BytesBytesCodec | None = None, + shards: dict | int | None = None, ) -> None: """ TODO: I think the user should be permitted to specify data groups to write. @@ -209,6 +215,17 @@ def write_image( Format of output image, currently "casa" and "zarr" are supported overwrite : bool If True, overwrite existing image. Default is False. + compressor : zarr.abc.codec.BytesBytesCodec | None + Codec used to compress inner chunks. Only used when ``out_format='zarr'``. + Defaults to ``zarr.codecs.ZstdCodec(level=2)`` when ``shards`` is provided + and ``None`` means use xarray/zarr defaults. + shards : dict | int | None + Sharding specification. Pass a ``dict`` of absolute shard sizes keyed + by dimension name (e.g. ``{'l': 512, 'm': 512}``), or a positive ``int`` + factor so that ``shard_size = factor × chunk_size`` for every dimension. + Only used when ``out_format='zarr'``. Requires zarr v3. When ``None`` + (default) no sharding is applied. + Returns ------- None @@ -232,7 +249,7 @@ def write_image( _xds_to_multiple_casa_images(xds, imagename) elif my_format == "zarr": - _xds_to_zarr(xds, imagename) + _xds_to_zarr(xds, imagename, compressor=compressor, shards=shards) else: raise ValueError( f"Writing to format {out_format} is not supported. " diff --git a/src/xradio/measurement_set/_utils/_zarr/encoding.py b/src/xradio/measurement_set/_utils/_zarr/encoding.py index cb41c508..7e3ef13b 100644 --- a/src/xradio/measurement_set/_utils/_zarr/encoding.py +++ b/src/xradio/measurement_set/_utils/_zarr/encoding.py @@ -1,60 +1 @@ -def add_encoding( - xds, - compressor, - chunks: dict | None = None, - shards: dict | int | None = None, -) -> None: - """Set zarr encoding on every data variable in *xds*. - - Parameters - ---------- - xds : - Dataset whose data variables will have their ``.encoding`` set in-place. - compressor : - A zarr v3 ``BytesBytesCodec`` (e.g. ``zarr.codecs.ZstdCodec(level=2)``) - used for compressing inner chunks. - chunks : - Inner-chunk sizes keyed by dimension name. Missing dimensions default - to the full axis length. When *shards* is ``None`` these are the - on-disk chunk sizes; when *shards* is provided they are the inner-chunk - sizes inside each shard. - shards : - Controls zarr v3 sharding. - - - ``dict[str, int]`` — per-dimension absolute shard sizes, keyed by - dimension name (same keys as *chunks*). Dimensions absent from the - dict default to the full axis length. - - ``int`` — uniform factor applied to every dimension: - ``shard_size = factor × chunk_size``. Must be a positive integer; - divisibility is guaranteed by construction. - - ``None`` (default) — no sharding. - """ - if chunks is None: - chunks = xds.sizes - - chunks = {**dict(xds.sizes), **chunks} # Add missing sizes if present. - - for da_name in list(xds.data_vars): - da_chunks = [chunks[dim_name] for dim_name in xds[da_name].dims] - encoding = {"chunks": da_chunks, "compressors": (compressor,)} - - if shards is not None: - if isinstance(shards, int): - if shards < 1: - raise ValueError( - f"Shard factor must be a positive integer, got {shards}." - ) - shard_shape = [c * shards for c in da_chunks] - else: - shard_shape = [ - shards.get(dim, xds.sizes[dim]) for dim in xds[da_name].dims - ] - for dim, inner, outer in zip(xds[da_name].dims, da_chunks, shard_shape): - if outer % inner != 0: - raise ValueError( - f'Shard size {outer} for dimension "{dim}" must be an ' - f"exact multiple of the inner chunk size {inner}." - ) - encoding["shards"] = shard_shape - - xds[da_name].encoding = encoding +from xradio._utils.zarr.encoding import add_encoding # noqa: F401