Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
0f0a2df
Limited write support: initial support
sjperkins May 21, 2025
d582291
Remove debugging assign
sjperkins May 21, 2025
99d19bf
Remove unused method
sjperkins May 21, 2025
7253be7
Remove unused assign of common_store_args on the DataTree root
sjperkins May 21, 2025
cdd9573
Remove unused tmp_path fixture
sjperkins May 21, 2025
7d3e4f9
Use ternary operator
sjperkins May 21, 2025
389c451
Merge branch 'main' into write-support
sjperkins Jun 30, 2025
4cffec8
Merge branch 'main' into write-support
sjperkins Jul 16, 2025
658dfb6
Prototype support for adding new columns
sjperkins Jul 16, 2025
ce3716a
Use existing table descriptor during synthesis of column descriptors
sjperkins Jul 16, 2025
ffdc94d
Preliminary support for writing with regions
sjperkins Jul 17, 2025
d68ece2
Comments and typing
sjperkins Jul 18, 2025
30c54fa
Merge branch 'main' into write-support
sjperkins Jul 18, 2025
6fa1347
Merge branch 'main' into write-support
sjperkins Jul 18, 2025
85623e9
Merge branch 'main' into write-support
sjperkins Jul 18, 2025
4d1e0bb
Depend on arcae 0.4.0-alpha.1
sjperkins Aug 12, 2025
7667206
Distinguish data written by two test cases
sjperkins Aug 13, 2025
ce26c37
Temporarily get test cases working by using a function session scope …
sjperkins Aug 13, 2025
c567a5c
Upgrade to arcae 0.4.0.alpha.2
sjperkins Aug 25, 2025
f5f9530
Change simmed_ms test fixture back to scope
sjperkins Aug 25, 2025
df6db52
Create specific test data for individual write tests
sjperkins Aug 25, 2025
82b9402
Add a dask process/threads test case
sjperkins Aug 25, 2025
919946a
Merge branch 'main' into write-support
sjperkins Aug 26, 2025
9a4e4fa
Use max(shape[d], s.stop) to clamp expected_shape during writes
sjperkins Aug 26, 2025
84b1afe
Merge branch 'main' into write-support
sjperkins Aug 26, 2025
a92b605
Merge branch 'main' into write-support
sjperkins Aug 27, 2025
0af4a0f
Remove deprecated lockoptions kwarg
sjperkins Aug 27, 2025
451f830
Merge branch 'main' into write-support
sjperkins Sep 2, 2025
0a8354d
Disable threads case
sjperkins Sep 2, 2025
e10f693
Create a new MS for test_indexed_write
sjperkins Sep 2, 2025
3cce2d2
Introduce a gc.collect into clear_caches
sjperkins Sep 2, 2025
d3ab196
Merge branch 'main' into write-support
sjperkins Sep 3, 2025
91cb53b
Merge branch 'main' into write-support
sjperkins Sep 18, 2025
71836e5
Depend on arcae 0.4.0a3
sjperkins Sep 19, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ requires-python = ">=3.10"
dependencies = [
"xarray>=2025.0",
"cacheout>=0.16.0",
"arcae>=0.3.2",
"arcae>=0.4.0a3",
"typing-extensions>=4.12.2",
]

Expand Down
12 changes: 9 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import gc

import numpy as np
import pytest
from arcae.lib.arrow_tables import Table, ms_descriptor
Expand Down Expand Up @@ -38,14 +40,18 @@ def pytest_collection_modifyitems(config, items):
@pytest.fixture(autouse=True)
def clear_caches():
yield
Multiton._INSTANCE_CACHE.clear()

# Structure Factories have references to Multitons
MSv2StructureFactory._STRUCTURE_CACHE.clear()
Multiton._INSTANCE_CACHE.clear()
gc.collect()


@pytest.fixture(scope="session", params=[DEFAULT_SIM_PARAMS])
def simmed_ms(request, tmp_path_factory):
ms = tmp_path_factory.mktemp("simulated") / request.param.pop("name", "test.ms")
simulator = MSStructureSimulator(**{**DEFAULT_SIM_PARAMS, **request.param})
params = request.param.copy()
ms = tmp_path_factory.mktemp("simulated") / params.pop("name", "test.ms")
simulator = MSStructureSimulator(**{**DEFAULT_SIM_PARAMS, **params})
simulator.simulate_ms(str(ms))
return str(ms)

