Skip to content
Merged
15 changes: 10 additions & 5 deletions src/astrohack/core/extract_pointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def extract_pointing_chunk(pnt_params: dict, output_mds: AstrohackPointFile):
tb.close()
table_obj.close()

_evaluate_time_samping(direction_time, ant_name)
_evaluate_time_sampling(direction_time, ant_name)

pnt_xds = xr.Dataset()
coords = {"time": direction_time}
Expand Down Expand Up @@ -379,17 +379,22 @@ def _extract_scan_time_dict_jit(time, scan_ids, state_ids, ddi_ids, mapping_stat
return scan_time_dict


def _evaluate_time_samping(
time_sampling, data_label, threshold=0.01, expected_interval=0.1
def _evaluate_time_sampling(
time_sampling, data_label, threshold=0.01, expected_interval=None
):
intervals = np.diff(time_sampling)
unq_intervals, counts = np.unique(intervals, return_counts=True)
if expected_interval is None:
i_max_count = np.argmax(counts)
expected_interval = unq_intervals[i_max_count]

bin_sz = expected_interval / 4
time_bin_edge = np.arange(-bin_sz / 2, 2.5 * expected_interval, bin_sz)
time_bin_axis = time_bin_edge[:-1] + bin_sz / 2
i_mid = int(np.argmin(np.abs(time_bin_axis - expected_interval)))

intervals = np.diff(time_sampling)
hist, edges = np.histogram(intervals, bins=time_bin_edge)
n_total = np.sum(hist)
n_total = np.nansum(hist)
outlier_fraction = 1 - hist[i_mid] / n_total

if outlier_fraction > threshold:
Expand Down
53 changes: 44 additions & 9 deletions src/astrohack/core/image_comparison_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ def _init_as_fits(self, fits_filename, telescope_name, istokes=0, ichan=0):
"""
self.filename = fits_filename
self.telescope_name = telescope_name
self.rootname = ".".join(fits_filename.split(".")[:-1]) + "."
fits_real_filename = fits_filename.split("/")[-1]
self.rootname = ".".join(fits_real_filename.split(".")[:-1]) + "."
self.header, self.data = read_fits(self.filename, header_as_dict=True)
stokes_iaxis = get_stokes_axis_iaxis(self.header)

Expand All @@ -152,12 +153,16 @@ def _init_as_fits(self, fits_filename, telescope_name, istokes=0, ichan=0):
self.y_axis, _, self.y_unit = get_axis_from_fits_header(
self.header, 2, pixel_offset=False
)
offset_scale = 1.5
x_offset = offset_scale * np.unique(np.diff(self.x_axis))[0]
y_offset = offset_scale * np.unique(np.diff(self.y_axis))[0]
self.x_axis = np.flip(self.x_axis + x_offset)
self.y_axis = np.flip(self.y_axis + y_offset)
self.x_unit = "m"
self.y_unit = "m"
elif "Astrohack" in self.header["ORIGIN"]:
self.x_axis, _, self.x_unit = get_axis_from_fits_header(self.header, 1)
self.y_axis, _, self.y_unit = get_axis_from_fits_header(self.header, 2)
self.data = np.fliplr(self.data)
else:
raise NotImplementedError(f'Unrecognized origin:\n{self.header["origin"]}')
self._create_base_mask()
Expand Down Expand Up @@ -208,11 +213,13 @@ def resample(self, ref_image):
x_mesh_dest, y_mesh_dest = np.meshgrid(
ref_image.x_axis, ref_image.y_axis, indexing="ij"
)
raveled_data = self.data.ravel()
valid_data = np.isfinite(raveled_data)
resamp = griddata(
(x_mesh_orig.ravel(), y_mesh_orig.ravel()),
self.data.ravel(),
(x_mesh_orig.ravel()[valid_data], y_mesh_orig.ravel()[valid_data]),
raveled_data[valid_data],
(x_mesh_dest.ravel(), y_mesh_dest.ravel()),
method="linear",
method="nearest",
)
size = ref_image.x_axis.shape[0], ref_image.y_axis.shape[0]
self.x_axis = ref_image.x_axis
Expand Down Expand Up @@ -592,14 +599,24 @@ def export_to_fits(self, destination):
reorder_axis=False,
)

def scatter_plot(self, destination, ref_image, dpi=300, display=False):
def scatter_plot(
self,
destination,
ref_image,
dpi=300,
display=False,
max_radius=None,
min_radius=None,
):
"""
Produce a scatter plot of self.data agains ref_image.data
Args:
destination: Location to store scatter plot
ref_image: Reference FITSImage object
dpi: png resolution on disk
display: Show interactive view of plot
max_radius: Maximum radius for scatter plot comparison as the outer panels can be crappy.
min_radius: Minimum radius for scatter plot comparison as the innermost panels can be crappy.

Returns:
None
Expand All @@ -610,10 +627,23 @@ def scatter_plot(self, destination, ref_image, dpi=300, display=False):

fig, ax = plt.subplots(1, 1, figsize=[10, 8])

x_mesh_orig, y_mesh_orig = np.meshgrid(self.x_axis, self.y_axis, indexing="ij")
radius = np.sqrt(x_mesh_orig**2 + y_mesh_orig**2)

telescope = get_proper_telescope(self.telescope_name)
if min_radius is None:
min_radius = telescope.inner_radial_limit
if max_radius is None:
max_radius = telescope.outer_radial_limit - 1.0
scatter_mask = np.isfinite(ref_image.data)
scatter_mask = np.where(np.isfinite(self.data), scatter_mask, False)
scatter_mask = np.where(radius < max_radius, scatter_mask, False)
scatter_mask = np.where(radius > min_radius, scatter_mask, False)

ydata = self.data[scatter_mask]
xdata = ref_image.data[scatter_mask]
pl_max = np.max((np.max(xdata), np.max(ydata)))
pl_min = np.min((np.min(xdata), np.min(ydata)))

scatter_plot(
ax,
Expand All @@ -622,6 +652,12 @@ def scatter_plot(self, destination, ref_image, dpi=300, display=False):
ydata,
f"{self.filename} [{self.unit}]",
add_regression=True,
regression_method="siegelslopes",
add_regression_reference=True,
regression_reference_label="Perfect agreement",
xlim=(pl_min, pl_max),
ylim=[pl_min, pl_max],
force_equal_aspect=True,
)
close_figure(
fig,
Expand All @@ -641,7 +677,6 @@ def image_comparison_chunk(compare_params):
Returns:
A DataTree containing the Image and its reference Image.
"""

image = FITSImage.from_fits_file(
compare_params["this_image"], compare_params["telescope_name"]
)
Expand Down Expand Up @@ -698,8 +733,8 @@ def image_comparison_chunk(compare_params):
if compare_params["plot_scatter"]:
image.scatter_plot(destination, ref_image, dpi=dpi, display=display)

img_node = xr.DataTree(name=image.filename, dataset=image.export_as_xds())
ref_node = xr.DataTree(name=ref_image.filename, dataset=ref_image.export_as_xds())
img_node = xr.DataTree(name=image.rootname, dataset=image.export_as_xds())
ref_node = xr.DataTree(name=ref_image.rootname, dataset=ref_image.export_as_xds())
tree_node = xr.DataTree(
name=image.rootname[:-1], children={"Reference": ref_node, "Image": img_node}
)
Expand Down
4 changes: 4 additions & 0 deletions src/astrohack/extract_holog.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pathlib
from copy import deepcopy

import toolviper.utils.parameter
import toolviper.utils.logger as logger
Expand Down Expand Up @@ -187,6 +188,9 @@ def extract_holog(
}

"""
# This copy here ensures that the user space holog_obs_dict given to extract_holog is not modified during execution
if holog_obs_dict is not None:
holog_obs_dict = deepcopy(holog_obs_dict)

# Doing this here allows it to get captured by locals()
holog_name = get_default_file_name(ms_name, ".holog.zarr", holog_name)
Expand Down
2 changes: 1 addition & 1 deletion src/astrohack/image_comparison_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def rms_table_from_zarr_datatree(
raise FileNotFoundError

xdt = xr.open_datatree(input_params["zarr_data_tree"])
if xdt.attrs["origin"] != "compare_fits_images":
if xdt.attrs["origin_info"]["creator_function"] != "compare_fits_images":
logger.error("Data tree file was not created by astrohack.compare_fits_images")
raise ValueError

Expand Down
3 changes: 3 additions & 0 deletions src/astrohack/utils/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ def get_default_file_name(
else:
output_filename = user_filename

if output_filename[-1] == "/":
output_filename = output_filename[:-1]

logger.info(f"Creating output file name: {output_filename}")
return output_filename

Expand Down
35 changes: 33 additions & 2 deletions src/astrohack/visualization/plot_tools.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import matplotlib.image
import numpy as np
from scipy.stats import linregress
from scipy.stats import linregress, theilslopes, siegelslopes

from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle
Expand Down Expand Up @@ -251,6 +251,12 @@ def scatter_plot(
add_regression=False,
regression_linestyle="-",
regression_color="black",
regression_method="linregress",
add_regression_reference=False,
regression_reference=(1.0, 0.0),
regression_reference_color="orange",
regression_reference_label="Regression refrence",
force_equal_aspect=False,
add_legend=True,
legend_location="best",
):
Expand Down Expand Up @@ -287,6 +293,12 @@ def scatter_plot(
add_regression: Add a linear regression between X and y data
regression_linestyle: Line style for the regression plot
regression_color: Color for the regression plot
regression_method: Which scipy function to use for the linear regression: linregress, theilslopes or siegelslopes
add_regression_reference: Add reference for the expected regression result
regression_reference: 2 value array/tuple/list with a slope and intercept for reference
regression_reference_color: Color for reference regression
regression_reference_label: Label for reference regression
force_equal_aspect: Force equal aspect on plot box
add_legend: add legend to the plot
legend_location: Location of the legend in the plot
"""
Expand Down Expand Up @@ -325,7 +337,14 @@ def scatter_plot(
)

if add_regression:
slope, intercept, _, _, _ = linregress(xdata, ydata)
if regression_method == "linregress":
slope, intercept, _, _, _ = linregress(xdata, ydata)
elif regression_method == "theilslopes":
slope, intercept, _, _ = theilslopes(ydata, xdata)
elif regression_method == "siegelslopes":
slope, intercept = siegelslopes(ydata, xdata)
else:
raise RuntimeError(f"Unknown linear regression method: {regression_method}")
regression_label = f"y = {slope:.4f}*x + {intercept:.4f}"
yregress = slope * xdata + intercept
ax.plot(
Expand All @@ -336,6 +355,15 @@ def scatter_plot(
label=regression_label,
lw=2,
)
if add_regression_reference:
reg_ref = regression_reference[0] * xdata + regression_reference[1]
ax.plot(
xdata,
reg_ref,
ls=regression_linestyle,
color=regression_reference_color,
label=regression_reference_label,
)

if model is not None:
ax.plot(
Expand Down Expand Up @@ -373,6 +401,9 @@ def scatter_plot(
ax_res.set_ylabel("Residuals")
ax_res.set_xlabel(xlabel)

if force_equal_aspect:
ax.set_aspect("equal", adjustable="box")

if title is not None:
ax.set_title(title)

Expand Down
Loading