Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/xradio/measurement_set/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
convert_msv2_to_processing_set,
estimate_conversion_memory_and_cores,
)
from .open_msv2 import open_msv2
except ModuleNotFoundError as exc:
warnings.warn(
f"Could not import the function to convert from MSv2 to MSv4. "
Expand All @@ -34,5 +35,9 @@
)
else:
__all__.extend(
["convert_msv2_to_processing_set", "estimate_conversion_memory_and_cores"]
[
"convert_msv2_to_processing_set",
"estimate_conversion_memory_and_cores",
"open_msv2",
]
)
63 changes: 63 additions & 0 deletions src/xradio/measurement_set/_msv2_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""xarray backend engine for reading MSv2 as MSv4-schema datasets.

Registers ``xradio:msv2`` so that users can write::

xr.open_datatree("path/to/ms.ms", engine="xradio:msv2")

and get back the same DataTree that :func:`open_msv2` produces.
"""

import os

import xarray as xr
from xarray.backends import BackendEntrypoint


class MSv2BackendEntrypoint(BackendEntrypoint):
"""xarray backend for CASA MSv2 tables via xradio."""

def open_datatree(
self,
filename_or_obj,
*,
drop_variables=None,
partition_scheme=None,
main_chunksize=None,
with_pointing=True,
pointing_interpolate=False,
ephemeris_interpolate=False,
phase_cal_interpolate=False,
sys_cal_interpolate=False,
) -> xr.DataTree:
from xradio.measurement_set.open_msv2 import open_msv2

return open_msv2(
str(filename_or_obj),
partition_scheme=partition_scheme,
main_chunksize=main_chunksize,
with_pointing=with_pointing,
pointing_interpolate=pointing_interpolate,
ephemeris_interpolate=ephemeris_interpolate,
phase_cal_interpolate=phase_cal_interpolate,
sys_cal_interpolate=sys_cal_interpolate,
)

def guess_can_open(self, filename_or_obj) -> bool:
try:
path = str(filename_or_obj)
except Exception:
return False
return os.path.isdir(path) and os.path.isfile(os.path.join(path, "table.dat"))

description = "Open CASA MSv2 tables as MSv4-schema DataTree via xradio"
open_dataset_parameters = [
"filename_or_obj",
"drop_variables",
"partition_scheme",
"main_chunksize",
"with_pointing",
"pointing_interpolate",
"ephemeris_interpolate",
"phase_cal_interpolate",
"sys_cal_interpolate",
]
155 changes: 153 additions & 2 deletions src/xradio/measurement_set/_utils/_msv2/_tables/read.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Callable, Dict, List, Tuple, Union

import dask.array as da
from dask.utils import SerializableLock
import numpy as np
import pandas as pd
import xarray as xr
Expand Down Expand Up @@ -1284,6 +1285,8 @@ def read_col_conversion_dask(
-------
da.Array
"""
# Serialize casacore access across dask threads (casacore is not thread-safe)
_casacore_lock = SerializableLock()