Expand Down
158 changes: 158 additions & 0 deletions tests/test_write.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
from contextlib import ExitStack

import arcae
import numpy as np
import pytest
import xarray
from xarray import Dataset, DataTree

from xarray_ms.backend.msv2.writes import dataset_to_msv2, datatree_to_msv2
from xarray_ms.errors import MismatchedWriteRegion
from xarray_ms.msv4_types import CORRELATED_DATASET_TYPES


@pytest.mark.parametrize("simmed_ms", [{"name": "test_store.ms"}], indirect=True)
def test_store(monkeypatch, simmed_ms):
monkeypatch.setattr(Dataset, "to_msv2", dataset_to_msv2, raising=False)
monkeypatch.setattr(DataTree, "to_msv2", datatree_to_msv2, raising=False)

read = written = False

with xarray.open_datatree(simmed_ms, auto_corrs=True) as xdt:
# Overwrite UVW coordinates with zeroes
# Add a CORRECTED column
for node in xdt.subtree:
if node.attrs.get("type") in CORRELATED_DATASET_TYPES:
assert not np.all(node.UVW == 0)
node.UVW[:] = 0
assert len(node.encoding) > 0
ds = node.ds.assign(CORRECTED=xarray.full_like(node.VISIBILITY, 2 + 3j))
xdt[node.path] = DataTree(ds)
assert len(node.encoding) > 0

xdt.to_msv2(["UVW", "CORRECTED"])
written = written or True

with xarray.open_datatree(simmed_ms, auto_corrs=True) as xdt:
for node in xdt.subtree:
if node.attrs.get("type") in CORRELATED_DATASET_TYPES:
assert np.all(node.UVW == 0)
# Non-standard columns aren't yet exposed
# assert np.all(node.CORRECTED == 1 + 2j)
read = read or True

assert read
assert written

# But we can check that CORRECTED has been written correctly
with arcae.table(simmed_ms) as T:
np.testing.assert_array_equal(T.getcol("CORRECTED"), 2 + 3j)


@pytest.mark.parametrize("simmed_ms", [{"name": "test_store_region.ms"}], indirect=True)
def test_store_region(monkeypatch, simmed_ms):
monkeypatch.setattr(Dataset, "to_msv2", dataset_to_msv2, raising=False)
monkeypatch.setattr(DataTree, "to_msv2", datatree_to_msv2, raising=False)

region = {"time": slice(0, 2), "frequency": slice(2, 4)}

with xarray.open_datatree(simmed_ms, auto_corrs=True) as xdt:
# Add a CORRECTED column
for node in xdt.subtree:
if node.attrs.get("type") in CORRELATED_DATASET_TYPES:
ds = node.ds.assign(CORRECTED=xarray.zeros_like(node.VISIBILITY))
xdt[node.path] = DataTree(ds)
assert len(node.encoding) > 0

# Create the new MS columns
xdt.to_msv2(["CORRECTED"], compute=False)

for node in xdt.subtree:
if node.attrs.get("type") in CORRELATED_DATASET_TYPES:
sizes = node.sizes
ds = ds.isel(**region)
ds = ds.assign(CORRECTED=xarray.full_like(ds.CORRECTED, 1 + 2j))
# Now write it out
ds.to_msv2(["CORRECTED"], compute=False, region=region)

# We can check that CORRECTED has been written correctly
with arcae.table(simmed_ms) as T:
corrected = T.getcol("CORRECTED")
nt, nbl, nf, npol = (
sizes[d] for d in ("time", "baseline_id", "frequency", "polarization")
)
corrected = corrected.reshape((nt, nbl, nf, npol))
ts, fs = (region[d] for d in ("time", "frequency"))
mask = np.full(corrected.shape, False, np.bool_)
mask[ts, :, fs, :] = True
np.testing.assert_array_equal(corrected[mask], 1 + 2j)
np.testing.assert_array_equal(corrected[~mask], 0 + 0j)


@pytest.mark.parametrize("chunks", [{"time": 2, "frequency": 2}])
@pytest.mark.parametrize("simmed_ms", [{"name": "distributed-write.ms"}], indirect=True)
@pytest.mark.parametrize("nworkers", [4])
@pytest.mark.parametrize("processes", [True, False])
def test_distributed_write(simmed_ms, monkeypatch, processes, nworkers, chunks):
monkeypatch.setattr(Dataset, "to_msv2", dataset_to_msv2, raising=False)
monkeypatch.setattr(DataTree, "to_msv2", datatree_to_msv2, raising=False)
da = pytest.importorskip("dask.array")
distributed = pytest.importorskip("dask.distributed")
Client = distributed.Client
LocalCluster = distributed.LocalCluster

