From 0f0a2dfbe8559f14314e250205d0d30cca03db7a Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Wed, 21 May 2025 16:25:54 +0200 Subject: [PATCH 01/23] Limited write support: initial support --- tests/test_read.py | 1 + tests/test_write.py | 97 ++++++++++++++++++++++ xarray_ms/backend/msv2/array.py | 19 ++++- xarray_ms/backend/msv2/entrypoint.py | 32 ++++++- xarray_ms/backend/msv2/entrypoint_utils.py | 12 ++- xarray_ms/errors.py | 4 + xarray_ms/multiton.py | 9 +- xarray_ms/utils.py | 23 +++-- 8 files changed, 184 insertions(+), 13 deletions(-) create mode 100644 tests/test_write.py diff --git a/tests/test_read.py b/tests/test_read.py index a3db1b8..beed107 100644 --- a/tests/test_read.py +++ b/tests/test_read.py @@ -241,6 +241,7 @@ def test_jittered_intervals(simmed_ms): for p in ["000", "001"]: node = xdt[f"backend_partition_{p}"] + node.VISIBILITY[:5] = 5 + 6j assert np.isclose(node.time.integration_time, DUMP_RATE) diff --git a/tests/test_write.py b/tests/test_write.py new file mode 100644 index 0000000..b0b0a38 --- /dev/null +++ b/tests/test_write.py @@ -0,0 +1,97 @@ +from typing import Iterable + +import numpy as np +import xarray +from xarray import Dataset, DataTree +from xarray.backends.api import dump_to_store + +from xarray_ms.backend.msv2.entrypoint import MSv2Store +from xarray_ms.backend.msv2.entrypoint_utils import CommonStoreArgs +from xarray_ms.errors import MissingEncodingError +from xarray_ms.msv4_types import CORRELATED_DATASET_TYPES + + +def test_store(monkeypatch, simmed_ms, tmp_path): + def datatree_to_msv2( + dt: DataTree, variables: str | Iterable[str], write_inherited_coords: bool = False + ): + assert isinstance(dt, DataTree) + + if isinstance(variables, str): + list_vars = [variables] + else: + list_vars = list(variables) + + if len(list_vars) == 0: + return + + for node in dt.subtree: + # Only write visibility data + if node.attrs.get("type", None) not in CORRELATED_DATASET_TYPES: + continue + + at_root = node is dt + ds = node.to_dataset(inherit=write_inherited_coords or at_root) + ds.to_msv2(list_vars) + + def dataset_to_msv2(ds: Dataset, variables: Iterable[str]): + assert isinstance(ds, Dataset) + + if isinstance(variables, str): + list_vars = [variables] + else: + list_vars = list(variables) + + if len(list_vars) == 0: + return + + try: + common_store_args = ds.encoding["common_store_args"] + partition_key = ds.encoding["partition_key"] + except KeyError as e: + raise MissingEncodingError( + f"Expected encoding key {e} is not present on " + f"the dataset of type {ds.attrs.get('type', None)}. " + f"at path {ds.path} " + f"Writing back to a Measurement Set " + f"is not possible without this information" + ) from e + + # Recover common arguments used to create the original store + # This will likely re-use existing table and structure factories + store_args = CommonStoreArgs(**common_store_args) + msv2_store = MSv2Store.open( + ms=store_args.ms, + partition_schema=store_args.partition_schema, + partition_key=partition_key, + preferred_chunks=store_args.preferred_chunks, + auto_corrs=store_args.auto_corrs, + ninstances=1, + epoch=store_args.epoch, + structure_factory=store_args.structure_factory, + ) + + # Strip out coordinates and attributes + ignored_vars = set(ds.data_vars) - set(list_vars) + ds = ds.drop_vars(ds.coords).drop_vars(ignored_vars).drop_attrs() + try: + dump_to_store(ds, msv2_store) + finally: + msv2_store.close() + + monkeypatch.setattr(Dataset, "to_msv2", dataset_to_msv2, raising=False) + monkeypatch.setattr(DataTree, "to_msv2", datatree_to_msv2, raising=False) + + with xarray.open_datatree(simmed_ms) as xdt: + # Overwrite UVW coordinates with zeroes + for ds in xdt.subtree: + if "UVW" in ds.data_vars: + assert not np.all(ds.UVW == 0) + ds.UVW[:] = 0 + + xdt.to_msv2("UVW") + + with xarray.open_datatree(simmed_ms) as xdt: + for ds in xdt.subtree: + if "UVW" in ds.data_vars: + assert np.all(ds.UVW == 0) diff --git a/xarray_ms/backend/msv2/array.py b/xarray_ms/backend/msv2/array.py index 68349b7..d87489f 100644 --- a/xarray_ms/backend/msv2/array.py +++ b/xarray_ms/backend/msv2/array.py @@ -4,7 +4,11 @@ import numpy as np from xarray.backends import BackendArray -from xarray.core.indexing import IndexingSupport, explicit_indexing_adapter +from xarray.core.indexing import ( + IndexingSupport, + expanded_indexer, + explicit_indexing_adapter, +) if TYPE_CHECKING: import numpy.typing as npt @@ -43,6 +47,9 @@ def __init__(self, shape: Tuple[int, ...], dtype: npt.DTypeLike): def __getitem__(self, key) -> npt.NDArray: raise NotImplementedError + def __setitem__(self, key, value) -> None: + raise NotImplementedError + @property def transform(self) -> TransformerT | None: raise NotImplementedError @@ -112,6 +119,16 @@ def _getitem(self, key) -> npt.NDArray: result = result.reshape(rows.shape + expected_shape[2:]) return self._transform(result) if self._transform else result + def __setitem__(self, key, value: npt.NDArray) -> None: + key = expanded_indexer(key, len(self.shape)) + expected_shape = tuple(slice_length(k, s) for k, s in zip(key, self.shape)) + rows = self._structure_factory.instance[self._partition].row_map[key[:2]] + row_key = (rows.ravel(),) + key[2:] + row_shape = (rows.size,) + expected_shape[2:] + value = self._transform(value) if self._transform else value + value = value.reshape(row_shape) + self._table_factory.instance.putcol(self._column, value, index=row_key) + @property def transform(self) -> TransformerT | None: return self._transform diff --git a/xarray_ms/backend/msv2/entrypoint.py b/xarray_ms/backend/msv2/entrypoint.py index ea7e73f..5da9c86 100644 --- a/xarray_ms/backend/msv2/entrypoint.py +++ b/xarray_ms/backend/msv2/entrypoint.py @@ -14,6 +14,7 @@ from xarray.core.datatree import DataTree from xarray.core.utils import try_read_magic_number_from_file_or_path +from xarray_ms.backend.msv2.array import MainMSv2Array from xarray_ms.backend.msv2.entrypoint_utils import CommonStoreArgs from xarray_ms.backend.msv2.factories import ( AntennaDatasetFactory, @@ -205,10 +206,36 @@ def get_attrs(self): return dict(sorted({**attrs, **factory.get_attrs()}.items())) def get_dimensions(self): - return None + return {} def get_encoding(self): - return {} + table_args = self._table_factory.arguments() + + return { + "common_store_args": { + "ms": table_args["filename"], + "ninstances": self._ninstances, + "auto_corrs": self._auto_corrs, + "epoch": self._epoch, + "partition_schema": self._partition_schema, + }, + "partition_key": self._partition_key, + } + + def set_dimensions(self, variables, unlimited_dims=None): + if unlimited_dims is not None: + raise NotImplementedError("MSv2 backend doesn't handle unlimited dimensions") + + def prepare_variable(self, name, variable, check_encoding=False, unlimited_dims=None): + target = MainMSv2Array( + self._table_factory, + self._structure_factory, + self._partition_key, + name, + variable.shape, + variable.dtype, + ) + return target, variable.data class MSv2EntryPoint(BackendEntrypoint): @@ -381,6 +408,7 @@ def open_datatree( > 0 ): dt.set_close(vis_ds[0]._close) + dt.encoding["common_store_args"] = vis_ds[0].encoding["common_store_args"] return dt diff --git a/xarray_ms/backend/msv2/entrypoint_utils.py b/xarray_ms/backend/msv2/entrypoint_utils.py index c8bedd0..59f5ff3 100644 --- a/xarray_ms/backend/msv2/entrypoint_utils.py +++ b/xarray_ms/backend/msv2/entrypoint_utils.py @@ -1,5 +1,5 @@ import os.path -from typing import Dict, List +from typing import Any, Dict, List from uuid import uuid4 import pyarrow as pa @@ -96,3 +96,13 @@ def __init__( self.epoch, auto_corrs=self.auto_corrs, ) + + def encode_base_args(self) -> Dict[str, Any]: + return { + "ms": self.ms, + "ninstances": self.ninstances, + "auto_corrs": self.auto_corrs, + "epoch": self.epoch, + "partition_schema": self.partition_schema, + "preferred_chunks": self.preferred_chunks, + } diff --git a/xarray_ms/errors.py b/xarray_ms/errors.py index 99689ae..038b28b 100644 --- a/xarray_ms/errors.py +++ b/xarray_ms/errors.py @@ -20,6 +20,10 @@ class InvalidPartitionKey(ValueError): """Raised when a string representing a partition key is invalid""" +class MissingEncodingError(ValueError): + """Raised when encoding information is missing""" + + class PartitioningError(ValueError): """Raised when a logical error is encountered during Measurement Set partitioning""" diff --git a/xarray_ms/multiton.py b/xarray_ms/multiton.py index f48d3f7..aff4445 100644 --- a/xarray_ms/multiton.py +++ b/xarray_ms/multiton.py @@ -1,10 +1,10 @@ from __future__ import annotations -from typing import Any, Callable, ClassVar, Mapping, Tuple +from typing import Any, Callable, ClassVar, Dict, Mapping, Tuple from cacheout import Cache -from xarray_ms.utils import FrozenKey, normalise_args +from xarray_ms.utils import FrozenKey, function_arguments, normalise_args FactoryFunctionT = Callable[..., Any] @@ -45,6 +45,11 @@ def from_reduce_args( ) -> Multiton: return Multiton(factory, *args, **kw) + def arguments(self) -> Dict[str, Any]: + """Return a dictionary of argument values that would be applied + to the factory function signature""" + return function_arguments(self._factory, self._args, self._kw) + def __reduce__( self, ) -> Tuple[Callable, Tuple[Callable, Tuple[Any, ...], Mapping[str, Any]]]: diff --git a/xarray_ms/utils.py b/xarray_ms/utils.py index 741afa9..ad84c1e 100644 --- a/xarray_ms/utils.py +++ b/xarray_ms/utils.py @@ -1,6 +1,6 @@ import inspect from collections.abc import Callable, Hashable, Mapping, Sequence, Set -from typing import Any, Tuple +from typing import Any, Dict, Tuple from numpy import ndarray @@ -51,7 +51,7 @@ def __str__(self) -> str: def normalise_args( - factory: Callable, args, kw + factory: Callable, args: Tuple, kw: Dict ) -> Tuple[Tuple[Any, ...], Mapping[str, Any]]: """Normalise args and kwargs into a hashable representation @@ -64,18 +64,27 @@ def normalise_args( tuple containing the normalised positional arguments and keyword arguments """ spec = inspect.getfullargspec(factory) - args = list(args) + list_args = list(args) for i, arg in enumerate(spec.args): - if i < len(args): + if i < len(list_args): continue elif arg in kw: - args.append(kw.pop(arg)) + list_args.append(kw.pop(arg)) elif spec.defaults and len(spec.args) - len(spec.defaults) <= i: default = spec.defaults[i - (len(spec.args) - len(spec.defaults))] - args.append(default) + list_args.append(default) - return tuple(args), kw + return tuple(list_args), kw + + +def function_arguments(fn: Callable, args: Tuple, kw: Mapping) -> Dict[str, Any]: + """Given a callable and some arguments and keywords, return a dictionary + of argument values that would be applied to the function signature""" + signature = inspect.signature(fn) + bound_arguments = signature.bind(*args, **kw) + bound_arguments.apply_defaults() + return bound_arguments.arguments def format_docstring(**subs): From d582291660065f2aa5136f03d1006cf3cf5261e2 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Wed, 21 May 2025 16:27:14 +0200 Subject: [PATCH 02/23] Remove debugging assign --- tests/test_read.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_read.py b/tests/test_read.py index beed107..a3db1b8 100644 --- a/tests/test_read.py +++ b/tests/test_read.py @@ -241,7 +241,6 @@ def test_jittered_intervals(simmed_ms): for p in ["000", "001"]: node = xdt[f"backend_partition_{p}"] - node.VISIBILITY[:5] = 5 + 6j assert np.isclose(node.time.integration_time, DUMP_RATE) From 99d19bf960bc82f3227ab2a0abb1655d02d98cf0 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Wed, 21 May 2025 16:28:32 +0200 Subject: [PATCH 03/23] Remove unused method --- xarray_ms/backend/msv2/entrypoint_utils.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/xarray_ms/backend/msv2/entrypoint_utils.py b/xarray_ms/backend/msv2/entrypoint_utils.py index 59f5ff3..c8bedd0 100644 --- a/xarray_ms/backend/msv2/entrypoint_utils.py +++ b/xarray_ms/backend/msv2/entrypoint_utils.py @@ -1,5 +1,5 @@ import os.path -from typing import Any, Dict, List +from typing import Dict, List from uuid import uuid4 import pyarrow as pa @@ -96,13 +96,3 @@ def __init__( self.epoch, auto_corrs=self.auto_corrs, ) - - def encode_base_args(self) -> Dict[str, Any]: - return { - "ms": self.ms, - "ninstances": self.ninstances, - "auto_corrs": self.auto_corrs, - "epoch": self.epoch, - "partition_schema": self.partition_schema, - "preferred_chunks": self.preferred_chunks, - } From 7253be7f4c7edec97f56487e8cd65856fc8d8c1f Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Wed, 21 May 2025 16:30:42 +0200 Subject: [PATCH 04/23] Remove unused assign of common_store_args on the DataTree root --- xarray_ms/backend/msv2/entrypoint.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xarray_ms/backend/msv2/entrypoint.py b/xarray_ms/backend/msv2/entrypoint.py index 5da9c86..150755f 100644 --- a/xarray_ms/backend/msv2/entrypoint.py +++ b/xarray_ms/backend/msv2/entrypoint.py @@ -408,7 +408,6 @@ def open_datatree( > 0 ): dt.set_close(vis_ds[0]._close) - dt.encoding["common_store_args"] = vis_ds[0].encoding["common_store_args"] return dt From cdd957356b17b792bcc99d7595bab539156f1961 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Wed, 21 May 2025 16:35:27 +0200 Subject: [PATCH 05/23] Remove unused tmp_path fixture --- tests/test_write.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_write.py b/tests/test_write.py index b0b0a38..3e2c479 100644 --- a/tests/test_write.py +++ b/tests/test_write.py @@ -11,7 +11,7 @@ from xarray_ms.msv4_types import CORRELATED_DATASET_TYPES -def test_store(monkeypatch, simmed_ms, tmp_path): +def test_store(monkeypatch, simmed_ms): def datatree_to_msv2( dt: DataTree, variables: str | Iterable[str], write_inherited_coords: bool = False ): From 7d3e4f995208384620f6bb95429924dcf6265776 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Wed, 21 May 2025 17:01:44 +0200 Subject: [PATCH 06/23] Use ternary operator --- tests/test_write.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/tests/test_write.py b/tests/test_write.py index 3e2c479..aeda86c 100644 --- a/tests/test_write.py +++ b/tests/test_write.py @@ -16,11 +16,7 @@ def datatree_to_msv2( dt: DataTree, variables: str | Iterable[str], write_inherited_coords: bool = False ): assert isinstance(dt, DataTree) - - if isinstance(variables, str): - list_vars = [variables] - else: - list_vars = list(variables) + list_vars = [variables] if isinstance(variables, str) else list(variables) if len(list_vars) == 0: return @@ -36,11 +32,7 @@ def datatree_to_msv2( def dataset_to_msv2(ds: Dataset, variables: Iterable[str]): assert isinstance(ds, Dataset) - - if isinstance(variables, str): - list_vars = [variables] - else: - list_vars = list(variables) + list_vars = [variables] if isinstance(variables, str) else list(variables) if len(list_vars) == 0: return From 658dfb67682e1e0c8d0ec3fa144d77eb3ce0bb9f Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Wed, 16 Jul 2025 14:47:52 +0200 Subject: [PATCH 07/23] Prototype support for adding new columns --- tests/test_write.py | 109 ++++--------- xarray_ms/backend/msv2/writes.py | 252 +++++++++++++++++++++++++++++++ xarray_ms/casa_types.py | 15 ++ 3 files changed, 299 insertions(+), 77 deletions(-) create mode 100644 xarray_ms/backend/msv2/writes.py diff --git a/tests/test_write.py b/tests/test_write.py index aeda86c..7904e4e 100644 --- a/tests/test_write.py +++ b/tests/test_write.py @@ -1,89 +1,44 @@ -from typing import Iterable - +import arcae import numpy as np import xarray from xarray import Dataset, DataTree -from xarray.backends.api import dump_to_store -from xarray_ms.backend.msv2.entrypoint import MSv2Store -from xarray_ms.backend.msv2.entrypoint_utils import CommonStoreArgs -from xarray_ms.errors import MissingEncodingError +from xarray_ms.backend.msv2.writes import dataset_to_msv2, datatree_to_msv2 from xarray_ms.msv4_types import CORRELATED_DATASET_TYPES def test_store(monkeypatch, simmed_ms): - def datatree_to_msv2( - dt: DataTree, variables: str | Iterable[str], write_inherited_coords: bool = False - ): - assert isinstance(dt, DataTree) - list_vars = [variables] if isinstance(variables, str) else list(variables) - - if len(list_vars) == 0: - return - - for node in dt.subtree: - # Only write visibility data - if node.attrs.get("type", None) not in CORRELATED_DATASET_TYPES: - continue - - at_root = node is dt - ds = node.to_dataset(inherit=write_inherited_coords or at_root) - ds.to_msv2(list_vars) - - def dataset_to_msv2(ds: Dataset, variables: Iterable[str]): - assert isinstance(ds, Dataset) - list_vars = [variables] if isinstance(variables, str) else list(variables) - - if len(list_vars) == 0: - return - - try: - common_store_args = ds.encoding["common_store_args"] - partition_key = ds.encoding["partition_key"] - except KeyError as e: - raise MissingEncodingError( - f"Expected encoding key {e} is not present on " - f"the dataset of type {ds.attrs.get('type', None)}. " - f"at path {ds.path} " - f"Writing back to a Measurement Set " - f"is not possible without this information" - ) from e - - # Recover common arguments used to create the original store - # This will likely re-use existing table and structure factories - store_args = CommonStoreArgs(**common_store_args) - msv2_store = MSv2Store.open( - ms=store_args.ms, - partition_schema=store_args.partition_schema, - partition_key=partition_key, - preferred_chunks=store_args.preferred_chunks, - auto_corrs=store_args.auto_corrs, - ninstances=1, - epoch=store_args.epoch, - structure_factory=store_args.structure_factory, - ) - - # Strip out coordinates and attributes - ignored_vars = set(ds.data_vars) - set(list_vars) - ds = ds.drop_vars(ds.coords).drop_vars(ignored_vars).drop_attrs() - try: - dump_to_store(ds, msv2_store) - finally: - msv2_store.close() - monkeypatch.setattr(Dataset, "to_msv2", dataset_to_msv2, raising=False) monkeypatch.setattr(DataTree, "to_msv2", datatree_to_msv2, raising=False) - with xarray.open_datatree(simmed_ms) as xdt: - # Overwrite UVW coordinates with zeroes - for ds in xdt.subtree: - if "UVW" in ds.data_vars: - assert not np.all(ds.UVW == 0) - ds.UVW[:] = 0 - - xdt.to_msv2("UVW") + read = written = False - with xarray.open_datatree(simmed_ms) as xdt: - for ds in xdt.subtree: - if "UVW" in ds.data_vars: - assert np.all(ds.UVW == 0) + with xarray.open_datatree(simmed_ms, auto_corrs=True) as xdt: + # Overwrite UVW coordinates with zeroes + # Add a CORRECTED column + for node in xdt.subtree: + if node.attrs.get("type") in CORRELATED_DATASET_TYPES: + assert not np.all(node.UVW == 0) + node.UVW[:] = 0 + assert len(node.encoding) > 0 + ds = node.ds.assign(CORRECTED=xarray.full_like(node.VISIBILITY, 1 + 2j)) + xdt[node.path] = DataTree(ds) + assert len(node.encoding) > 0 + + xdt.to_msv2(["UVW", "CORRECTED"]) + written = written or True + + with xarray.open_datatree(simmed_ms, auto_corrs=True) as xdt: + for node in xdt.subtree: + if node.attrs.get("type") in CORRELATED_DATASET_TYPES: + assert np.all(node.UVW == 0) + # Non-standard columns aren't yet exposed + # assert np.all(node.CORRECTED == 1 + 2j) + read = read or True + + assert read + assert written + + # But we can check that CORRECTED has been written correctly + with arcae.table(simmed_ms) as T: + np.testing.assert_array_equal(T.getcol("CORRECTED"), 1 + 2j) diff --git a/xarray_ms/backend/msv2/writes.py b/xarray_ms/backend/msv2/writes.py new file mode 100644 index 0000000..c7f8b65 --- /dev/null +++ b/xarray_ms/backend/msv2/writes.py @@ -0,0 +1,252 @@ +import warnings +from collections import defaultdict +from typing import Any, Dict, Iterable, Set, Tuple + +import numpy as np +import numpy.typing as npt +from arcae.lib.arrow_tables import ms_descriptor +from xarray import Dataset, DataTree +from xarray.backends.api import dump_to_store + +from xarray_ms.backend.msv2.entrypoint import MSv2Store +from xarray_ms.backend.msv2.entrypoint_utils import CommonStoreArgs +from xarray_ms.casa_types import NUMPY_TO_CASA_MAP +from xarray_ms.errors import MissingEncodingError +from xarray_ms.msv4_types import CORRELATED_DATASET_TYPES + +ShapeSetType = Set[Tuple[int, ...]] +DTypeSetType = Set[npt.DTypeLike] +ShapeAndDTypeType = Dict[Tuple[str, str], Tuple[ShapeSetType, DTypeSetType]] + + +def validate_column_desc( + var_name: str, column: str, shapes: ShapeSetType, column_desc: Dict +) -> None: + # Validate the variable ndim against the column ndim + if (ndim := column_desc.get("ndim")) is not None: + # Multi-dimensional CASA column configuraiton + multidim = ndim == -1 + ndims = {len(s) for s in shapes} + if not multidim and len(ndims) > 1: + raise ValueError( + f"{column} descriptor specifies a fixed ndim ({ndim}) " + f"but {var_name} has multiple trailing shapes with dimensions {shapes}" + ) + elif not all(len(shape) == 0 for shape in shapes): + # ndim implies column only has a row-dimension + raise ValueError( + f"{column} descriptor specifies a row only column " + f"but {var_name} has trailing shape(s) {shapes}" + ) + + # Validate the variable shape if the column is fixed + if (fixed_shape := column_desc.get("shape")) is not None: + fixed_shape = tuple(fixed_shape) + if len(shapes) > 1: + raise ValueError( + f"Variable {var_name} has multiple trailing shapes {shapes} " + f"but {column} specifies a fixed shape {fixed_shape}" + ) + + if (var_shape := next(iter(shapes))) != fixed_shape: + raise ValueError( + f"Variable {var_name} has trailing shape {var_shape} " + f"but {column} has a fixed shape {fixed_shape}" + ) + + +def fit_tile_shape(shape: Tuple[int, ...], dtype: npt.DTypeLike) -> Dict[str, np.int32]: + """ + Args: + shape: tile shape + dtype: tile data type + + Returns: + A :code:`{DEFAULTTILESHAPE: tile_shape}` dictionary + """ + nbytes = np.dtype(dtype).itemsize + min_tile_dims = [512] + max_tile_dims = [np.inf] + + for dim in shape: + min_tile = min(dim, 4) # Don't tile <=4 elements. + # For dims which are not exact powers of two, treat them as though + # they are floored to the nearest power of two. + max_tile = int(min(2 ** int(np.log2(dim)) / 8, 64)) + max_tile = min_tile if max_tile < min_tile else max_tile + min_tile_dims.append(min_tile) + max_tile_dims.append(max_tile) + + tile_shape = min_tile_dims.copy() + growth_axis = 0 + + while np.prod(tile_shape) * nbytes < 1024**2: # 1MB tiles. + if tile_shape[growth_axis] < max_tile_dims[growth_axis]: + tile_shape[growth_axis] *= 2 + growth_axis = (growth_axis + 1) % len(tile_shape) + + return {"DEFAULTTILESHAPE": list(tile_shape[::-1])} + + +def generate_descriptor( + shapes_and_dtypes: ShapeAndDTypeType, +) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """ + Returns: + A (descriptor, dminfo) tuple containing any columns + and data managers that should be created + """ + canonical_table_desc = ms_descriptor("MAIN", complete=True) + actual_desc = {} + dm_groups = [] + + for (var_name, column), (shapes, dtypes) in shapes_and_dtypes.items(): + if column_desc := canonical_table_desc.get(column): + # If an existing descriptor exists for this column, + # validate that the variable shape matches the column + validate_column_desc(var_name, column, shapes, column_desc) + else: + # Construct a column descriptor and possibly an associated data manager + # Unify variable numpy types + dtype = np.result_type(*dtypes) + + if dtype is object: + raise NotImplementedError( + f"Types of variable {var_name} ({list(dtypes)}) " + f"resolves to an object. " + f"Writing of objects is not supported" + ) + + try: + casa_type = NUMPY_TO_CASA_MAP[np.dtype(dtype).type] + except KeyError as e: + raise ValueError( + f"No CASA type matched NumPy dtype {dtype}\n{NUMPY_TO_CASA_MAP}" + ) from e + + column_desc = {"valueType": casa_type, "option": 0} + + if len(shapes) == 1: + # Fixed shape, Tile the column + fixed_shape = tuple(reversed(next(iter(shapes)))) + row_only = len(fixed_shape) == 0 + if not row_only: + column_desc["option"] |= 4 + column_desc["shape"] = list(fixed_shape) + column_desc["ndim"] = len(fixed_shape) + column_desc["dataManagerGroup"] = dm_group = f"{column}_GROUP" + column_desc["dataManagerType"] = dm_type = "TiledColumnStMan" + + dm_groups.append( + { + "COLUMNS": [column], + "NAME": dm_group, + "TYPE": dm_type, + "SPEC": fit_tile_shape(fixed_shape, dtype), + } + ) + else: + # Variably shaped + column_desc["option"] = 0 + column_desc["dataManagerGroup"] = "StandardStMan" + column_desc["dataManagerType"] = "StandardStMan" + + actual_desc[column] = column_desc + + dminfo = {f"*{i + 1}": g for i, g in enumerate(dm_groups)} + return actual_desc, dminfo + + +def msv2_store_from_dataset(ds: Dataset) -> MSv2Store: + try: + common_store_args = ds.encoding["common_store_args"] + partition_key = ds.encoding["partition_key"] + except KeyError as e: + raise MissingEncodingError( + f"Expected encoding key {e} is not present on " + f"a dataset of type {ds.attrs.get('type')}. " + f"Writing back to a Measurement Set " + f"is not possible without this information" + ) from e + + # Recover common arguments used to create the original store + # This will likely re-use existing table and structure factories + store_args = CommonStoreArgs(**common_store_args) + return MSv2Store.open( + ms=store_args.ms, + partition_schema=store_args.partition_schema, + partition_key=partition_key, + preferred_chunks=store_args.preferred_chunks, + auto_corrs=store_args.auto_corrs, + ninstances=store_args.ninstances, + epoch=store_args.epoch, + structure_factory=store_args.structure_factory, + ) + + +def datatree_to_msv2( + dt: DataTree, variables: str | Iterable[str], write_inherited_coords: bool = False +): + assert isinstance(dt, DataTree) + list_var_names = [variables] if isinstance(variables, str) else list(variables) + set_var_names = set(list_var_names) + + if len(set_var_names) == 0: + warnings.warn("Empty 'variables'") + return + + vis_datasets = [ + n for n in dt.subtree if n.attrs.get("type") in CORRELATED_DATASET_TYPES + ] + + if len(vis_datasets) == 0: + warnings.warn("No visibility datasets were found on the DataTree") + return + + shapes_and_dtypes: ShapeAndDTypeType = defaultdict(lambda: (set(), set())) + + for node in vis_datasets: + assert isinstance(node, DataTree) + if not set_var_names.issubset(node.data_vars.keys()): + raise ValueError( + f"{set_var_names} are not present in all visibility DataTree nodes" + ) + + for n in set_var_names: + var = node.data_vars[n] + PREFIX_DIMS = ("time", "baseline_id") + + if var.dims[:2] != PREFIX_DIMS: + raise ValueError(f"{n} dimensions {var.dims} do not start with {PREFIX_DIMS}") + + shapes, dtypes = shapes_and_dtypes[(n, n)] + + shapes.add(var.shape[len(PREFIX_DIMS) :]) + dtypes.add(var.dtype) + + column_descs, dminfo = generate_descriptor(shapes_and_dtypes) + store = msv2_store_from_dataset(next(iter(vis_datasets)).ds) + store._table_factory.instance.addcols(column_descs, dminfo) + assert set(column_descs.keys()).issubset(store._table_factory.instance.columns()) + + for node in vis_datasets: + at_root = node is dt.root + node = node.to_dataset(inherit=write_inherited_coords or at_root) + node.to_msv2(list_var_names) + + +def dataset_to_msv2(ds: Dataset, variables: str | Iterable[str]): + assert isinstance(ds, Dataset) + list_vars = [variables] if isinstance(variables, str) else list(variables) + + if len(list_vars) == 0: + return + + # Strip out coordinates and attributes + msv2_store = msv2_store_from_dataset(ds) + ignored_vars = set(ds.data_vars) - set(list_vars) + ds = ds.drop_vars(ds.coords).drop_vars(ignored_vars).drop_attrs() + try: + dump_to_store(ds, msv2_store) + finally: + msv2_store.close() diff --git a/xarray_ms/casa_types.py b/xarray_ms/casa_types.py index 8a965d7..f72a42e 100644 --- a/xarray_ms/casa_types.py +++ b/xarray_ms/casa_types.py @@ -36,6 +36,21 @@ "STRING": object, } +NUMPY_TO_CASA_MAP = { + np.bool_: "BOOL", + np.int8: "CHAR", + np.uint8: "UCHAR", + np.int16: "SHORT", + np.uint16: "USHORT", + np.int32: "INT", + np.uint32: "UINT", + np.float32: "FLOAT", + np.float64: "DOUBLE", + np.complex64: "COMPLEX", + np.complex128: "DCOMPLEX", + object: "STRING", +} + @dataclass class ColumnDesc: From ce3716a3abf64687eef67467fcf069c2396a82c4 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Wed, 16 Jul 2025 15:24:01 +0200 Subject: [PATCH 08/23] Use existing table descriptor during synthesis of column descriptors --- xarray_ms/backend/msv2/writes.py | 41 ++++++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/xarray_ms/backend/msv2/writes.py b/xarray_ms/backend/msv2/writes.py index c7f8b65..558feff 100644 --- a/xarray_ms/backend/msv2/writes.py +++ b/xarray_ms/backend/msv2/writes.py @@ -88,10 +88,25 @@ def fit_tile_shape(shape: Tuple[int, ...], dtype: npt.DTypeLike) -> Dict[str, np return {"DEFAULTTILESHAPE": list(tile_shape[::-1])} -def generate_descriptor( +def generate_column_descriptor( + table_desc: Dict[str, Any], shapes_and_dtypes: ShapeAndDTypeType, ) -> Tuple[Dict[str, Any], Dict[str, Any]]: """ + Synthesises a column descriptor from: + + 1. An existing table descriptor. + 2. A complete table descriptor. + 3. The shapes and data types associated with + xarray Variables on a DataTree. + + Args: + table_desc: Table descriptor containing existing + column definitions. + shapes_and_dtypes: A :code:`{var_name: (set(shapes), set(dtypes))}` + mapping of shapes and data types associated with each xarray + Variable. + Returns: A (descriptor, dminfo) tuple containing any columns and data managers that should be created @@ -101,12 +116,16 @@ def generate_descriptor( dm_groups = [] for (var_name, column), (shapes, dtypes) in shapes_and_dtypes.items(): - if column_desc := canonical_table_desc.get(column): - # If an existing descriptor exists for this column, - # validate that the variable shape matches the column + # If there are existing descriptors, either for + # columns present on the table, or in the canonical definition + # validate that the variable shape matches the column + if column_desc := table_desc.get(column): + validate_column_desc(var_name, column, shapes, column_desc) + elif column_desc := canonical_table_desc.get(column): validate_column_desc(var_name, column, shapes, column_desc) else: - # Construct a column descriptor and possibly an associated data manager + # Construct a column descriptor and possibly an + # associated data manager # Unify variable numpy types dtype = np.result_type(*dtypes) @@ -146,7 +165,8 @@ def generate_descriptor( } ) else: - # Variably shaped + # Variably shaped, use a StandardStMan for now + # but consider a TiledCellStMan in future column_desc["option"] = 0 column_desc["dataManagerGroup"] = "StandardStMan" column_desc["dataManagerType"] = "StandardStMan" @@ -224,10 +244,11 @@ def datatree_to_msv2( shapes.add(var.shape[len(PREFIX_DIMS) :]) dtypes.add(var.dtype) - column_descs, dminfo = generate_descriptor(shapes_and_dtypes) - store = msv2_store_from_dataset(next(iter(vis_datasets)).ds) - store._table_factory.instance.addcols(column_descs, dminfo) - assert set(column_descs.keys()).issubset(store._table_factory.instance.columns()) + table_factory = msv2_store_from_dataset(next(iter(vis_datasets)).ds)._table_factory + table_desc = table_factory.instance.tabledesc() + column_descs, dminfo = generate_column_descriptor(table_desc, shapes_and_dtypes) + table_factory.instance.addcols(column_descs, dminfo) + assert set(column_descs.keys()).issubset(table_factory.instance.columns()) for node in vis_datasets: at_root = node is dt.root From ffdc94d4a6b01025313e01de535ec544f26aa04b Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Thu, 17 Jul 2025 16:54:30 +0200 Subject: [PATCH 09/23] Preliminary support for writing with regions --- tests/test_write.py | 39 +++++++++++++++++ xarray_ms/backend/msv2/array.py | 2 +- xarray_ms/backend/msv2/entrypoint.py | 47 ++++++++++++++++++++- xarray_ms/backend/msv2/writes.py | 63 +++++++++++++++++++--------- 4 files changed, 129 insertions(+), 22 deletions(-) diff --git a/tests/test_write.py b/tests/test_write.py index 7904e4e..4db4e8d 100644 --- a/tests/test_write.py +++ b/tests/test_write.py @@ -42,3 +42,42 @@ def test_store(monkeypatch, simmed_ms): # But we can check that CORRECTED has been written correctly with arcae.table(simmed_ms) as T: np.testing.assert_array_equal(T.getcol("CORRECTED"), 1 + 2j) + + +def test_store_region(monkeypatch, simmed_ms): + monkeypatch.setattr(Dataset, "to_msv2", dataset_to_msv2, raising=False) + monkeypatch.setattr(DataTree, "to_msv2", datatree_to_msv2, raising=False) + + region = {"time": slice(0, 2), "frequency": slice(2, 4)} + + with xarray.open_datatree(simmed_ms, auto_corrs=True) as xdt: + # Overwrite UVW coordinates with zeroes + # Add a CORRECTED column + for node in xdt.subtree: + if node.attrs.get("type") in CORRELATED_DATASET_TYPES: + ds = node.ds.assign(CORRECTED=xarray.full_like(node.VISIBILITY, 1 + 2j)) + xdt[node.path] = DataTree(ds) + assert len(node.encoding) > 0 + + xdt.to_msv2(["UVW", "CORRECTED"], compute=False) + + for node in xdt.subtree: + if node.attrs.get("type") in CORRELATED_DATASET_TYPES: + sizes = node.sizes + # Slice out the region + ds = node.ds.isel(**region) + # Now write it out + ds.to_msv2(["UVW", "CORRECTED"], compute=False, region=region) + + # But we can check that CORRECTED has been written correctly + with arcae.table(simmed_ms) as T: + corrected = T.getcol("CORRECTED") + nt, nbl, nf, npol = ( + sizes[d] for d in ("time", "baseline_id", "frequency", "polarization") + ) + corrected = corrected.reshape((nt, nbl, nf, npol)) + ts, fs = (region[d] for d in ("time", "frequency")) + mask = np.full(corrected.shape, False, np.bool_) + mask[ts, :, fs, :] = True + np.testing.assert_array_equal(corrected[mask], 1 + 2j) + np.testing.assert_array_equal(corrected[~mask], 0 + 0j) diff --git a/xarray_ms/backend/msv2/array.py b/xarray_ms/backend/msv2/array.py index 59e8772..8fd7a93 100644 --- a/xarray_ms/backend/msv2/array.py +++ b/xarray_ms/backend/msv2/array.py @@ -28,7 +28,7 @@ def slice_length(s: npt.NDArray | slice, max_len) -> int: raise NotImplementedError("Slicing with non-1D numpy arrays") return len(s) - start, stop, step = s.indices(max_len) + start, stop, step = s.indices(max(max_len, s.stop) if s.stop is not None else max_len) if step != 1: raise NotImplementedError(f"Slicing with steps {s} other than 1 not supported") return stop - start diff --git a/xarray_ms/backend/msv2/entrypoint.py b/xarray_ms/backend/msv2/entrypoint.py index 150755f..ce8f3cf 100644 --- a/xarray_ms/backend/msv2/entrypoint.py +++ b/xarray_ms/backend/msv2/entrypoint.py @@ -4,7 +4,7 @@ import warnings from datetime import datetime, timezone from importlib.metadata import version as importlib_version -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Mapping import xarray from xarray.backends import BackendEntrypoint @@ -38,6 +38,9 @@ from xarray_ms.backend.msv2.structure import DEFAULT_PARTITION_COLUMNS, PartitionKeyT +WriteRegionType = Mapping[str, slice | Literal["auto"]] | Literal["auto"] + + def promote_chunks( structure: MSv2Structure, chunks: Dict | str | None ) -> Dict[PartitionKeyT, Dict[str, int]] | str | None: @@ -84,6 +87,7 @@ class MSv2Store(AbstractWritableDataStore): "_auto_corrs", "_ninstances", "_epoch", + "_write_region", ) _table_factory: Multiton @@ -95,6 +99,7 @@ class MSv2Store(AbstractWritableDataStore): _autocorrs: bool _ninstances: int _epoch: str + _write_region: WriteRegionType def __init__( self, @@ -107,6 +112,7 @@ def __init__( auto_corrs: bool, ninstances: int, epoch: str, + write_region: WriteRegionType, ): self._table_factory = table_factory self._subtable_factories = subtable_factories @@ -117,6 +123,7 @@ def __init__( self._auto_corrs = auto_corrs self._ninstances = ninstances self._epoch = epoch + self._write_region = write_region @classmethod def open( @@ -130,6 +137,7 @@ def open( ninstances: int = 1, epoch: str | None = None, structure_factory: MSv2StructureFactory | None = None, + write_region: WriteRegionType = "auto", ): if not isinstance(ms, str): raise ValueError("Measurement Sets paths must be strings") @@ -170,6 +178,7 @@ def open( auto_corrs=store_args.auto_corrs, ninstances=store_args.ninstances, epoch=store_args.epoch, + write_region=write_region, ) def close(self, **kwargs): @@ -226,6 +235,21 @@ def set_dimensions(self, variables, unlimited_dims=None): if unlimited_dims is not None: raise NotImplementedError("MSv2 backend doesn't handle unlimited dimensions") + def set_variables( + self, + variables: dict[str, xarray.Variable], + check_encoding_set, + writer, + unlimited_dims=None, + ): + for n, v in variables.items(): + target, source = self.prepare_variable( + n, v, n in check_encoding_set, unlimited_dims=unlimited_dims + ) + write_region = {} if self._write_region == "auto" else self._write_region + region = tuple(write_region.get(d, slice(None)) for d in v.dims) + writer.add(source, target, region=region) + def prepare_variable(self, name, variable, check_encoding=False, unlimited_dims=None): target = MainMSv2Array( self._table_factory, @@ -237,6 +261,27 @@ def prepare_variable(self, name, variable, check_encoding=False, unlimited_dims= ) return target, variable.data + def set_write_region(self, ds: Dataset): + """Sets the region for writing on this store, given a source dataset""" + + region: Dict[str, slice | Literal["auto"]] = dict.fromkeys(ds.dims, "auto") + + if self._write_region == "auto": + pass + elif isinstance(self._write_region, dict): + region.update(self._write_region) + else: + raise TypeError(f"'region' ({type(region)}) should be a dict or \"auto\"") + + try: + expanded_region = { + d: slice(0, ds.sizes[d]) if v == "auto" else v for d, v in region.items() + } + except KeyError as e: + raise ValueError(f"{e} is not a dataset dimension") + + self._write_region = expanded_region + class MSv2EntryPoint(BackendEntrypoint): open_dataset_parameters = [ diff --git a/xarray_ms/backend/msv2/writes.py b/xarray_ms/backend/msv2/writes.py index 558feff..9e70b10 100644 --- a/xarray_ms/backend/msv2/writes.py +++ b/xarray_ms/backend/msv2/writes.py @@ -1,12 +1,13 @@ import warnings from collections import defaultdict -from typing import Any, Dict, Iterable, Set, Tuple +from typing import Any, Dict, Iterable, Literal, Mapping, Set, Tuple import numpy as np import numpy.typing as npt from arcae.lib.arrow_tables import ms_descriptor from xarray import Dataset, DataTree -from xarray.backends.api import dump_to_store +from xarray.backends.api import _finalize_store, dump_to_store +from xarray.backends.common import ArrayWriter from xarray_ms.backend.msv2.entrypoint import MSv2Store from xarray_ms.backend.msv2.entrypoint_utils import CommonStoreArgs @@ -58,7 +59,7 @@ def validate_column_desc( def fit_tile_shape(shape: Tuple[int, ...], dtype: npt.DTypeLike) -> Dict[str, np.int32]: """ Args: - shape: tile shape + shape: FORTRAN ordered tile shape dtype: tile data type Returns: @@ -85,6 +86,7 @@ def fit_tile_shape(shape: Tuple[int, ...], dtype: npt.DTypeLike) -> Dict[str, np tile_shape[growth_axis] *= 2 growth_axis = (growth_axis + 1) % len(tile_shape) + # The tile shape is C ordered return {"DEFAULTTILESHAPE": list(tile_shape[::-1])} @@ -146,7 +148,8 @@ def generate_column_descriptor( column_desc = {"valueType": casa_type, "option": 0} if len(shapes) == 1: - # Fixed shape, Tile the column + # If the shape is fixed, Tile the column + # column descriptor shapes are fortran ordered fixed_shape = tuple(reversed(next(iter(shapes)))) row_only = len(fixed_shape) == 0 if not row_only: @@ -177,7 +180,7 @@ def generate_column_descriptor( return actual_desc, dminfo -def msv2_store_from_dataset(ds: Dataset) -> MSv2Store: +def msv2_store_from_dataset(ds: Dataset, region="auto") -> MSv2Store: try: common_store_args = ds.encoding["common_store_args"] partition_key = ds.encoding["partition_key"] @@ -190,7 +193,7 @@ def msv2_store_from_dataset(ds: Dataset) -> MSv2Store: ) from e # Recover common arguments used to create the original store - # This will likely re-use existing table and structure factories + # This will re-use existing table and structure factories store_args = CommonStoreArgs(**common_store_args) return MSv2Store.open( ms=store_args.ms, @@ -201,11 +204,15 @@ def msv2_store_from_dataset(ds: Dataset) -> MSv2Store: ninstances=store_args.ninstances, epoch=store_args.epoch, structure_factory=store_args.structure_factory, + write_region=region, ) def datatree_to_msv2( - dt: DataTree, variables: str | Iterable[str], write_inherited_coords: bool = False + dt: DataTree, + variables: str | Iterable[str], + compute: Literal[True] = True, + write_inherited_coords: bool = False, ): assert isinstance(dt, DataTree) list_var_names = [variables] if isinstance(variables, str) else list(variables) @@ -240,7 +247,6 @@ def datatree_to_msv2( raise ValueError(f"{n} dimensions {var.dims} do not start with {PREFIX_DIMS}") shapes, dtypes = shapes_and_dtypes[(n, n)] - shapes.add(var.shape[len(PREFIX_DIMS) :]) dtypes.add(var.dtype) @@ -250,24 +256,41 @@ def datatree_to_msv2( table_factory.instance.addcols(column_descs, dminfo) assert set(column_descs.keys()).issubset(table_factory.instance.columns()) - for node in vis_datasets: - at_root = node is dt.root - node = node.to_dataset(inherit=write_inherited_coords or at_root) - node.to_msv2(list_var_names) + if compute: + for node in vis_datasets: + at_root = node is dt.root + ds = node.to_dataset(inherit=write_inherited_coords or at_root) + dataset_to_msv2(ds, list_var_names, compute) -def dataset_to_msv2(ds: Dataset, variables: str | Iterable[str]): +def dataset_to_msv2( + ds: Dataset, + variables: str | Iterable[str], + compute: Literal[True] = True, + region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] = "auto", +): assert isinstance(ds, Dataset) list_vars = [variables] if isinstance(variables, str) else list(variables) if len(list_vars) == 0: return - # Strip out coordinates and attributes - msv2_store = msv2_store_from_dataset(ds) + # Strip out + # 1. variables that will not be written + # 2. coordinates + # 3. attributes ignored_vars = set(ds.data_vars) - set(list_vars) - ds = ds.drop_vars(ds.coords).drop_vars(ignored_vars).drop_attrs() - try: - dump_to_store(ds, msv2_store) - finally: - msv2_store.close() + ds = ds.drop_vars(ignored_vars | set(ds.coords)).drop_attrs() + + msv2_store = msv2_store_from_dataset(ds, region) + msv2_store.set_write_region(ds) + writer = ArrayWriter() + dump_to_store(ds, msv2_store, writer) + writes = writer.sync(compute=compute) + + if compute: + _finalize_store(writes, msv2_store) + else: + import dask + + return dask.delayed(_finalize_store)(writes, msv2_store) From d68ece268f030ff584aa177ef72e90c1bec6ded1 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Fri, 18 Jul 2025 10:58:25 +0200 Subject: [PATCH 10/23] Comments and typing --- xarray_ms/backend/msv2/entrypoint.py | 43 ++++++++++++++++++++++++---- 1 file changed, 37 insertions(+), 6 deletions(-) diff --git a/xarray_ms/backend/msv2/entrypoint.py b/xarray_ms/backend/msv2/entrypoint.py index ce8f3cf..d5cd5a5 100644 --- a/xarray_ms/backend/msv2/entrypoint.py +++ b/xarray_ms/backend/msv2/entrypoint.py @@ -38,7 +38,7 @@ from xarray_ms.backend.msv2.structure import DEFAULT_PARTITION_COLUMNS, PartitionKeyT -WriteRegionType = Mapping[str, slice | Literal["auto"]] | Literal["auto"] +WriteRegionT = Mapping[str, slice | Literal["auto"]] | Literal["auto"] def promote_chunks( @@ -99,7 +99,7 @@ class MSv2Store(AbstractWritableDataStore): _autocorrs: bool _ninstances: int _epoch: str - _write_region: WriteRegionType + _write_region: WriteRegionT def __init__( self, @@ -112,7 +112,7 @@ def __init__( auto_corrs: bool, ninstances: int, epoch: str, - write_region: WriteRegionType, + write_region: WriteRegionT, ): self._table_factory = table_factory self._subtable_factories = subtable_factories @@ -137,7 +137,7 @@ def open( ninstances: int = 1, epoch: str | None = None, structure_factory: MSv2StructureFactory | None = None, - write_region: WriteRegionType = "auto", + write_region: WriteRegionT = "auto", ): if not isinstance(ms, str): raise ValueError("Measurement Sets paths must be strings") @@ -197,9 +197,11 @@ def main_dataset_factory(self) -> CorrelatedDatasetFactory: ) def get_variables(self): + """Overrides AbstractDataStore.get_variables""" return self.main_dataset_factory().get_variables() def get_attrs(self): + """Overrides AbstractDataStore.get_attrs""" factory = self.main_dataset_factory() attrs = { @@ -215,9 +217,20 @@ def get_attrs(self): return dict(sorted({**attrs, **factory.get_attrs()}.items())) def get_dimensions(self): + """Typically, this hook retrieves Dataset dimensions from the + underlying store, but in the MSv2Store implementation, these + are currently derived from the Dataset. + + Overrides AbstractDataStore.get_dimensions. + """ return {} def get_encoding(self): + """Encodes the arguments used to create the MSv2Store, + as well as the partition key associated with this partition + of the Measurement Set. + + Overrides AbstractDataStore.get_encoding""" table_args = self._table_factory.arguments() return { @@ -232,8 +245,18 @@ def get_encoding(self): } def set_dimensions(self, variables, unlimited_dims=None): - if unlimited_dims is not None: - raise NotImplementedError("MSv2 backend doesn't handle unlimited dimensions") + """Set dimensions on the store, based on the variables. + This might be a good point to add columns to the Measurement Set, + but a store only references a partition of the MS so a complete + view of all the variables that would contribute to defining an + MS column are not available. + + This logic must be performed at the higher DataTree level, + so we noop here. + + Overrides AbstractWritableDataStore.set_dimensions. + """ + pass def set_variables( self, @@ -242,6 +265,10 @@ def set_variables( writer, unlimited_dims=None, ): + """Set variables for writing to the store + + Overrides AbstractWritableDataStore.set_variables. + """ for n, v in variables.items(): target, source = self.prepare_variable( n, v, n in check_encoding_set, unlimited_dims=unlimited_dims @@ -251,6 +278,10 @@ def set_variables( writer.add(source, target, region=region) def prepare_variable(self, name, variable, check_encoding=False, unlimited_dims=None): + """Prepare variable for writing to the Measurement Set + + Overrides AbstractWritableDataStore.prepare_variable. + """ target = MainMSv2Array( self._table_factory, self._structure_factory, From 4d1e0bb257b04243b956cb166f3c7ea3cfec3c60 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Tue, 12 Aug 2025 15:59:02 +0200 Subject: [PATCH 11/23] Depend on arcae 0.4.0-alpha.1 --- pyproject.toml | 2 +- xarray_ms/backend/msv2/entrypoint_utils.py | 10 ++-------- xarray_ms/backend/msv2/structure.py | 2 +- xarray_ms/testing/simulator.py | 6 ++---- 4 files changed, 6 insertions(+), 14 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b113f7e..83a9635 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ requires-python = ">=3.10" dependencies = [ "xarray>=2025.0", "cacheout>=0.16.0", - "arcae>=0.3.0", + "arcae>=0.4.0a1", "typing-extensions>=4.12.2", ] diff --git a/xarray_ms/backend/msv2/entrypoint_utils.py b/xarray_ms/backend/msv2/entrypoint_utils.py index c8bedd0..56106f3 100644 --- a/xarray_ms/backend/msv2/entrypoint_utils.py +++ b/xarray_ms/backend/msv2/entrypoint_utils.py @@ -25,9 +25,7 @@ def subtable_factory(name: str) -> pa.Table: - return Table.from_filename( - name, ninstances=1, readonly=True, lockoptions="nolock" - ).to_arrow() + return Table.from_filename(name, ninstances=1, readonly=True).to_arrow() class CommonStoreArgs: @@ -79,11 +77,7 @@ def __init__( self.partition_schema = partition_schema or DEFAULT_PARTITION_COLUMNS self.preferred_chunks = preferred_chunks or {} self.ms_factory = ms_factory or Multiton( - Table.from_filename, - self.ms, - ninstances=self.ninstances, - readonly=True, - lockoptions="nolock", + Table.from_filename, self.ms, ninstances=self.ninstances, readonly=True ) self.subtable_factories = subtable_factories or { subtable: Multiton(subtable_factory, f"{ms}::{subtable}") diff --git a/xarray_ms/backend/msv2/structure.py b/xarray_ms/backend/msv2/structure.py index ccd206e..024c789 100644 --- a/xarray_ms/backend/msv2/structure.py +++ b/xarray_ms/backend/msv2/structure.py @@ -601,7 +601,7 @@ def read_subtable( else: return None - with Table.from_filename(subtable_path, lockoptions="nolock") as T: + with Table.from_filename(subtable_path) as T: return T.to_arrow(columns=list(columns)) def __init__( diff --git a/xarray_ms/testing/simulator.py b/xarray_ms/testing/simulator.py index 4672bdd..f53ac02 100644 --- a/xarray_ms/testing/simulator.py +++ b/xarray_ms/testing/simulator.py @@ -225,7 +225,7 @@ def simulate_ms(self, output_ms: str) -> None: # Generate descriptors, create simulated data from the descriptors # and write simulated data to the main Measurement Set - with Table.ms_from_descriptor(output_ms, "MAIN", self.table_desc) as T: + with Table.ms_from_descriptor(output_ms, "MAIN", 1, self.table_desc) as T: startrow = 0 for chunk_desc in self.generate_descriptors(): @@ -343,9 +343,7 @@ def simulate_ms(self, output_ms: str) -> None: T.putcol("RELEASE_DATE", np.asarray([0.0] * self.nobs)) source_table_desc = ms_descriptor("SOURCE", complete=True) - with Table.ms_from_descriptor( - output_ms, "SOURCE", table_desc=source_table_desc - ) as T: + with Table.ms_from_descriptor(output_ms, "SOURCE", 1, source_table_desc) as T: T.addrows(self.nfield) T.putcol("TIME", np.asarray([1.0] * self.nfield)) T.putcol("INTERVAL", np.asarray([1.0] * self.nfield)) From 76672067299f1ed6a14f8c73af00bed16aacc0a9 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Wed, 13 Aug 2025 14:04:24 +0200 Subject: [PATCH 12/23] Distinguish data written by two test cases --- tests/test_write.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_write.py b/tests/test_write.py index 4db4e8d..20eeb6b 100644 --- a/tests/test_write.py +++ b/tests/test_write.py @@ -21,7 +21,7 @@ def test_store(monkeypatch, simmed_ms): assert not np.all(node.UVW == 0) node.UVW[:] = 0 assert len(node.encoding) > 0 - ds = node.ds.assign(CORRECTED=xarray.full_like(node.VISIBILITY, 1 + 2j)) + ds = node.ds.assign(CORRECTED=xarray.full_like(node.VISIBILITY, 2 + 3j)) xdt[node.path] = DataTree(ds) assert len(node.encoding) > 0 @@ -41,7 +41,7 @@ def test_store(monkeypatch, simmed_ms): # But we can check that CORRECTED has been written correctly with arcae.table(simmed_ms) as T: - np.testing.assert_array_equal(T.getcol("CORRECTED"), 1 + 2j) + np.testing.assert_array_equal(T.getcol("CORRECTED"), 2 + 3j) def test_store_region(monkeypatch, simmed_ms): From ce26c370e12ab2277f387a60509488769b8202a0 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Wed, 13 Aug 2025 14:06:08 +0200 Subject: [PATCH 13/23] Temporarily get test cases working by using a function session scope for the simmed_ms fixture --- tests/conftest.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 4d06155..042b363 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,10 +14,11 @@ def clear_caches(): MSv2StructureFactory._STRUCTURE_CACHE.clear() -@pytest.fixture(scope="session", params=[DEFAULT_SIM_PARAMS]) +@pytest.fixture(scope="function", params=[DEFAULT_SIM_PARAMS]) def simmed_ms(request, tmp_path_factory): - ms = tmp_path_factory.mktemp("simulated") / request.param.pop("name", "test.ms") - simulator = MSStructureSimulator(**{**DEFAULT_SIM_PARAMS, **request.param}) + params = request.param.copy() + ms = tmp_path_factory.mktemp("simulated") / params.pop("name", "test.ms") + simulator = MSStructureSimulator(**{**DEFAULT_SIM_PARAMS, **params}) simulator.simulate_ms(str(ms)) return str(ms) From c567a5c40d10231fce2112565a826cfc44dd2a57 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Mon, 25 Aug 2025 18:18:59 +0200 Subject: [PATCH 14/23] Upgrade to arcae 0.4.0.alpha.2 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 83a9635..0b04e18 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ requires-python = ">=3.10" dependencies = [ "xarray>=2025.0", "cacheout>=0.16.0", - "arcae>=0.4.0a1", + "arcae>=0.4.0a2", "typing-extensions>=4.12.2", ] From f5f95305af9feb7bc39170b2372423d010ef4671 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Mon, 25 Aug 2025 18:23:10 +0200 Subject: [PATCH 15/23] Change simmed_ms test fixture back to scope --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 042b363..081e81a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,7 +14,7 @@ def clear_caches(): MSv2StructureFactory._STRUCTURE_CACHE.clear() -@pytest.fixture(scope="function", params=[DEFAULT_SIM_PARAMS]) +@pytest.fixture(scope="session", params=[DEFAULT_SIM_PARAMS]) def simmed_ms(request, tmp_path_factory): params = request.param.copy() ms = tmp_path_factory.mktemp("simulated") / params.pop("name", "test.ms") From df6db52b328d00ceef5bd5ad4923b9dbbc92555a Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Mon, 25 Aug 2025 18:24:03 +0200 Subject: [PATCH 16/23] Create specific test data for individual write tests --- tests/test_write.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/tests/test_write.py b/tests/test_write.py index 20eeb6b..5417b30 100644 --- a/tests/test_write.py +++ b/tests/test_write.py @@ -1,5 +1,6 @@ import arcae import numpy as np +import pytest import xarray from xarray import Dataset, DataTree @@ -7,6 +8,7 @@ from xarray_ms.msv4_types import CORRELATED_DATASET_TYPES +@pytest.mark.parametrize("simmed_ms", [{"name": "test_store.ms"}], indirect=True) def test_store(monkeypatch, simmed_ms): monkeypatch.setattr(Dataset, "to_msv2", dataset_to_msv2, raising=False) monkeypatch.setattr(DataTree, "to_msv2", datatree_to_msv2, raising=False) @@ -44,6 +46,7 @@ def test_store(monkeypatch, simmed_ms): np.testing.assert_array_equal(T.getcol("CORRECTED"), 2 + 3j) +@pytest.mark.parametrize("simmed_ms", [{"name": "test_store_region.ms"}], indirect=True) def test_store_region(monkeypatch, simmed_ms): monkeypatch.setattr(Dataset, "to_msv2", dataset_to_msv2, raising=False) monkeypatch.setattr(DataTree, "to_msv2", datatree_to_msv2, raising=False) @@ -51,25 +54,25 @@ def test_store_region(monkeypatch, simmed_ms): region = {"time": slice(0, 2), "frequency": slice(2, 4)} with xarray.open_datatree(simmed_ms, auto_corrs=True) as xdt: - # Overwrite UVW coordinates with zeroes # Add a CORRECTED column for node in xdt.subtree: if node.attrs.get("type") in CORRELATED_DATASET_TYPES: - ds = node.ds.assign(CORRECTED=xarray.full_like(node.VISIBILITY, 1 + 2j)) + ds = node.ds.assign(CORRECTED=xarray.zeros_like(node.VISIBILITY)) xdt[node.path] = DataTree(ds) assert len(node.encoding) > 0 - xdt.to_msv2(["UVW", "CORRECTED"], compute=False) + # Create the new MS columns + xdt.to_msv2(["CORRECTED"], compute=False) for node in xdt.subtree: if node.attrs.get("type") in CORRELATED_DATASET_TYPES: sizes = node.sizes - # Slice out the region - ds = node.ds.isel(**region) + ds = ds.isel(**region) + ds = ds.assign(CORRECTED=xarray.full_like(ds.CORRECTED, 1 + 2j)) # Now write it out - ds.to_msv2(["UVW", "CORRECTED"], compute=False, region=region) + ds.to_msv2(["CORRECTED"], compute=False, region=region) - # But we can check that CORRECTED has been written correctly + # We can check that CORRECTED has been written correctly with arcae.table(simmed_ms) as T: corrected = T.getcol("CORRECTED") nt, nbl, nf, npol = ( From 82b9402a88d830fead296f5b3cf4a86b3b87438b Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Mon, 25 Aug 2025 18:25:01 +0200 Subject: [PATCH 17/23] Add a dask process/threads test case --- tests/test_write.py | 48 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/tests/test_write.py b/tests/test_write.py index 5417b30..da557a7 100644 --- a/tests/test_write.py +++ b/tests/test_write.py @@ -1,3 +1,5 @@ +from contextlib import ExitStack + import arcae import numpy as np import pytest @@ -84,3 +86,49 @@ def test_store_region(monkeypatch, simmed_ms): mask[ts, :, fs, :] = True np.testing.assert_array_equal(corrected[mask], 1 + 2j) np.testing.assert_array_equal(corrected[~mask], 0 + 0j) + + +@pytest.mark.parametrize("chunks", [{"time": 2, "frequency": 2}]) +@pytest.mark.parametrize("simmed_ms", [{"name": "distributed-write.ms"}], indirect=True) +@pytest.mark.parametrize("nworkers", [4]) +@pytest.mark.parametrize("processes", [True, False]) +def test_distributed_write(simmed_ms, monkeypatch, processes, nworkers, chunks): + monkeypatch.setattr(Dataset, "to_msv2", dataset_to_msv2, raising=False) + monkeypatch.setattr(DataTree, "to_msv2", datatree_to_msv2, raising=False) + da = pytest.importorskip("dask.array") + distributed = pytest.importorskip("dask.distributed") + Client = distributed.Client + LocalCluster = distributed.LocalCluster + + with ExitStack() as stack: + cluster = stack.enter_context(LocalCluster(processes=processes, n_workers=nworkers)) + stack.enter_context(Client(cluster)) + dt = stack.enter_context( + xarray.open_datatree(simmed_ms, chunks=chunks, auto_corrs=True) + ) + for node in dt.subtree: + if node.attrs.get("type") in CORRELATED_DATASET_TYPES: + vis = node.VISIBILITY + sizes = node.sizes + corrected = da.arange(np.prod(vis.shape), dtype=np.int32) + corrected = corrected.reshape(vis.shape).rechunk(vis.data.chunks) + ds = node.ds.assign(CORRECTED=(vis.dims, corrected)) + dt[node.path] = DataTree(ds) + assert len(node.encoding) > 0 + + # Create the new MS columns + dt.to_msv2(["CORRECTED"], compute=False) + + for node in dt.subtree: + if node.attrs.get("type") in CORRELATED_DATASET_TYPES: + print("Writing") + node.ds.to_msv2(["CORRECTED"], compute=True) + + with arcae.table(simmed_ms) as T: + corrected = T.getcol("CORRECTED") + shape = tuple( + sizes[d] for d in ("time", "baseline_id", "frequency", "polarization") + ) + expected = np.arange(np.prod(vis.shape), dtype=np.int32) + expected = expected.reshape((-1,) + shape[2:]) + np.testing.assert_array_equal(corrected, expected) From 9a4e4fa31b48c313b4c0a7898d473f783c566cc8 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Tue, 26 Aug 2025 09:46:24 +0200 Subject: [PATCH 18/23] Use max(shape[d], s.stop) to clamp expected_shape during writes --- xarray_ms/backend/msv2/array.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/xarray_ms/backend/msv2/array.py b/xarray_ms/backend/msv2/array.py index 84a8475..6d62890 100644 --- a/xarray_ms/backend/msv2/array.py +++ b/xarray_ms/backend/msv2/array.py @@ -2,7 +2,7 @@ from functools import reduce from operator import mul -from typing import TYPE_CHECKING, Any, Callable, Tuple +from typing import TYPE_CHECKING, Any, Callable, Literal, Tuple import numpy as np from xarray.backends import BackendArray @@ -22,13 +22,18 @@ TransformerT = Callable[[npt.NDArray], npt.NDArray] -def slice_length(s: npt.NDArray | slice, max_len) -> int: +def slice_length( + s: npt.NDArray | slice, max_len: int, op: Literal["read", "write"] +) -> int: if isinstance(s, np.ndarray): if s.ndim != 1: raise NotImplementedError("Slicing with non-1D numpy arrays") return len(s) - start, stop, step = s.indices(min(max_len, s.stop) if s.stop is not None else max_len) + clamp_op = min if op == "read" else max + start, stop, step = s.indices( + clamp_op(max_len, s.stop) if s.stop is not None else max_len + ) if step != 1: raise NotImplementedError(f"Slicing with steps {s} other than 1 not supported") return stop - start @@ -112,7 +117,7 @@ def __getitem__(self, key) -> npt.NDArray: def _getitem(self, key) -> npt.NDArray: assert len(key) == len(self.shape) - expected_shape = tuple(slice_length(k, s) for k, s in zip(key, self.shape)) + expected_shape = tuple(slice_length(k, s, "read") for k, s in zip(key, self.shape)) if reduce(mul, expected_shape, 1) == 0: return np.empty(expected_shape, dtype=self.dtype) # Map the (time, baseline_id) coordinates onto row indices @@ -126,7 +131,7 @@ def _getitem(self, key) -> npt.NDArray: def __setitem__(self, key, value: npt.NDArray) -> None: key = expanded_indexer(key, len(self.shape)) - expected_shape = tuple(slice_length(k, s) for k, s in zip(key, self.shape)) + expected_shape = tuple(slice_length(k, s, "write") for k, s in zip(key, self.shape)) rows = self._structure_factory.instance[self._partition].row_map[key[:2]] row_key = (rows.ravel(),) + key[2:] row_shape = (rows.size,) + expected_shape[2:] @@ -192,7 +197,7 @@ def __getitem__(self, key) -> npt.NDArray: ) def _getitem(self, key) -> npt.NDArray: - expected_shape = tuple(slice_length(k, s) for k, s in zip(key, self.shape)) + expected_shape = tuple(slice_length(k, s, "read") for k, s in zip(key, self.shape)) if reduce(mul, expected_shape, 1) == 0: return np.empty(expected_shape, dtype=self.dtype) low_res_key = tuple(k for i, k in zip(self._low_res_index, key) if i is not None) From 0af4a0f9754f03c9651aef989c94d102f02aec13 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Wed, 27 Aug 2025 15:34:31 +0200 Subject: [PATCH 19/23] Remove deprecated lockoptions kwarg --- xarray_ms/backend/msv2/entrypoint_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/xarray_ms/backend/msv2/entrypoint_utils.py b/xarray_ms/backend/msv2/entrypoint_utils.py index 514697a..e84095c 100644 --- a/xarray_ms/backend/msv2/entrypoint_utils.py +++ b/xarray_ms/backend/msv2/entrypoint_utils.py @@ -36,9 +36,7 @@ def subtable_factory( name: str, on_missing: Literal["raise", "empty_table"] = "empty_table" ) -> pa.Table: try: - return Table.from_filename( - name, ninstances=1, readonly=True, lockoptions="nolock" - ).to_arrow() + return Table.from_filename(name, ninstances=1, readonly=True).to_arrow() except pa.lib.ArrowInvalid as e: if "subtable" in e.msg and "is invalid" in e.msg: if on_missing == "raise": From 0a8354de12cd95f541617e6c7600f8dc781f8522 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Tue, 2 Sep 2025 11:10:16 +0200 Subject: [PATCH 20/23] Disable threads case --- tests/test_write.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_write.py b/tests/test_write.py index 640e72e..2ebaa49 100644 --- a/tests/test_write.py +++ b/tests/test_write.py @@ -92,7 +92,7 @@ def test_store_region(monkeypatch, simmed_ms): @pytest.mark.parametrize("chunks", [{"time": 2, "frequency": 2}]) @pytest.mark.parametrize("simmed_ms", [{"name": "distributed-write.ms"}], indirect=True) @pytest.mark.parametrize("nworkers", [4]) -@pytest.mark.parametrize("processes", [True, False]) +@pytest.mark.parametrize("processes", [True]) def test_distributed_write(simmed_ms, monkeypatch, processes, nworkers, chunks): monkeypatch.setattr(Dataset, "to_msv2", dataset_to_msv2, raising=False) monkeypatch.setattr(DataTree, "to_msv2", datatree_to_msv2, raising=False) From e10f6937eaf2891d81eab5c31b2cf6b4fd82b0da Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Tue, 2 Sep 2025 11:19:38 +0200 Subject: [PATCH 21/23] Create a new MS for test_indexed_write --- tests/test_write.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_write.py b/tests/test_write.py index 2ebaa49..51be7c1 100644 --- a/tests/test_write.py +++ b/tests/test_write.py @@ -92,7 +92,7 @@ def test_store_region(monkeypatch, simmed_ms): @pytest.mark.parametrize("chunks", [{"time": 2, "frequency": 2}]) @pytest.mark.parametrize("simmed_ms", [{"name": "distributed-write.ms"}], indirect=True) @pytest.mark.parametrize("nworkers", [4]) -@pytest.mark.parametrize("processes", [True]) +@pytest.mark.parametrize("processes", [True, False]) def test_distributed_write(simmed_ms, monkeypatch, processes, nworkers, chunks): monkeypatch.setattr(Dataset, "to_msv2", dataset_to_msv2, raising=False) monkeypatch.setattr(DataTree, "to_msv2", datatree_to_msv2, raising=False) @@ -135,6 +135,7 @@ def test_distributed_write(simmed_ms, monkeypatch, processes, nworkers, chunks): np.testing.assert_array_equal(corrected, expected) +@pytest.mark.parametrize("simmed_ms", [{"name": "indexed-write.ms"}], indirect=True) def test_indexed_write(monkeypatch, simmed_ms): """Check that we throw if we select a variable out with an integer index and then try write that sub-selection out""" From 3cce2d24d43d3a44270d5e980f7229306d241774 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Tue, 2 Sep 2025 11:20:13 +0200 Subject: [PATCH 22/23] Introduce a gc.collect into clear_caches --- tests/conftest.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 49f6971..9124b54 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,5 @@ +import gc + import numpy as np import pytest from arcae.lib.arrow_tables import Table, ms_descriptor @@ -38,8 +40,11 @@ def pytest_collection_modifyitems(config, items): @pytest.fixture(autouse=True) def clear_caches(): yield - Multiton._INSTANCE_CACHE.clear() + + # Structure Factories have references to Multitons MSv2StructureFactory._STRUCTURE_CACHE.clear() + Multiton._INSTANCE_CACHE.clear() + gc.collect() @pytest.fixture(scope="session", params=[DEFAULT_SIM_PARAMS]) From 71836e549b3b11c034acc20a21720eaea8176a4a Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Fri, 19 Sep 2025 10:29:48 +0200 Subject: [PATCH 23/23] Depend on arcae 0.4.0a3 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index fac6c4b..d3f08c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ requires-python = ">=3.10" dependencies = [ "xarray>=2025.0", "cacheout>=0.16.0", - "arcae>=0.4.0a2", + "arcae>=0.4.0a3", "typing-extensions>=4.12.2", ]