Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/xradio/image/_util/_casacore/xds_from_casacore.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,15 +674,13 @@ def _get_persistent_block(
infile: str,
shapes: tuple,
starts: tuple,
dimorder: list,
transpose_list: list,
new_axes: list,
) -> xr.DataArray:
) -> da.Array:
block = _read_image_chunk(infile, shapes, starts)
block = np.expand_dims(block, new_axes)
block = block.transpose(transpose_list)
block = da.from_array(block, chunks=block.shape)
block = xr.DataArray(block, dims=dimorder)
return block


Expand Down
14 changes: 10 additions & 4 deletions src/xradio/image/_util/casacore.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,15 @@
def _squeeze_if_needed(ary: da, image_type: str) -> da:
if image_type.upper() == "VISIBILITY_NORMALIZATION":
shape = ary.shape
if len(shape) != 5:
raise ValueError(
"VISIBILITY_NORMALIZATION casa image must be 5D before squeezing. "
f"Found shape {shape}"
)
if shape[3] != 1 or shape[4] != 1:
raise ValueError(
"VISIBILITY_NORMALIZATION casa image must have l and m of length 1. Found "
+ [shape[3], shape[4]]
"VISIBILITY_NORMALIZATION casa image must have l and m of length 1. "
f"Found {(shape[3], shape[4])}"
)
ary = ary.squeeze(axis=(3, 4))
return ary
Expand Down Expand Up @@ -99,7 +104,7 @@ def _load_casa_image_block(
starts, shapes, slices = _get_starts_shapes_slices(block_des, coords, cshape)
transpose_list, new_axes = _get_transpose_list(coords)
block = _get_persistent_block(
image_full_path, shapes, starts, dimorder, transpose_list, new_axes
image_full_path, shapes, starts, transpose_list, new_axes
)
block = _squeeze_if_needed(block, image_type)
xds = _add_sky_or_aperture(
Expand All @@ -109,8 +114,9 @@ def _load_casa_image_block(
for m in mymasks:
full_path = os.sep.join([image_full_path, m])
block = _get_persistent_block(
full_path, shapes, starts, dimorder, transpose_list, new_axes
full_path, shapes, starts, transpose_list, new_axes
)
block = _squeeze_if_needed(block, image_type)
# data vars are all caps by convention
Comment on lines 116 to 120
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The newly added _squeeze_if_needed(...) call can trigger an exception path where _squeeze_if_needed attempts to build the ValueError message via string concatenation with a Python list ("... Found " + [shape[3], shape[4]]), which will raise a TypeError and mask the intended validation error. Please format the shape values as a string (e.g., tuple/list via f-string/str(...)) so the correct ValueError is raised when l/m are not length 1 (and consider guarding against unexpected dimensionality before indexing shape[3]/shape[4]).

Copilot uses AI. Check for mistakes.
Comment on lines 116 to 120
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new squeeze step inside the mask loop isn’t covered by the added unit test (the synthetic sumwt image in the test has no mask, so the loop never executes). Consider extending/adding a test that creates a VISIBILITY_NORMALIZATION image with a mask (e.g., MASK_0/default mask) and asserts the returned flag/mask data variable is also squeezed to (time, frequency, polarization) with no l/m dims.

Copilot uses AI. Check for mistakes.
mask_name = re.sub(r"\bMASK(\d+)\b", r"MASK_\1", m.upper())
xds = _add_mask(xds, mask_name, block, dimorder)
Expand Down
87 changes: 40 additions & 47 deletions src/xradio/image/_util/image_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,6 +911,19 @@ def create_image_xds_from_store(
img_xds[image_type] = xds[image_type]
img_xds[image_type].attrs["type"] = image_type.lower()

expected_flag_name = "FLAG_" + image_type

def _add_flag_to_output(
img_xds: xr.Dataset,
flag_array: xr.DataArray,
expected_flag_name: str,
active_group: dict | None = None,
):
img_xds[expected_flag_name] = flag_array
img_xds[expected_flag_name].attrs["type"] = "flag"
if active_group is not None:
active_group["flag"] = expected_flag_name

active_data_group_name = None
# If sky image, handle internal masks and beam fit params.
if "sky" in image_type.lower():
Expand All @@ -925,55 +938,35 @@ def create_image_xds_from_store(
data_groups[active_data_group_name]["beam_fit_params_sky"] = (
"BEAM_FIT_PARAMS_" + image_type.upper()
)
expected_flag_name = "FLAG_" + image_type

# TODO remove this mask logic and everything that still makes it necessary
def _add_flag_to_group(
img_xds: xr.Dataset,
flag_array: xr.DataArray,
expected_flag_name: str,
active_group: dict,
):
img_xds[expected_flag_name] = flag_array
img_xds[expected_flag_name].attrs["type"] = "flag"
active_group["flag"] = expected_flag_name

if expected_flag_name in xds:
_add_flag_to_group(
img_xds,
xds[expected_flag_name],
expected_flag_name,
data_groups[active_data_group_name],
)

if "MASK_0" in xds:
_add_flag_to_group(
img_xds,
xds["MASK_0"],
expected_flag_name,
data_groups[active_data_group_name],
)
"""
TODO delete old code when certain new function works
img_xds[expected_flag_name] = xds["MASK_0"]
data_groups[active_data_group_name]["flag"] = expected_flag_name
img_xds[expected_flag_name].attrs["type"] = "flag"
"""
if "MASK" in xds:
_add_flag_to_group(
img_xds,
xds["MASK"],
expected_flag_name,
data_groups[active_data_group_name],
)
"""
TODO delete old code when certain new function works
img_xds[expected_flag_name] = xds["MASK"]
data_groups[active_data_group_name]["flag"] = expected_flag_name
img_xds[expected_flag_name].attrs["type"] = "flag"
"""
img_xds[image_type].attrs["type"] = "sky"

active_group = (
data_groups[active_data_group_name]
if active_data_group_name is not None
else None
)
if expected_flag_name in xds:
_add_flag_to_output(
img_xds,
xds[expected_flag_name],
expected_flag_name,
active_group,
)
elif "MASK_0" in xds:
_add_flag_to_output(
img_xds,
xds["MASK_0"],
expected_flag_name,
active_group,
)
elif "MASK" in xds:
_add_flag_to_output(
img_xds,
xds["MASK"],
expected_flag_name,
active_group,
)

# If point spread function, handle beam fit params.
if "point_spread_function" in image_type.lower():
if "BEAM_FIT_PARAMS_" + image_type.upper() in xds:
Expand Down
61 changes: 61 additions & 0 deletions tests/unit/image/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
open_image,
write_image,
)
from xradio.image._util.casacore import _squeeze_if_needed
from xradio.image._util._casacore.common import _create_new_image as create_new_image
from xradio.image._util._casacore.common import _open_image_ro as open_image_ro
from xradio.image._util.common import _image_type as image_type
Expand All @@ -57,6 +58,66 @@ def clean_path_logic(text: str) -> str:
return text


def test_load_visibility_normalization_block_squeezes_spatial_axes(tmp_path):
imagename = tmp_path / "synthetic.sumwt"
data = np.arange(8, dtype=np.float32).reshape(4, 2, 1, 1)
masked_data = ma.masked_array(data, np.zeros_like(data, dtype=bool))

with create_new_image(str(imagename), shape=list(data.shape)) as im:
im.put(masked_data)

xds = load_image({"visibility_normalization": str(imagename)})

assert xds.VISIBILITY_NORMALIZATION.dims == (
"time",
"frequency",
"polarization",
)
assert xds.VISIBILITY_NORMALIZATION.shape == (1, 4, 2)
assert "l" not in xds.dims
assert "m" not in xds.dims
np.testing.assert_array_equal(
xds.VISIBILITY_NORMALIZATION.values,
data[np.newaxis, :, :, 0, 0],
)


def test_squeeze_if_needed_rejects_non_singleton_spatial_axes():
data = da.from_array(np.zeros((1, 4, 2, 2, 3), dtype=np.float32))

with pytest.raises(
ValueError,
match=r"VISIBILITY_NORMALIZATION casa image must have l and m of length 1\. Found \(2, 3\)",
):
_squeeze_if_needed(data, "VISIBILITY_NORMALIZATION")


def test_load_visibility_normalization_mask_squeezes_spatial_axes(tmp_path):
imagename = tmp_path / "masked.sumwt"
data = np.arange(8, dtype=np.float32).reshape(4, 2, 1, 1)
mask = np.zeros_like(data, dtype=bool)
mask[1, 0, 0, 0] = True
masked_data = ma.masked_array(data, mask)

with create_new_image(str(imagename), shape=list(data.shape), mask="MASK_0") as im:
im.put(masked_data)

xds = load_image({"visibility_normalization": str(imagename)})

assert xds.FLAG_VISIBILITY_NORMALIZATION.dims == (
"time",
"frequency",
"polarization",
)
assert xds.FLAG_VISIBILITY_NORMALIZATION.shape == (1, 4, 2)
assert "l" not in xds.FLAG_VISIBILITY_NORMALIZATION.dims
assert "m" not in xds.FLAG_VISIBILITY_NORMALIZATION.dims
np.testing.assert_array_equal(
xds.FLAG_VISIBILITY_NORMALIZATION.values,
mask[np.newaxis, :, :, 0, 0],
)


@pytest.fixture(scope="module")
def dask_client_module():
"""Set up and tear down a Dask client for the test module.
Expand Down
Loading