with ExitStack() as stack:
cluster = stack.enter_context(LocalCluster(processes=processes, n_workers=nworkers))
stack.enter_context(Client(cluster))
dt = stack.enter_context(
xarray.open_datatree(simmed_ms, chunks=chunks, auto_corrs=True)
)
for node in dt.subtree:
if node.attrs.get("type") in CORRELATED_DATASET_TYPES:
vis = node.VISIBILITY
sizes = node.sizes
corrected = da.arange(np.prod(vis.shape), dtype=np.int32)
corrected = corrected.reshape(vis.shape).rechunk(vis.data.chunks)
ds = node.ds.assign(CORRECTED=(vis.dims, corrected))
dt[node.path] = DataTree(ds)
assert len(node.encoding) > 0

# Create the new MS columns
dt.to_msv2(["CORRECTED"], compute=False)

for node in dt.subtree:
if node.attrs.get("type") in CORRELATED_DATASET_TYPES:
print("Writing")
node.ds.to_msv2(["CORRECTED"], compute=True)

with arcae.table(simmed_ms) as T:
corrected = T.getcol("CORRECTED")
shape = tuple(
sizes[d] for d in ("time", "baseline_id", "frequency", "polarization")
)
expected = np.arange(np.prod(vis.shape), dtype=np.int32)
expected = expected.reshape((-1,) + shape[2:])
np.testing.assert_array_equal(corrected, expected)


@pytest.mark.parametrize("simmed_ms", [{"name": "indexed-write.ms"}], indirect=True)
def test_indexed_write(monkeypatch, simmed_ms):
"""Check that we throw if we select a variable out with an integer index
and then try write that sub-selection out"""
monkeypatch.setattr(Dataset, "to_msv2", dataset_to_msv2, raising=False)
monkeypatch.setattr(DataTree, "to_msv2", datatree_to_msv2, raising=False)
dt = xarray.open_datatree(simmed_ms)
assert len(dt.children) == 1

for node in dt.subtree:
if node.attrs.get("type") in CORRELATED_DATASET_TYPES:
ds = node.ds.assign(CORRECTED=xarray.full_like(node.VISIBILITY, 1 + 2j))
dt[node.path] = DataTree(ds)

dt.to_msv2(["CORRECTED"], compute=False)

for node in dt.subtree:
if node.attrs.get("type") in CORRELATED_DATASET_TYPES:
ds = node.ds.isel(time=slice(0, 2), baseline_id=slice(0, 2), frequency=1)
with pytest.raises(MismatchedWriteRegion):
ds.to_msv2(["CORRECTED"], compute=True)
67 changes: 60 additions & 7 deletions xarray_ms/backend/msv2/array.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
from __future__ import annotations

import re
from contextlib import contextmanager
from functools import reduce
from operator import mul
from typing import TYPE_CHECKING, Any, Callable, Tuple
from typing import TYPE_CHECKING, Any, Callable, Literal, Tuple

import numpy as np
from xarray.backends import BackendArray
from xarray.core.indexing import (
IndexingSupport,
OuterIndexer,
expanded_indexer,
explicit_indexing_adapter,
)

from xarray_ms.errors import MismatchedWriteRegion

if TYPE_CHECKING:
import numpy.typing as npt

Expand All @@ -20,19 +25,52 @@

TransformerT = Callable[[npt.NDArray], npt.NDArray]

MAIN_PREFIX_DIMS = ("time", "baseline_id")

DATA_COLUMN_DIM_MISMATCH_RE = re.compile(
r"Number of data dimensions (?P<data>\d+) does not "
r"match number of column dimensions (?P<column>\d+)"
)


def slice_length(s: npt.NDArray | slice, max_len) -> int:
def slice_length(
s: npt.NDArray | slice, max_len: int, op: Literal["read", "write"]
) -> int:
if isinstance(s, np.ndarray):
if s.ndim != 1:
raise NotImplementedError("Slicing with non-1D numpy arrays")
return len(s)

