diff --git a/pyproject.toml b/pyproject.toml index fc00292..d3f08c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ requires-python = ">=3.10" dependencies = [ "xarray>=2025.0", "cacheout>=0.16.0", - "arcae>=0.3.2", + "arcae>=0.4.0a3", "typing-extensions>=4.12.2", ] diff --git a/tests/conftest.py b/tests/conftest.py index 74c029a..9124b54 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,5 @@ +import gc + import numpy as np import pytest from arcae.lib.arrow_tables import Table, ms_descriptor @@ -38,14 +40,18 @@ def pytest_collection_modifyitems(config, items): @pytest.fixture(autouse=True) def clear_caches(): yield - Multiton._INSTANCE_CACHE.clear() + + # Structure Factories have references to Multitons MSv2StructureFactory._STRUCTURE_CACHE.clear() + Multiton._INSTANCE_CACHE.clear() + gc.collect() @pytest.fixture(scope="session", params=[DEFAULT_SIM_PARAMS]) def simmed_ms(request, tmp_path_factory): - ms = tmp_path_factory.mktemp("simulated") / request.param.pop("name", "test.ms") - simulator = MSStructureSimulator(**{**DEFAULT_SIM_PARAMS, **request.param}) + params = request.param.copy() + ms = tmp_path_factory.mktemp("simulated") / params.pop("name", "test.ms") + simulator = MSStructureSimulator(**{**DEFAULT_SIM_PARAMS, **params}) simulator.simulate_ms(str(ms)) return str(ms) diff --git a/tests/test_write.py b/tests/test_write.py new file mode 100644 index 0000000..51be7c1 --- /dev/null +++ b/tests/test_write.py @@ -0,0 +1,158 @@ +from contextlib import ExitStack + +import arcae +import numpy as np +import pytest +import xarray +from xarray import Dataset, DataTree + +from xarray_ms.backend.msv2.writes import dataset_to_msv2, datatree_to_msv2 +from xarray_ms.errors import MismatchedWriteRegion +from xarray_ms.msv4_types import CORRELATED_DATASET_TYPES + + +@pytest.mark.parametrize("simmed_ms", [{"name": "test_store.ms"}], indirect=True) +def test_store(monkeypatch, simmed_ms): + monkeypatch.setattr(Dataset, "to_msv2", dataset_to_msv2, raising=False) + monkeypatch.setattr(DataTree, "to_msv2", datatree_to_msv2, raising=False) + + read = written = False + + with xarray.open_datatree(simmed_ms, auto_corrs=True) as xdt: + # Overwrite UVW coordinates with zeroes + # Add a CORRECTED column + for node in xdt.subtree: + if node.attrs.get("type") in CORRELATED_DATASET_TYPES: + assert not np.all(node.UVW == 0) + node.UVW[:] = 0 + assert len(node.encoding) > 0 + ds = node.ds.assign(CORRECTED=xarray.full_like(node.VISIBILITY, 2 + 3j)) + xdt[node.path] = DataTree(ds) + assert len(node.encoding) > 0 + + xdt.to_msv2(["UVW", "CORRECTED"]) + written = written or True + + with xarray.open_datatree(simmed_ms, auto_corrs=True) as xdt: + for node in xdt.subtree: + if node.attrs.get("type") in CORRELATED_DATASET_TYPES: + assert np.all(node.UVW == 0) + # Non-standard columns aren't yet exposed + # assert np.all(node.CORRECTED == 1 + 2j) + read = read or True + + assert read + assert written + + # But we can check that CORRECTED has been written correctly + with arcae.table(simmed_ms) as T: + np.testing.assert_array_equal(T.getcol("CORRECTED"), 2 + 3j) + + +@pytest.mark.parametrize("simmed_ms", [{"name": "test_store_region.ms"}], indirect=True) +def test_store_region(monkeypatch, simmed_ms): + monkeypatch.setattr(Dataset, "to_msv2", dataset_to_msv2, raising=False) + monkeypatch.setattr(DataTree, "to_msv2", datatree_to_msv2, raising=False) + + region = {"time": slice(0, 2), "frequency": slice(2, 4)} + + with xarray.open_datatree(simmed_ms, auto_corrs=True) as xdt: + # Add a CORRECTED column + for node in xdt.subtree: + if node.attrs.get("type") in CORRELATED_DATASET_TYPES: + ds = node.ds.assign(CORRECTED=xarray.zeros_like(node.VISIBILITY)) + xdt[node.path] = DataTree(ds) + assert len(node.encoding) > 0 + + # Create the new MS columns + xdt.to_msv2(["CORRECTED"], compute=False) + + for node in xdt.subtree: + if node.attrs.get("type") in CORRELATED_DATASET_TYPES: + sizes = node.sizes + ds = ds.isel(**region) + ds = ds.assign(CORRECTED=xarray.full_like(ds.CORRECTED, 1 + 2j)) + # Now write it out + ds.to_msv2(["CORRECTED"], compute=False, region=region) + + # We can check that CORRECTED has been written correctly + with arcae.table(simmed_ms) as T: + corrected = T.getcol("CORRECTED") + nt, nbl, nf, npol = ( + sizes[d] for d in ("time", "baseline_id", "frequency", "polarization") + ) + corrected = corrected.reshape((nt, nbl, nf, npol)) + ts, fs = (region[d] for d in ("time", "frequency")) + mask = np.full(corrected.shape, False, np.bool_) + mask[ts, :, fs, :] = True + np.testing.assert_array_equal(corrected[mask], 1 + 2j) + np.testing.assert_array_equal(corrected[~mask], 0 + 0j) + + +@pytest.mark.parametrize("chunks", [{"time": 2, "frequency": 2}]) +@pytest.mark.parametrize("simmed_ms", [{"name": "distributed-write.ms"}], indirect=True) +@pytest.mark.parametrize("nworkers", [4]) +@pytest.mark.parametrize("processes", [True, False]) +def test_distributed_write(simmed_ms, monkeypatch, processes, nworkers, chunks): + monkeypatch.setattr(Dataset, "to_msv2", dataset_to_msv2, raising=False) + monkeypatch.setattr(DataTree, "to_msv2", datatree_to_msv2, raising=False) + da = pytest.importorskip("dask.array") + distributed = pytest.importorskip("dask.distributed") + Client = distributed.Client + LocalCluster = distributed.LocalCluster + + with ExitStack() as stack: + cluster = stack.enter_context(LocalCluster(processes=processes, n_workers=nworkers)) + stack.enter_context(Client(cluster)) + dt = stack.enter_context( + xarray.open_datatree(simmed_ms, chunks=chunks, auto_corrs=True) + ) + for node in dt.subtree: + if node.attrs.get("type") in CORRELATED_DATASET_TYPES: + vis = node.VISIBILITY + sizes = node.sizes + corrected = da.arange(np.prod(vis.shape), dtype=np.int32) + corrected = corrected.reshape(vis.shape).rechunk(vis.data.chunks) + ds = node.ds.assign(CORRECTED=(vis.dims, corrected)) + dt[node.path] = DataTree(ds) + assert len(node.encoding) > 0 + + # Create the new MS columns + dt.to_msv2(["CORRECTED"], compute=False) + + for node in dt.subtree: + if node.attrs.get("type") in CORRELATED_DATASET_TYPES: + print("Writing") + node.ds.to_msv2(["CORRECTED"], compute=True) + + with arcae.table(simmed_ms) as T: + corrected = T.getcol("CORRECTED") + shape = tuple( + sizes[d] for d in ("time", "baseline_id", "frequency", "polarization") + ) + expected = np.arange(np.prod(vis.shape), dtype=np.int32) + expected = expected.reshape((-1,) + shape[2:]) + np.testing.assert_array_equal(corrected, expected) + + +@pytest.mark.parametrize("simmed_ms", [{"name": "indexed-write.ms"}], indirect=True) +def test_indexed_write(monkeypatch, simmed_ms): + """Check that we throw if we select a variable out with an integer index + and then try write that sub-selection out""" + monkeypatch.setattr(Dataset, "to_msv2", dataset_to_msv2, raising=False) + monkeypatch.setattr(DataTree, "to_msv2", datatree_to_msv2, raising=False) + dt = xarray.open_datatree(simmed_ms) + assert len(dt.children) == 1 + + for node in dt.subtree: + if node.attrs.get("type") in CORRELATED_DATASET_TYPES: + ds = node.ds.assign(CORRECTED=xarray.full_like(node.VISIBILITY, 1 + 2j)) + dt[node.path] = DataTree(ds) + + dt.to_msv2(["CORRECTED"], compute=False) + + for node in dt.subtree: + if node.attrs.get("type") in CORRELATED_DATASET_TYPES: + ds = node.ds.isel(time=slice(0, 2), baseline_id=slice(0, 2), frequency=1) + with pytest.raises(MismatchedWriteRegion): + ds.to_msv2(["CORRECTED"], compute=True) diff --git a/xarray_ms/backend/msv2/array.py b/xarray_ms/backend/msv2/array.py index bde1fe0..8a77cff 100644 --- a/xarray_ms/backend/msv2/array.py +++ b/xarray_ms/backend/msv2/array.py @@ -1,17 +1,22 @@ from __future__ import annotations +import re +from contextlib import contextmanager from functools import reduce from operator import mul -from typing import TYPE_CHECKING, Any, Callable, Tuple +from typing import TYPE_CHECKING, Any, Callable, Literal, Tuple import numpy as np from xarray.backends import BackendArray from xarray.core.indexing import ( IndexingSupport, OuterIndexer, + expanded_indexer, explicit_indexing_adapter, ) +from xarray_ms.errors import MismatchedWriteRegion + if TYPE_CHECKING: import numpy.typing as npt @@ -20,19 +25,52 @@ TransformerT = Callable[[npt.NDArray], npt.NDArray] +MAIN_PREFIX_DIMS = ("time", "baseline_id") + +DATA_COLUMN_DIM_MISMATCH_RE = re.compile( + r"Number of data dimensions (?P\d+) does not " + r"match number of column dimensions (?P\d+)" +) + -def slice_length(s: npt.NDArray | slice, max_len) -> int: +def slice_length( + s: npt.NDArray | slice, max_len: int, op: Literal["read", "write"] +) -> int: if isinstance(s, np.ndarray): if s.ndim != 1: raise NotImplementedError("Slicing with non-1D numpy arrays") return len(s) - start, stop, step = s.indices(min(max_len, s.stop) if s.stop is not None else max_len) + clamp_op = min if op == "read" else max + start, stop, step = s.indices( + clamp_op(max_len, s.stop) if s.stop is not None else max_len + ) if step != 1: raise NotImplementedError(f"Slicing with steps {s} other than 1 not supported") return stop - start +@contextmanager +def rethrow_arcae_exceptions(): + import pyarrow as pa + + try: + yield + except pa.lib.ArrowInvalid as e: + if m := re.match(DATA_COLUMN_DIM_MISMATCH_RE, str(e)): + data_dims = m.group("data") + column_dims = m.group("column") + raise MismatchedWriteRegion( + f"Attempted to write an array of dimensionality {data_dims} " + f"to a column of dimensionality {column_dims} " + f"This can occur if xarray variable dimensions " + f"have been selected out with integer indices. " + f"Prefer xarray semantics that retain dimensionality " + f"such as slicing during selection or sum(..., keepdims=True)" + ) + raise + + class MSv2Array(BackendArray): """Base MSv2Array backend array class, containing required shape and data type""" @@ -49,6 +87,9 @@ def __init__(self, shape: Tuple[int, ...], dtype: npt.DTypeLike): def __getitem__(self, key) -> npt.NDArray: raise NotImplementedError + def __setitem__(self, key, value) -> None: + raise NotImplementedError + @property def transform(self) -> TransformerT | None: raise NotImplementedError @@ -99,7 +140,7 @@ def __init__( self._default = default self._transform = transform - assert len(shape) >= 2, "(time, baseline_ids) required" + assert len(shape) >= len(MAIN_PREFIX_DIMS), f"{MAIN_PREFIX_DIMS} required" def __getitem__(self, key) -> npt.NDArray: return explicit_indexing_adapter( @@ -118,7 +159,7 @@ def promote_integer_dims(key): def _getitem(self, key) -> npt.NDArray: assert len(key) == len(self.shape) key, squeeze_axis = self.promote_integer_dims(key) - expected_shape = tuple(slice_length(k, s) for k, s in zip(key, self.shape)) + expected_shape = tuple(slice_length(k, s, "read") for k, s in zip(key, self.shape)) if reduce(mul, expected_shape, 1) == 0: return np.empty(expected_shape, dtype=self.dtype) # Map the (time, baseline_id) coordinates onto row indices @@ -126,12 +167,24 @@ def _getitem(self, key) -> npt.NDArray: row_key = (rows.ravel(),) + key[2:] row_shape = (rows.size,) + expected_shape[2:] result = np.full(row_shape, self._default, dtype=self.dtype) - self._table_factory.instance.getcol(self._column, row_key, result) + with rethrow_arcae_exceptions(): + self._table_factory.instance.getcol(self._column, row_key, result) result = result.reshape(rows.shape + expected_shape[2:]) # arcae doesn't handle squeezing out the selecting axis so we do it here. result = result.squeeze(axis=squeeze_axis) return self._transform(result) if self._transform else result + def __setitem__(self, key, value: npt.NDArray) -> None: + key = expanded_indexer(key, len(self.shape)) + expected_shape = tuple(slice_length(k, s, "write") for k, s in zip(key, self.shape)) + rows = self._structure_factory.instance[self._partition].row_map[key[:2]] + row_key = (rows.ravel(),) + key[2:] + row_shape = (rows.size,) + expected_shape[2:] + value = self._transform(value) if self._transform else value + value = value.reshape(row_shape) + with rethrow_arcae_exceptions(): + self._table_factory.instance.putcol(self._column, value, index=row_key) + @property def transform(self) -> TransformerT | None: return self._transform @@ -190,7 +243,7 @@ def __getitem__(self, key) -> npt.NDArray: ) def _getitem(self, key) -> npt.NDArray: - expected_shape = tuple(slice_length(k, s) for k, s in zip(key, self.shape)) + expected_shape = tuple(slice_length(k, s, "read") for k, s in zip(key, self.shape)) if reduce(mul, expected_shape, 1) == 0: return np.empty(expected_shape, dtype=self.dtype) low_res_key = tuple(k for i, k in zip(self._low_res_index, key) if i is not None) diff --git a/xarray_ms/backend/msv2/entrypoint.py b/xarray_ms/backend/msv2/entrypoint.py index ea3ff7c..0b795b7 100644 --- a/xarray_ms/backend/msv2/entrypoint.py +++ b/xarray_ms/backend/msv2/entrypoint.py @@ -4,7 +4,7 @@ import warnings from datetime import datetime, timezone from importlib.metadata import version as importlib_version -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Mapping import xarray from xarray.backends import BackendEntrypoint @@ -14,6 +14,7 @@ from xarray.core.datatree import DataTree from xarray.core.utils import try_read_magic_number_from_file_or_path +from xarray_ms.backend.msv2.array import MainMSv2Array from xarray_ms.backend.msv2.entrypoint_utils import CommonStoreArgs from xarray_ms.backend.msv2.factories import ( AntennaFactory, @@ -38,6 +39,9 @@ from xarray_ms.backend.msv2.structure import DEFAULT_PARTITION_COLUMNS, PartitionKeyT +WriteRegionT = Mapping[str, slice | Literal["auto"]] | Literal["auto"] + + def promote_chunks( structure: MSv2Structure, chunks: Dict | str | None ) -> Dict[PartitionKeyT, Dict[str, int]] | str | None: @@ -84,6 +88,7 @@ class MSv2Store(AbstractWritableDataStore): "_auto_corrs", "_ninstances", "_epoch", + "_write_region", ) _table_factory: Multiton @@ -95,6 +100,7 @@ class MSv2Store(AbstractWritableDataStore): _autocorrs: bool _ninstances: int _epoch: str + _write_region: WriteRegionT def __init__( self, @@ -107,6 +113,7 @@ def __init__( auto_corrs: bool, ninstances: int, epoch: str, + write_region: WriteRegionT, ): self._table_factory = table_factory self._subtable_factories = subtable_factories @@ -117,6 +124,7 @@ def __init__( self._auto_corrs = auto_corrs self._ninstances = ninstances self._epoch = epoch + self._write_region = write_region @classmethod def open( @@ -130,6 +138,7 @@ def open( ninstances: int = 1, epoch: str | None = None, structure_factory: MSv2StructureFactory | None = None, + write_region: WriteRegionT = "auto", ): if not isinstance(ms, str): raise ValueError("Measurement Sets paths must be strings") @@ -170,6 +179,7 @@ def open( auto_corrs=store_args.auto_corrs, ninstances=store_args.ninstances, epoch=store_args.epoch, + write_region=write_region, ) def close(self, **kwargs): @@ -188,9 +198,11 @@ def main_dataset_factory(self) -> CorrelatedFactory: ) def get_variables(self): + """Overrides AbstractDataStore.get_variables""" return self.main_dataset_factory().get_variables() def get_attrs(self): + """Overrides AbstractDataStore.get_attrs""" factory = self.main_dataset_factory() attrs = { @@ -206,11 +218,102 @@ def get_attrs(self): return dict(sorted({**attrs, **factory.get_attrs()}.items())) def get_dimensions(self): - return None + """Typically, this hook retrieves Dataset dimensions from the + underlying store, but in the MSv2Store implementation, these + are currently derived from the Dataset. - def get_encoding(self): + Overrides AbstractDataStore.get_dimensions. + """ return {} + def get_encoding(self): + """Encodes the arguments used to create the MSv2Store, + as well as the partition key associated with this partition + of the Measurement Set. + + Overrides AbstractDataStore.get_encoding""" + table_args = self._table_factory.arguments() + + return { + "common_store_args": { + "ms": table_args["filename"], + "ninstances": self._ninstances, + "auto_corrs": self._auto_corrs, + "epoch": self._epoch, + "partition_schema": self._partition_schema, + }, + "partition_key": self._partition_key, + } + + def set_dimensions(self, variables, unlimited_dims=None): + """Set dimensions on the store, based on the variables. + This might be a good point to add columns to the Measurement Set, + but a store only references a partition of the MS so a complete + view of all the variables that would contribute to defining an + MS column are not available. + + This logic must be performed at the higher DataTree level, + so we noop here. + + Overrides AbstractWritableDataStore.set_dimensions. + """ + pass + + def set_variables( + self, + variables: dict[str, xarray.Variable], + check_encoding_set, + writer, + unlimited_dims=None, + ): + """Set variables for writing to the store + + Overrides AbstractWritableDataStore.set_variables. + """ + for n, v in variables.items(): + target, source = self.prepare_variable( + n, v, n in check_encoding_set, unlimited_dims=unlimited_dims + ) + write_region = {} if self._write_region == "auto" else self._write_region + region = tuple(write_region.get(d, slice(None)) for d in v.dims) + writer.add(source, target, region=region) + + def prepare_variable(self, name, variable, check_encoding=False, unlimited_dims=None): + """Prepare variable for writing to the Measurement Set + + Overrides AbstractWritableDataStore.prepare_variable. + """ + target = MainMSv2Array( + self._table_factory, + self._structure_factory, + self._partition_key, + name, + variable.shape, + variable.dtype, + ) + return target, variable.data + + def set_write_region(self, ds: Dataset): + """Sets the region for writing on this store, given a source dataset""" + + region: Dict[str, slice | Literal["auto"]] = dict.fromkeys(ds.dims, "auto") + + if self._write_region == "auto": + pass + elif isinstance(self._write_region, dict): + region.update(self._write_region) + else: + raise TypeError(f"'region' ({type(region)}) should be a dict or \"auto\"") + + try: + expanded_region = { + d: slice(0, ds.sizes[d]) if v == "auto" else v for d, v in region.items() + } + except KeyError as e: + raise ValueError(f"{e} is not a dataset dimension") + + self._write_region = expanded_region + class MSv2EntryPoint(BackendEntrypoint): open_dataset_parameters = [ diff --git a/xarray_ms/backend/msv2/entrypoint_utils.py b/xarray_ms/backend/msv2/entrypoint_utils.py index c08cf20..846fc06 100644 --- a/xarray_ms/backend/msv2/entrypoint_utils.py +++ b/xarray_ms/backend/msv2/entrypoint_utils.py @@ -36,9 +36,7 @@ def subtable_factory( name: str, on_missing: Literal["raise", "empty_table"] = "empty_table" ) -> pa.Table: try: - return Table.from_filename( - name, ninstances=1, readonly=True, lockoptions="nolock" - ).to_arrow() + return Table.from_filename(name, ninstances=1, readonly=True).to_arrow() except pa.lib.ArrowInvalid as e: e_str = str(e) if "subtable" in e_str and "is invalid" in e_str: @@ -97,11 +95,7 @@ def __init__( self.partition_schema = partition_schema or DEFAULT_PARTITION_COLUMNS self.preferred_chunks = preferred_chunks or {} self.ms_factory = ms_factory or Multiton( - Table.from_filename, - self.ms, - ninstances=self.ninstances, - readonly=True, - lockoptions="nolock", + Table.from_filename, self.ms, ninstances=self.ninstances, readonly=True ) self.subtable_factories = subtable_factories or { subtable: Multiton(subtable_factory, f"{ms}::{subtable}") diff --git a/xarray_ms/backend/msv2/structure.py b/xarray_ms/backend/msv2/structure.py index f1bdb5d..ad78e9f 100644 --- a/xarray_ms/backend/msv2/structure.py +++ b/xarray_ms/backend/msv2/structure.py @@ -601,7 +601,7 @@ def read_subtable( else: return None - with Table.from_filename(subtable_path, lockoptions="nolock") as T: + with Table.from_filename(subtable_path) as T: return T.to_arrow(columns=list(columns)) def __init__( diff --git a/xarray_ms/backend/msv2/writes.py b/xarray_ms/backend/msv2/writes.py new file mode 100644 index 0000000..d28d0be --- /dev/null +++ b/xarray_ms/backend/msv2/writes.py @@ -0,0 +1,298 @@ +import warnings +from collections import defaultdict +from typing import Any, Dict, Iterable, Literal, Mapping, Set, Tuple + +import numpy as np +import numpy.typing as npt +from arcae.lib.arrow_tables import ms_descriptor +from xarray import Dataset, DataTree +from xarray.backends.api import _finalize_store, dump_to_store +from xarray.backends.common import ArrayWriter + +from xarray_ms.backend.msv2.array import MAIN_PREFIX_DIMS +from xarray_ms.backend.msv2.entrypoint import MSv2Store +from xarray_ms.backend.msv2.entrypoint_utils import CommonStoreArgs +from xarray_ms.casa_types import NUMPY_TO_CASA_MAP +from xarray_ms.errors import MissingEncodingError +from xarray_ms.msv4_types import CORRELATED_DATASET_TYPES + +ShapeSetType = Set[Tuple[int, ...]] +DTypeSetType = Set[npt.DTypeLike] +ShapeAndDTypeType = Dict[Tuple[str, str], Tuple[ShapeSetType, DTypeSetType]] + + +def validate_column_desc( + var_name: str, column: str, shapes: ShapeSetType, column_desc: Dict +) -> None: + # Validate the variable ndim against the column ndim + if (ndim := column_desc.get("ndim")) is not None: + # Multi-dimensional CASA column configuraiton + multidim = ndim == -1 + ndims = {len(s) for s in shapes} + if not multidim and len(ndims) > 1: + raise ValueError( + f"{column} descriptor specifies a fixed ndim ({ndim}) " + f"but {var_name} has multiple trailing shapes with dimensions {shapes}" + ) + elif not all(len(shape) == 0 for shape in shapes): + # ndim implies column only has a row-dimension + raise ValueError( + f"{column} descriptor specifies a row only column " + f"but {var_name} has trailing shape(s) {shapes}" + ) + + # Validate the variable shape if the column is fixed + if (fixed_shape := column_desc.get("shape")) is not None: + fixed_shape = tuple(fixed_shape) + if len(shapes) > 1: + raise ValueError( + f"Variable {var_name} has multiple trailing shapes {shapes} " + f"but {column} specifies a fixed shape {fixed_shape}" + ) + + if (var_shape := next(iter(shapes))) != fixed_shape: + raise ValueError( + f"Variable {var_name} has trailing shape {var_shape} " + f"but {column} has a fixed shape {fixed_shape}" + ) + + +def fit_tile_shape(shape: Tuple[int, ...], dtype: npt.DTypeLike) -> Dict[str, np.int32]: + """ + Args: + shape: FORTRAN ordered tile shape + dtype: tile data type + + Returns: + A :code:`{DEFAULTTILESHAPE: tile_shape}` dictionary + """ + nbytes = np.dtype(dtype).itemsize + min_tile_dims = [512] + max_tile_dims = [np.inf] + + for dim in shape: + min_tile = min(dim, 4) # Don't tile <=4 elements. + # For dims which are not exact powers of two, treat them as though + # they are floored to the nearest power of two. + max_tile = int(min(2 ** int(np.log2(dim)) / 8, 64)) + max_tile = min_tile if max_tile < min_tile else max_tile + min_tile_dims.append(min_tile) + max_tile_dims.append(max_tile) + + tile_shape = min_tile_dims.copy() + growth_axis = 0 + + while np.prod(tile_shape) * nbytes < 1024**2: # 1MB tiles. + if tile_shape[growth_axis] < max_tile_dims[growth_axis]: + tile_shape[growth_axis] *= 2 + growth_axis = (growth_axis + 1) % len(tile_shape) + + # The tile shape is C ordered + return {"DEFAULTTILESHAPE": list(tile_shape[::-1])} + + +def generate_column_descriptor( + table_desc: Dict[str, Any], + shapes_and_dtypes: ShapeAndDTypeType, +) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """ + Synthesises a column descriptor from: + + 1. An existing table descriptor. + 2. A complete table descriptor. + 3. The shapes and data types associated with + xarray Variables on a DataTree. + + Args: + table_desc: Table descriptor containing existing + column definitions. + shapes_and_dtypes: A :code:`{var_name: (set(shapes), set(dtypes))}` + mapping of shapes and data types associated with each xarray + Variable. + + Returns: + A (descriptor, dminfo) tuple containing any columns + and data managers that should be created + """ + canonical_table_desc = ms_descriptor("MAIN", complete=True) + actual_desc = {} + dm_groups = [] + + for (var_name, column), (shapes, dtypes) in shapes_and_dtypes.items(): + # If there are existing descriptors, either for + # columns present on the table, or in the canonical definition + # validate that the variable shape matches the column + if column_desc := table_desc.get(column): + validate_column_desc(var_name, column, shapes, column_desc) + elif column_desc := canonical_table_desc.get(column): + validate_column_desc(var_name, column, shapes, column_desc) + else: + # Construct a column descriptor and possibly an + # associated data manager + # Unify variable numpy types + dtype = np.result_type(*dtypes) + + if dtype is object: + raise NotImplementedError( + f"Types of variable {var_name} ({list(dtypes)}) " + f"resolves to an object. " + f"Writing of objects is not supported" + ) + + try: + casa_type = NUMPY_TO_CASA_MAP[np.dtype(dtype).type] + except KeyError as e: + raise ValueError( + f"No CASA type matched NumPy dtype {dtype}\n{NUMPY_TO_CASA_MAP}" + ) from e + + column_desc = {"valueType": casa_type, "option": 0} + + if len(shapes) == 1: + # If the shape is fixed, Tile the column + # column descriptor shapes are fortran ordered + fixed_shape = tuple(reversed(next(iter(shapes)))) + row_only = len(fixed_shape) == 0 + if not row_only: + column_desc["option"] |= 4 + column_desc["shape"] = list(fixed_shape) + column_desc["ndim"] = len(fixed_shape) + column_desc["dataManagerGroup"] = dm_group = f"{column}_GROUP" + column_desc["dataManagerType"] = dm_type = "TiledColumnStMan" + + dm_groups.append( + { + "COLUMNS": [column], + "NAME": dm_group, + "TYPE": dm_type, + "SPEC": fit_tile_shape(fixed_shape, dtype), + } + ) + else: + # Variably shaped, use a StandardStMan for now + # but consider a TiledCellStMan in future + column_desc["option"] = 0 + column_desc["dataManagerGroup"] = "StandardStMan" + column_desc["dataManagerType"] = "StandardStMan" + + actual_desc[column] = column_desc + + dminfo = {f"*{i + 1}": g for i, g in enumerate(dm_groups)} + return actual_desc, dminfo + + +def msv2_store_from_dataset(ds: Dataset, region="auto") -> MSv2Store: + try: + common_store_args = ds.encoding["common_store_args"] + partition_key = ds.encoding["partition_key"] + except KeyError as e: + raise MissingEncodingError( + f"Expected encoding key {e} is not present on " + f"a dataset of type {ds.attrs.get('type')}. " + f"Writing back to a Measurement Set " + f"is not possible without this information" + ) from e + + # Recover common arguments used to create the original store + # This will re-use existing table and structure factories + store_args = CommonStoreArgs(**common_store_args) + return MSv2Store.open( + ms=store_args.ms, + partition_schema=store_args.partition_schema, + partition_key=partition_key, + preferred_chunks=store_args.preferred_chunks, + auto_corrs=store_args.auto_corrs, + ninstances=store_args.ninstances, + epoch=store_args.epoch, + structure_factory=store_args.structure_factory, + write_region=region, + ) + + +def datatree_to_msv2( + dt: DataTree, + variables: str | Iterable[str], + compute: Literal[True] = True, + write_inherited_coords: bool = False, +): + assert isinstance(dt, DataTree) + list_var_names = [variables] if isinstance(variables, str) else list(variables) + set_var_names = set(list_var_names) + + if len(set_var_names) == 0: + warnings.warn("Empty 'variables'") + return + + vis_datasets = [ + n for n in dt.subtree if n.attrs.get("type") in CORRELATED_DATASET_TYPES + ] + + if len(vis_datasets) == 0: + warnings.warn("No visibility datasets were found on the DataTree") + return + + shapes_and_dtypes: ShapeAndDTypeType = defaultdict(lambda: (set(), set())) + + for node in vis_datasets: + assert isinstance(node, DataTree) + if not set_var_names.issubset(node.data_vars.keys()): + raise ValueError( + f"{set_var_names} are not present in all visibility DataTree nodes" + ) + + for n in set_var_names: + var = node.data_vars[n] + + if var.dims[: len(MAIN_PREFIX_DIMS)] != MAIN_PREFIX_DIMS: + raise ValueError( + f"{n} dimensions {var.dims} do not start with {MAIN_PREFIX_DIMS}" + ) + + shapes, dtypes = shapes_and_dtypes[(n, n)] + shapes.add(var.shape[len(MAIN_PREFIX_DIMS) :]) + dtypes.add(var.dtype) + + table_factory = msv2_store_from_dataset(next(iter(vis_datasets)).ds)._table_factory + table_desc = table_factory.instance.tabledesc() + column_descs, dminfo = generate_column_descriptor(table_desc, shapes_and_dtypes) + table_factory.instance.addcols(column_descs, dminfo) + assert set(column_descs.keys()).issubset(table_factory.instance.columns()) + + if compute: + for node in vis_datasets: + at_root = node is dt.root + ds = node.to_dataset(inherit=write_inherited_coords or at_root) + dataset_to_msv2(ds, list_var_names, compute) + + +def dataset_to_msv2( + ds: Dataset, + variables: str | Iterable[str], + compute: Literal[True] = True, + region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] = "auto", +): + assert isinstance(ds, Dataset) + list_vars = [variables] if isinstance(variables, str) else list(variables) + + if len(list_vars) == 0: + return + + # Strip out + # 1. variables that will not be written + # 2. coordinates + # 3. attributes + ignored_vars = set(ds.data_vars) - set(list_vars) + ds = ds.drop_vars(ignored_vars | set(ds.coords)).drop_attrs() + + msv2_store = msv2_store_from_dataset(ds, region) + msv2_store.set_write_region(ds) + writer = ArrayWriter() + dump_to_store(ds, msv2_store, writer) + writes = writer.sync(compute=compute) + + if compute: + _finalize_store(writes, msv2_store) + else: + import dask + + return dask.delayed(_finalize_store)(writes, msv2_store) diff --git a/xarray_ms/casa_types.py b/xarray_ms/casa_types.py index 8b762fc..9564f5b 100644 --- a/xarray_ms/casa_types.py +++ b/xarray_ms/casa_types.py @@ -36,6 +36,21 @@ "STRING": object, } +NUMPY_TO_CASA_MAP = { + np.bool_: "BOOL", + np.int8: "CHAR", + np.uint8: "UCHAR", + np.int16: "SHORT", + np.uint16: "USHORT", + np.int32: "INT", + np.uint32: "UINT", + np.float32: "FLOAT", + np.float64: "DOUBLE", + np.complex64: "COMPLEX", + np.complex128: "DCOMPLEX", + object: "STRING", +} + @dataclass class ColumnDesc: diff --git a/xarray_ms/errors.py b/xarray_ms/errors.py index 3e3ead6..4067b1e 100644 --- a/xarray_ms/errors.py +++ b/xarray_ms/errors.py @@ -43,6 +43,10 @@ class InvalidPartitionKey(ValueError): """Raised when a string representing a partition key is invalid""" +class MissingEncodingError(ValueError): + """Raised when encoding information is missing""" + + class PartitioningError(ValueError): """Raised when a logical error is encountered during Measurement Set partitioning""" @@ -57,3 +61,8 @@ class MissingQuantumUnits(ValueError): class MultipleQuantumUnits(ValueError): """Raised when there are multiple QuantumUnit value types in the column""" + + +class MismatchedWriteRegion(ValueError): + """Raised when attempting to write to a chunk of data whose dimensionality + does not match the target column""" diff --git a/xarray_ms/multiton.py b/xarray_ms/multiton.py index f48d3f7..aff4445 100644 --- a/xarray_ms/multiton.py +++ b/xarray_ms/multiton.py @@ -1,10 +1,10 @@ from __future__ import annotations -from typing import Any, Callable, ClassVar, Mapping, Tuple +from typing import Any, Callable, ClassVar, Dict, Mapping, Tuple from cacheout import Cache -from xarray_ms.utils import FrozenKey, normalise_args +from xarray_ms.utils import FrozenKey, function_arguments, normalise_args FactoryFunctionT = Callable[..., Any] @@ -45,6 +45,11 @@ def from_reduce_args( ) -> Multiton: return Multiton(factory, *args, **kw) + def arguments(self) -> Dict[str, Any]: + """Return a dictionary of argument values that would be applied + to the factory function signature""" + return function_arguments(self._factory, self._args, self._kw) + def __reduce__( self, ) -> Tuple[Callable, Tuple[Callable, Tuple[Any, ...], Mapping[str, Any]]]: diff --git a/xarray_ms/testing/simulator.py b/xarray_ms/testing/simulator.py index 8f1f17d..fef7daf 100644 --- a/xarray_ms/testing/simulator.py +++ b/xarray_ms/testing/simulator.py @@ -230,7 +230,7 @@ def simulate_ms(self, output_ms: str) -> None: for c in table_desc.pop("__remove_columns__", []): table_desc.pop(c, None) - with Table.ms_from_descriptor(output_ms, "MAIN", table_desc) as T: + with Table.ms_from_descriptor(output_ms, "MAIN", 1, table_desc) as T: startrow = 0 for chunk_desc in self.generate_descriptors(): @@ -350,9 +350,7 @@ def simulate_ms(self, output_ms: str) -> None: T.putcol("RELEASE_DATE", np.asarray([0.0] * self.nobs)) source_table_desc = ms_descriptor("SOURCE", complete=True) - with Table.ms_from_descriptor( - output_ms, "SOURCE", table_desc=source_table_desc - ) as T: + with Table.ms_from_descriptor(output_ms, "SOURCE", 1, source_table_desc) as T: T.addrows(self.nfield) T.putcol("TIME", np.asarray([1.0] * self.nfield)) T.putcol("INTERVAL", np.asarray([1.0] * self.nfield)) diff --git a/xarray_ms/utils.py b/xarray_ms/utils.py index 741afa9..ad84c1e 100644 --- a/xarray_ms/utils.py +++ b/xarray_ms/utils.py @@ -1,6 +1,6 @@ import inspect from collections.abc import Callable, Hashable, Mapping, Sequence, Set -from typing import Any, Tuple +from typing import Any, Dict, Tuple from numpy import ndarray @@ -51,7 +51,7 @@ def __str__(self) -> str: def normalise_args( - factory: Callable, args, kw + factory: Callable, args: Tuple, kw: Dict ) -> Tuple[Tuple[Any, ...], Mapping[str, Any]]: """Normalise args and kwargs into a hashable representation @@ -64,18 +64,27 @@ def normalise_args( tuple containing the normalised positional arguments and keyword arguments """ spec = inspect.getfullargspec(factory) - args = list(args) + list_args = list(args) for i, arg in enumerate(spec.args): - if i < len(args): + if i < len(list_args): continue elif arg in kw: - args.append(kw.pop(arg)) + list_args.append(kw.pop(arg)) elif spec.defaults and len(spec.args) - len(spec.defaults) <= i: default = spec.defaults[i - (len(spec.args) - len(spec.defaults))] - args.append(default) + list_args.append(default) - return tuple(args), kw + return tuple(list_args), kw + + +def function_arguments(fn: Callable, args: Tuple, kw: Mapping) -> Dict[str, Any]: + """Given a callable and some arguments and keywords, return a dictionary + of argument values that would be applied to the function signature""" + signature = inspect.signature(fn) + bound_arguments = signature.bind(*args, **kw) + bound_arguments.apply_defaults() + return bound_arguments.arguments def format_docstring(**subs):