Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ readme = "README.md"
requires-python = ">= 3.11, < 3.14"

dependencies = [
'xarray'
'xarray>=2025.09.0'
]

[project.optional-dependencies]
Expand All @@ -22,7 +22,7 @@ zarr = [
'toolviper>=0.0.12',
's3fs',
'scipy',
'xarray',
'xarray>=2025.09.0',
'zarr>=3,<4',
'pyarrow',
'psutil' # psutil is needed so large FITS images are not loaded into memory
Expand All @@ -37,7 +37,7 @@ test = [
'toolviper>=0.0.12',
's3fs',
'scipy',
'xarray',
'xarray>=2025.09.0',
'zarr>=3,<4',
'pyarrow',
'psutil',
Expand All @@ -52,7 +52,7 @@ casacore = [
'toolviper>=0.0.12',
's3fs',
'scipy',
'xarray',
'xarray>=2025.09.0',
'zarr>=3,<4',
'pyarrow',
'psutil',
Expand All @@ -64,7 +64,7 @@ interactive = [
'toolviper>=0.0.12',
's3fs',
'scipy',
'xarray',
'xarray>=2025.09.0',
'zarr>=3,<4',
'pyarrow',
'psutil',
Expand Down Expand Up @@ -94,7 +94,7 @@ all = [
'toolviper>=0.0.12',
's3fs',
'scipy',
'xarray',
'xarray>=2025.09.0',
'zarr>=3,<4',
'pyarrow',
'psutil',
Expand Down
67 changes: 67 additions & 0 deletions src/xradio/_utils/zarr/encoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""Shared zarr v3 encoding utilities (compressor + sharding)."""

import zarr.abc.codec
import zarr.codecs
import xarray as xr


def add_encoding(
xds: xr.Dataset,
compressor: zarr.abc.codec.BytesBytesCodec,
chunks: dict | None = None,
shards: dict | int | None = None,
) -> None:
"""Set zarr encoding on every data variable in *xds*.

Parameters
----------
xds :
Dataset whose data variables will have their ``.encoding`` set in-place.
compressor :
A zarr v3 ``BytesBytesCodec`` (e.g. ``zarr.codecs.ZstdCodec(level=2)``)
used for compressing inner chunks.
chunks :
Inner-chunk sizes keyed by dimension name. Missing dimensions default
to the full axis length. When *shards* is ``None`` these are the
on-disk chunk sizes; when *shards* is provided they are the inner-chunk
sizes inside each shard.
shards :
Controls zarr v3 sharding.

- ``dict[str, int]`` — per-dimension absolute shard sizes, keyed by
dimension name (same keys as *chunks*). Dimensions absent from the
dict default to the full axis length.
- ``int`` — uniform factor applied to every dimension:
``shard_size = factor × chunk_size``. Must be a positive integer;
divisibility is guaranteed by construction.
- ``None`` (default) — no sharding.
"""
if chunks is None:
chunks = xds.sizes

chunks = {**dict(xds.sizes), **chunks} # Add missing sizes if present.

for da_name in list(xds.data_vars):
da_chunks = [chunks[dim_name] for dim_name in xds[da_name].dims]
encoding = {"chunks": da_chunks, "compressors": (compressor,)}

if shards is not None:
if isinstance(shards, int):
if shards < 1:
raise ValueError(
f"Shard factor must be a positive integer, got {shards}."
)
shard_shape = [c * shards for c in da_chunks]
else:
shard_shape = [
shards.get(dim, xds.sizes[dim]) for dim in xds[da_name].dims
]
for dim, inner, outer in zip(xds[da_name].dims, da_chunks, shard_shape):
if outer % inner != 0:
raise ValueError(
f'Shard size {outer} for dimension "{dim}" must be an '
f"exact multiple of the inner chunk size {inner}."
)
encoding["shards"] = shard_shape

xds[da_name].encoding = encoding
16 changes: 15 additions & 1 deletion src/xradio/image/_util/_zarr/xds_to_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,19 @@
import logging
import numpy as np
import xarray as xr
import zarr.abc.codec
import os
from .common import _np_types, _top_level_sub_xds

from xradio._utils.zarr.config import ZARR_FORMAT


def _write_zarr(xds: xr.Dataset, zarr_store: str):
def _write_zarr(
xds: xr.Dataset,
zarr_store: str,
compressor: zarr.abc.codec.BytesBytesCodec | None = None,
shards: dict | int | None = None,
):
max_chunk_size = 0.95 * 2**30
for dv in xds.data_vars:
obj = xds[dv]
Expand All @@ -26,6 +32,14 @@ def _write_zarr(xds: xr.Dataset, zarr_store: str):
f"by at least a factor of {chunk_size_bytes/max_chunk_size}."
)
xds_copy = xds.copy(deep=True)
if compressor is not None or shards is not None:
import zarr.codecs
from xradio._utils.zarr.encoding import add_encoding

_compressor = (
compressor if compressor is not None else zarr.codecs.ZstdCodec(level=2)
)
add_encoding(xds_copy, compressor=_compressor, shards=shards)
sub_xds_dict = _encode(xds_copy, zarr_store)
z_obj = xds_copy.to_zarr(store=zarr_store, compute=True, zarr_format=ZARR_FORMAT)
if sub_xds_dict:
Expand Down
163 changes: 49 additions & 114 deletions src/xradio/image/_util/_zarr/zarr_low_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,29 +265,13 @@ def create_data_variable_meta_data(
):
zarr_meta = data_variables_and_dims

fs, items = _get_file_system_and_items(zarr_group_name)
from xradio._utils.zarr.config import ZARR_FORMAT

for data_variable_key, dims_dtype_name in data_variables_and_dims.items():
# print(data_variable_key, dims_dtype_name)

dims = dims_dtype_name["dims"]
dtype = dims_dtype_name["dtype"]
data_variable_name = dims_dtype_name["name"]

data_variable_path = os.path.join(zarr_group_name, data_variable_name)
if isinstance(fs, s3fs.core.S3FileSystem):
# N.b.,stateful "folder creation" is not a well defined concept for S3 objects and URIs
# see https://github.com/fsspec/s3fs/issues/401
# nor is a path specifier (cf. "URI")
fs.mkdir(data_variable_path)
else:
# default to assuming we can use the os module and mkdir system call
os.system("mkdir " + data_variable_path)
# Create .zattrs
zattrs = {
"_ARRAY_DIMENSIONS": dims,
# "coordinates": "time declination right_ascension"
}

shape = []
chunks = []
Expand All @@ -298,117 +282,68 @@ def create_data_variable_meta_data(
else:
chunks.append(xds_dims[d])

# print(chunks,shape)
# assuming data_variable_path has been set compatibly
zattrs_file = os.path.join(data_variable_path, ".zattrs")

if isinstance(fs, s3fs.core.S3FileSystem):
with fs.open(zattrs_file, "w") as file:
json.dump(
zattrs,
file,
indent=4,
sort_keys=True,
ensure_ascii=True,
separators=(",", ": "),
cls=NumberEncoder,
)
else:
# default to assuming we can use primitives
write_json_file(zattrs, zattrs_file)

# Create .zarray
from zarr import n5
fill_value = "NaN" if "f" in dtype else None

compressor_config = n5.compressor_config_to_zarr(
n5.compressor_config_to_n5(compressor.get_config())
# Determine codec pipeline for zarr array metadata.
# compressor must be a zarr v3 BytesBytesCodec (e.g. ZstdCodec).
if isinstance(compressor, zarr.abc.codec.BytesBytesCodec):
codecs = [zarr.codecs.BytesCodec(), compressor]
else:
codecs = None # fall back to zarr defaults for unrecognised types

open_kwargs: dict = dict(
mode="w",
shape=shape,
chunks=chunks,
dtype=dtype,
fill_value=fill_value,
zarr_format=ZARR_FORMAT,
)
if codecs is not None:
open_kwargs["codecs"] = codecs

if "f" in dtype:
zarray = {
"chunks": chunks,
"compressor": compressor_config,
"dtype": dtype,
"fill_value": "NaN",
"filters": None,
"order": "C",
"shape": shape,
"zarr_format": 2,
}

else:
zarray = {
"chunks": chunks,
"compressor": compressor_config,
"dtype": dtype,
"fill_value": None,
"filters": None,
"order": "C",
"shape": shape,
"zarr_format": 2,
}
z_arr = zarr.open_array(data_variable_path, **open_kwargs)
z_arr.attrs["_ARRAY_DIMENSIONS"] = dims

zarr_meta[data_variable_key]["chunks"] = chunks
zarr_meta[data_variable_key]["shape"] = shape

# again, assuming data_variable_path has been set compatibly
zarray_file = os.path.join(data_variable_path, ".zarray")

if isinstance(fs, s3fs.core.S3FileSystem):
with fs.open(zarray_file, "w") as file:
json.dump(
zarray,
file,
indent=4,
sort_keys=True,
ensure_ascii=True,
separators=(",", ": "),
cls=NumberEncoder,
)
else:
# default to assuming we can use primitives
write_json_file(zarray, zarray_file)

return zarr_meta


def write_chunk(img_xds, meta, parallel_dims_chunk_id, compressor, image_file):
def write_chunk(img_xds, meta, parallel_dims_chunk_id, image_file):
"""Write one chunk of *img_xds* into the pre-created zarr array.

Compression is handled by zarr itself (configured in
``create_data_variable_meta_data``); this function only computes the
correct destination slice and delegates the write to the zarr API.
"""
dims = meta["dims"]
dtype = meta["dtype"]
data_variable_name = meta["name"]
chunks = meta["chunks"]
shape = meta["shape"]
chunk_name = ""
if data_variable_name in img_xds:
for d in img_xds[data_variable_name].dims:
if d in parallel_dims_chunk_id:
chunk_name = chunk_name + str(parallel_dims_chunk_id[d]) + "."
else:
chunk_name = chunk_name + "0."
chunk_name = chunk_name[:-1]

if list(img_xds[data_variable_name].shape) != list(chunks):
array = pad_array_with_nans(
img_xds[data_variable_name].values,
output_shape=chunks,
dtype=dtype,
)
else:
array = img_xds[data_variable_name].values

write_binary_blob_to_disk(
array,
file_path=os.path.join(image_file, data_variable_name, chunk_name),
compressor=compressor,
)
if data_variable_name not in img_xds:
return

array = img_xds[data_variable_name].values

# Compute the destination slice in the full zarr array.
slices = []
for i, d in enumerate(dims):
if d in parallel_dims_chunk_id:
idx = parallel_dims_chunk_id[d]
start = idx * chunks[i]
end = start + array.shape[i]
else:
start, end = 0, array.shape[i]
slices.append(slice(start, min(end, shape[i])))

# z_chunk = zarr.open(
# os.path.join(image_file, data_variable_name, chunk_name),
# mode="a",
# shape=meta["shape"],
# chunks=meta["chunks"],
# dtype=meta["dtype"],
# compressor=compressor,
# )
from xradio._utils.zarr.config import ZARR_FORMAT

# return z_chunk
z_arr = zarr.open_array(
os.path.join(image_file, data_variable_name),
mode="r+",
zarr_format=ZARR_FORMAT,
)
z_arr[tuple(slices)] = array
10 changes: 8 additions & 2 deletions src/xradio/image/_util/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@
import numpy as np
import os
import xarray as xr
import zarr.abc.codec
from ..._utils.zarr.common import _open_dataset


def _xds_to_zarr(xds: xr.Dataset, zarr_store: str):
_write_zarr(xds, zarr_store)
def _xds_to_zarr(
xds: xr.Dataset,
zarr_store: str,
compressor: zarr.abc.codec.BytesBytesCodec | None = None,
shards: dict | int | None = None,
):
_write_zarr(xds, zarr_store, compressor=compressor, shards=shards)


def _xds_from_zarr(
Expand Down
Loading
Loading