start, stop, step = s.indices(min(max_len, s.stop) if s.stop is not None else max_len)
clamp_op = min if op == "read" else max
start, stop, step = s.indices(
clamp_op(max_len, s.stop) if s.stop is not None else max_len
)
if step != 1:
raise NotImplementedError(f"Slicing with steps {s} other than 1 not supported")
return stop - start


@contextmanager
def rethrow_arcae_exceptions():
import pyarrow as pa

try:
yield
except pa.lib.ArrowInvalid as e:
if m := re.match(DATA_COLUMN_DIM_MISMATCH_RE, str(e)):
data_dims = m.group("data")
column_dims = m.group("column")
raise MismatchedWriteRegion(
f"Attempted to write an array of dimensionality {data_dims} "
f"to a column of dimensionality {column_dims} "
f"This can occur if xarray variable dimensions "
f"have been selected out with integer indices. "
f"Prefer xarray semantics that retain dimensionality "
f"such as slicing during selection or sum(..., keepdims=True)"
)
raise


class MSv2Array(BackendArray):
"""Base MSv2Array backend array class,
containing required shape and data type"""
Expand All @@ -49,6 +87,9 @@ def __init__(self, shape: Tuple[int, ...], dtype: npt.DTypeLike):
def __getitem__(self, key) -> npt.NDArray:
raise NotImplementedError

def __setitem__(self, key, value) -> None:
raise NotImplementedError

@property
def transform(self) -> TransformerT | None:
raise NotImplementedError
Expand Down Expand Up @@ -99,7 +140,7 @@ def __init__(
self._default = default
self._transform = transform

assert len(shape) >= 2, "(time, baseline_ids) required"
assert len(shape) >= len(MAIN_PREFIX_DIMS), f"{MAIN_PREFIX_DIMS} required"

def __getitem__(self, key) -> npt.NDArray:
return explicit_indexing_adapter(
Expand All @@ -118,20 +159,32 @@ def promote_integer_dims(key):
def _getitem(self, key) -> npt.NDArray:
assert len(key) == len(self.shape)
key, squeeze_axis = self.promote_integer_dims(key)
expected_shape = tuple(slice_length(k, s) for k, s in zip(key, self.shape))
expected_shape = tuple(slice_length(k, s, "read") for k, s in zip(key, self.shape))
if reduce(mul, expected_shape, 1) == 0:
return np.empty(expected_shape, dtype=self.dtype)
# Map the (time, baseline_id) coordinates onto row indices
rows = self._structure_factory.instance[self._partition].row_map[key[:2]]
row_key = (rows.ravel(),) + key[2:]
row_shape = (rows.size,) + expected_shape[2:]
result = np.full(row_shape, self._default, dtype=self.dtype)
self._table_factory.instance.getcol(self._column, row_key, result)
with rethrow_arcae_exceptions():
self._table_factory.instance.getcol(self._column, row_key, result)
result = result.reshape(rows.shape + expected_shape[2:])
# arcae doesn't handle squeezing out the selecting axis so we do it here.
result = result.squeeze(axis=squeeze_axis)
return self._transform(result) if self._transform else result

def __setitem__(self, key, value: npt.NDArray) -> None:
key = expanded_indexer(key, len(self.shape))
expected_shape = tuple(slice_length(k, s, "write") for k, s in zip(key, self.shape))
rows = self._structure_factory.instance[self._partition].row_map[key[:2]]
row_key = (rows.ravel(),) + key[2:]
row_shape = (rows.size,) + expected_shape[2:]
value = self._transform(value) if self._transform else value
value = value.reshape(row_shape)
with rethrow_arcae_exceptions():
self._table_factory.instance.putcol(self._column, value, index=row_key)

@property
def transform(self) -> TransformerT | None:
return self._transform
Expand Down Expand Up @@ -190,7 +243,7 @@ def __getitem__(self, key) -> npt.NDArray:
)

def _getitem(self, key) -> npt.NDArray:
expected_shape = tuple(slice_length(k, s) for k, s in zip(key, self.shape))
expected_shape = tuple(slice_length(k, s, "read") for k, s in zip(key, self.shape))
if reduce(mul, expected_shape, 1) == 0:
return np.empty(expected_shape, dtype=self.dtype)
low_res_key = tuple(k for i, k in zip(self._low_res_index, key) if i is not None)
Expand Down
Loading