# Use casacore to get the shape of a row for this column
#################################################################################
Expand Down Expand Up @@ -1337,6 +1340,7 @@ def read_col_conversion_dask(
rows_per_time=rows_per_time,
cshape=cshape,
extra_dimensions=extra_dimensions,
lock=_casacore_lock,
drop_axis=[1],
new_axis=list(range(1, len(cshape + extra_dimensions))),
meta=np.array([], dtype=col_dtype),
Expand All @@ -1356,6 +1360,7 @@ def load_col_chunk(
rows_per_time,
cshape,
extra_dimensions,
lock=None,
):
start_row = x[0][0]
end_row = x[0][1]
Expand All @@ -1368,8 +1373,15 @@ def load_col_chunk(

# Load data from the column
# Release the casacore table as soon as possible
with table_manager.get_table() as tb_tool:
tb_tool.getcolnp(col_name, row_data, startrow=start_row, nrow=num_rows)
# Acquire lock to serialize casacore access (not thread-safe)
if lock is not None:
lock.acquire()
try:
with table_manager.get_table() as tb_tool:
tb_tool.getcolnp(col_name, row_data, startrow=start_row, nrow=num_rows)
finally:
if lock is not None:
lock.release()

# Initialise reshaped numpy array
reshaped_data = np.full(
Expand All @@ -1389,3 +1401,142 @@ def load_col_chunk(
reshaped_data[tidxs_slc, bidxs_slc] = row_data

return reshaped_data


def read_col_conversion_dask_sparse(
table_manager: TableManager,
col: str,
cshape: Tuple[int],
tidxs: np.ndarray,
bidxs: np.ndarray,
use_table_iter: bool,
time_chunksize: int,
) -> da.Array:
"""Lazy dask reader for sparse data (not every time has every baseline).

Unlike :func:`read_col_conversion_dask`, this function does NOT assume
that the number of MSv2 rows equals ``ntime * nbaseline``. Instead it
groups rows by time index, builds per-time-chunk row-ranges, and pads
missing baselines with fill values.

The function signature matches :func:`read_col_conversion_dask` and
:func:`read_col_conversion_numpy` so it is a drop-in replacement in
:func:`get_read_col_conversion_function`.
"""
_casacore_lock = SerializableLock()

with table_manager.get_table() as tb_tool:
if tb_tool.isscalarcol(col):
extra_dimensions = ()
else:
shape_string = tb_tool.getcolshapestring(col)[0]
extra_dimensions = tuple(
int(dim) for dim in shape_string.strip("[]").split(", ")
)
col_dtype = np.array(tb_tool.col(col)[0]).dtype

fill_value = _sparse_pad_value(col_dtype)

num_utimes = cshape[0]
n_baselines = cshape[1]

# Build cumulative row offsets per unique time.
rows_per_time = np.bincount(tidxs, minlength=num_utimes)
cum_rows = np.empty(num_utimes + 1, dtype=np.int64)
cum_rows[0] = 0
np.cumsum(rows_per_time, out=cum_rows[1:])

# Chunk along the time axis
tmp_chunks = da.core.normalize_chunks(time_chunksize, (num_utimes,))[0]

# Build (start_row, end_row, n_times_in_chunk) per chunk
chunk_specs = []
t_offset = 0
for chunk in tmp_chunks:
start_row = int(cum_rows[t_offset])
end_row = int(cum_rows[t_offset + chunk])
chunk_specs.append((start_row, end_row, chunk))
t_offset += chunk

arr_specs = da.from_array(np.array(chunk_specs, dtype=np.int64), chunks=(1, 3))

output_chunkshape = (tmp_chunks, n_baselines) + extra_dimensions

data = arr_specs.map_blocks(
_load_col_chunk_sparse,
table_manager=table_manager,
col_name=col,
col_dtype=col_dtype,
fill_value=fill_value,
tidxs=tidxs,
bidxs=bidxs,
n_baselines=n_baselines,
extra_dimensions=extra_dimensions,
lock=_casacore_lock,
drop_axis=[1],
new_axis=list(range(1, len(cshape + extra_dimensions))),
meta=np.array([], dtype=col_dtype),
chunks=output_chunkshape,
)

return data


def _load_col_chunk_sparse(
x,
table_manager,
col_name,
col_dtype,
fill_value,
tidxs,
bidxs,
n_baselines,
extra_dimensions,
lock=None,
):
"""Per-chunk read for sparse data."""
start_row = int(x[0][0])
end_row = int(x[0][1])
num_utimes = int(x[0][2])
num_rows = end_row - start_row

if num_rows == 0:
return np.full(
(num_utimes, n_baselines) + extra_dimensions, fill_value, dtype=col_dtype
)

row_data = np.full((num_rows,) + extra_dimensions, fill_value, dtype=col_dtype)

if lock is not None:
lock.acquire()
try:
with table_manager.get_table() as tb_tool:
tb_tool.getcolnp(col_name, row_data, startrow=start_row, nrow=num_rows)
finally:
if lock is not None:
lock.release()

reshaped_data = np.full(
(num_utimes, n_baselines) + extra_dimensions, fill_value, dtype=col_dtype
)

slc = slice(start_row, end_row)
tidxs_slc = tidxs[slc] - tidxs[start_row]
bidxs_slc = bidxs[slc]

# Only scatter rows with valid baseline index
valid = (bidxs_slc >= 0) & (bidxs_slc < n_baselines)
reshaped_data[tidxs_slc[valid], bidxs_slc[valid]] = row_data[valid]

return reshaped_data


def _sparse_pad_value(dtype: np.dtype):
"""Return the fill value for missing baselines, matching get_pad_value semantics."""
if np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating):
return np.nan
elif np.issubdtype(dtype, np.bool_):
return True # Missing → flagged
elif np.issubdtype(dtype, np.integer):
return 0
return np.nan
26 changes: 16 additions & 10 deletions src/xradio/measurement_set/_utils/_msv2/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
extract_table_attributes,
read_col_conversion_numpy,
read_col_conversion_dask,
read_col_conversion_dask_sparse,
load_generic_table,
)
from ._tables.read_main_table import get_baselines, get_baseline_indices, get_utimes_tol
Expand Down Expand Up @@ -693,7 +694,8 @@ def create_data_variables(
def get_read_col_conversion_function(col_name: str, parallel_mode: str) -> Callable:
"""
Returns the appropriate read_col_conversion function: use the dask version
for large columns and parallel_mode="time", or the numpy version otherwise.
for large columns and parallel_mode="time" or "sparse", or the numpy
version otherwise.
"""
large_columns = {
"DATA",
Expand All @@ -703,11 +705,13 @@ def get_read_col_conversion_function(col_name: str, parallel_mode: str) -> Calla
"WEIGHT",
"FLAG",
}
return (
read_col_conversion_dask
if parallel_mode == "time" and col_name in large_columns
else read_col_conversion_numpy
)
if col_name not in large_columns:
return read_col_conversion_numpy
if parallel_mode == "time":
return read_col_conversion_dask
if parallel_mode == "sparse":
return read_col_conversion_dask_sparse
return read_col_conversion_numpy


def repeat_weight_array(
Expand Down Expand Up @@ -1118,9 +1122,9 @@ def get_observation_info(in_file, observation_id, scan_intents):
"software_name": "xradio",
"version": importlib.metadata.version("xradio"),
},
"creation_date": datetime.datetime.now(
datetime.timezone.utc
).isoformat(),
"creation_date": (
datetime.datetime.now(datetime.timezone.utc).isoformat()
),
"type": "visibility",
}
)
Expand Down Expand Up @@ -1524,7 +1528,9 @@ def add_group_to_data_groups(
"flag": "FLAG",
"weight": "WEIGHT",
"field_and_source": f"field_and_source_{what_group}_xds",
"description": f"Data group derived from the data column '{correlated_data_name}' of an MSv2 converted to MSv4",
"description": (
f"Data group derived from the data column '{correlated_data_name}' of an MSv2 converted to MSv4"
),
"date": datetime.datetime.now(datetime.timezone.utc).isoformat(),
}
if uvw:
Expand Down
Loading
Loading