diff --git a/src/astrohack/core/extract_pointing.py b/src/astrohack/core/extract_pointing.py index 2c2db7e8..6bdd7f84 100644 --- a/src/astrohack/core/extract_pointing.py +++ b/src/astrohack/core/extract_pointing.py @@ -179,7 +179,7 @@ def extract_pointing_chunk(pnt_params: dict, output_mds: AstrohackPointFile): tb.close() table_obj.close() - _evaluate_time_samping(direction_time, ant_name) + _evaluate_time_sampling(direction_time, ant_name) pnt_xds = xr.Dataset() coords = {"time": direction_time} @@ -379,17 +379,22 @@ def _extract_scan_time_dict_jit(time, scan_ids, state_ids, ddi_ids, mapping_stat return scan_time_dict -def _evaluate_time_samping( - time_sampling, data_label, threshold=0.01, expected_interval=0.1 +def _evaluate_time_sampling( + time_sampling, data_label, threshold=0.01, expected_interval=None ): + intervals = np.diff(time_sampling) + unq_intervals, counts = np.unique(intervals, return_counts=True) + if expected_interval is None: + i_max_count = np.argmax(counts) + expected_interval = unq_intervals[i_max_count] + bin_sz = expected_interval / 4 time_bin_edge = np.arange(-bin_sz / 2, 2.5 * expected_interval, bin_sz) time_bin_axis = time_bin_edge[:-1] + bin_sz / 2 i_mid = int(np.argmin(np.abs(time_bin_axis - expected_interval))) - intervals = np.diff(time_sampling) hist, edges = np.histogram(intervals, bins=time_bin_edge) - n_total = np.sum(hist) + n_total = np.nansum(hist) outlier_fraction = 1 - hist[i_mid] / n_total if outlier_fraction > threshold: diff --git a/src/astrohack/core/image_comparison_tool.py b/src/astrohack/core/image_comparison_tool.py index d71107a7..03607b2b 100644 --- a/src/astrohack/core/image_comparison_tool.py +++ b/src/astrohack/core/image_comparison_tool.py @@ -127,7 +127,8 @@ def _init_as_fits(self, fits_filename, telescope_name, istokes=0, ichan=0): """ self.filename = fits_filename self.telescope_name = telescope_name - self.rootname = ".".join(fits_filename.split(".")[:-1]) + "." + fits_real_filename = fits_filename.split("/")[-1] + self.rootname = ".".join(fits_real_filename.split(".")[:-1]) + "." self.header, self.data = read_fits(self.filename, header_as_dict=True) stokes_iaxis = get_stokes_axis_iaxis(self.header) @@ -152,12 +153,16 @@ def _init_as_fits(self, fits_filename, telescope_name, istokes=0, ichan=0): self.y_axis, _, self.y_unit = get_axis_from_fits_header( self.header, 2, pixel_offset=False ) + offset_scale = 1.5 + x_offset = offset_scale * np.unique(np.diff(self.x_axis))[0] + y_offset = offset_scale * np.unique(np.diff(self.y_axis))[0] + self.x_axis = np.flip(self.x_axis + x_offset) + self.y_axis = np.flip(self.y_axis + y_offset) self.x_unit = "m" self.y_unit = "m" elif "Astrohack" in self.header["ORIGIN"]: self.x_axis, _, self.x_unit = get_axis_from_fits_header(self.header, 1) self.y_axis, _, self.y_unit = get_axis_from_fits_header(self.header, 2) - self.data = np.fliplr(self.data) else: raise NotImplementedError(f'Unrecognized origin:\n{self.header["origin"]}') self._create_base_mask() @@ -208,11 +213,13 @@ def resample(self, ref_image): x_mesh_dest, y_mesh_dest = np.meshgrid( ref_image.x_axis, ref_image.y_axis, indexing="ij" ) + raveled_data = self.data.ravel() + valid_data = np.isfinite(raveled_data) resamp = griddata( - (x_mesh_orig.ravel(), y_mesh_orig.ravel()), - self.data.ravel(), + (x_mesh_orig.ravel()[valid_data], y_mesh_orig.ravel()[valid_data]), + raveled_data[valid_data], (x_mesh_dest.ravel(), y_mesh_dest.ravel()), - method="linear", + method="nearest", ) size = ref_image.x_axis.shape[0], ref_image.y_axis.shape[0] self.x_axis = ref_image.x_axis @@ -592,7 +599,15 @@ def export_to_fits(self, destination): reorder_axis=False, ) - def scatter_plot(self, destination, ref_image, dpi=300, display=False): + def scatter_plot( + self, + destination, + ref_image, + dpi=300, + display=False, + max_radius=None, + min_radius=None, + ): """ Produce a scatter plot of self.data agains ref_image.data Args: @@ -600,6 +615,8 @@ def scatter_plot(self, destination, ref_image, dpi=300, display=False): ref_image: Reference FITSImage object dpi: png resolution on disk display: Show interactive view of plot + max_radius: Maximum radius for scatter plot comparison as the outer panels can be crappy. + min_radius: Minimum radius for scatter plot comparison as the innermost panels can be crappy. Returns: None @@ -610,10 +627,23 @@ def scatter_plot(self, destination, ref_image, dpi=300, display=False): fig, ax = plt.subplots(1, 1, figsize=[10, 8]) + x_mesh_orig, y_mesh_orig = np.meshgrid(self.x_axis, self.y_axis, indexing="ij") + radius = np.sqrt(x_mesh_orig**2 + y_mesh_orig**2) + + telescope = get_proper_telescope(self.telescope_name) + if min_radius is None: + min_radius = telescope.inner_radial_limit + if max_radius is None: + max_radius = telescope.outer_radial_limit - 1.0 scatter_mask = np.isfinite(ref_image.data) scatter_mask = np.where(np.isfinite(self.data), scatter_mask, False) + scatter_mask = np.where(radius < max_radius, scatter_mask, False) + scatter_mask = np.where(radius > min_radius, scatter_mask, False) + ydata = self.data[scatter_mask] xdata = ref_image.data[scatter_mask] + pl_max = np.max((np.max(xdata), np.max(ydata))) + pl_min = np.min((np.min(xdata), np.min(ydata))) scatter_plot( ax, @@ -622,6 +652,12 @@ def scatter_plot(self, destination, ref_image, dpi=300, display=False): ydata, f"{self.filename} [{self.unit}]", add_regression=True, + regression_method="siegelslopes", + add_regression_reference=True, + regression_reference_label="Perfect agreement", + xlim=(pl_min, pl_max), + ylim=[pl_min, pl_max], + force_equal_aspect=True, ) close_figure( fig, @@ -641,7 +677,6 @@ def image_comparison_chunk(compare_params): Returns: A DataTree containing the Image and its reference Image. """ - image = FITSImage.from_fits_file( compare_params["this_image"], compare_params["telescope_name"] ) @@ -698,8 +733,8 @@ def image_comparison_chunk(compare_params): if compare_params["plot_scatter"]: image.scatter_plot(destination, ref_image, dpi=dpi, display=display) - img_node = xr.DataTree(name=image.filename, dataset=image.export_as_xds()) - ref_node = xr.DataTree(name=ref_image.filename, dataset=ref_image.export_as_xds()) + img_node = xr.DataTree(name=image.rootname, dataset=image.export_as_xds()) + ref_node = xr.DataTree(name=ref_image.rootname, dataset=ref_image.export_as_xds()) tree_node = xr.DataTree( name=image.rootname[:-1], children={"Reference": ref_node, "Image": img_node} ) diff --git a/src/astrohack/extract_holog.py b/src/astrohack/extract_holog.py index f0a29314..337de77c 100644 --- a/src/astrohack/extract_holog.py +++ b/src/astrohack/extract_holog.py @@ -1,4 +1,5 @@ import pathlib +from copy import deepcopy import toolviper.utils.parameter import toolviper.utils.logger as logger @@ -187,6 +188,9 @@ def extract_holog( } """ + # This copy here ensures that the user space holog_obs_dict given to extract_holog is not modified during execution + if holog_obs_dict is not None: + holog_obs_dict = deepcopy(holog_obs_dict) # Doing this here allows it to get captured by locals() holog_name = get_default_file_name(ms_name, ".holog.zarr", holog_name) diff --git a/src/astrohack/image_comparison_tool.py b/src/astrohack/image_comparison_tool.py index 042ce93e..4ed1b84a 100644 --- a/src/astrohack/image_comparison_tool.py +++ b/src/astrohack/image_comparison_tool.py @@ -201,7 +201,7 @@ def rms_table_from_zarr_datatree( raise FileNotFoundError xdt = xr.open_datatree(input_params["zarr_data_tree"]) - if xdt.attrs["origin"] != "compare_fits_images": + if xdt.attrs["origin_info"]["creator_function"] != "compare_fits_images": logger.error("Data tree file was not created by astrohack.compare_fits_images") raise ValueError diff --git a/src/astrohack/utils/text.py b/src/astrohack/utils/text.py index 3044093d..446aaec0 100644 --- a/src/astrohack/utils/text.py +++ b/src/astrohack/utils/text.py @@ -132,6 +132,9 @@ def get_default_file_name( else: output_filename = user_filename + if output_filename[-1] == "/": + output_filename = output_filename[:-1] + logger.info(f"Creating output file name: {output_filename}") return output_filename diff --git a/src/astrohack/visualization/plot_tools.py b/src/astrohack/visualization/plot_tools.py index d8e6d582..60f3ffd4 100644 --- a/src/astrohack/visualization/plot_tools.py +++ b/src/astrohack/visualization/plot_tools.py @@ -1,6 +1,6 @@ import matplotlib.image import numpy as np -from scipy.stats import linregress +from scipy.stats import linregress, theilslopes, siegelslopes from matplotlib import pyplot as plt from matplotlib.patches import Rectangle @@ -251,6 +251,12 @@ def scatter_plot( add_regression=False, regression_linestyle="-", regression_color="black", + regression_method="linregress", + add_regression_reference=False, + regression_reference=(1.0, 0.0), + regression_reference_color="orange", + regression_reference_label="Regression refrence", + force_equal_aspect=False, add_legend=True, legend_location="best", ): @@ -287,6 +293,12 @@ def scatter_plot( add_regression: Add a linear regression between X and y data regression_linestyle: Line style for the regression plot regression_color: Color for the regression plot + regression_method: Which scipy function to use for the linear regression: linregress, theilslopes or siegelslopes + add_regression_reference: Add reference for the expected regression result + regression_reference: 2 value array/tuple/list with a slope and intercept for reference + regression_reference_color: Color for reference regression + regression_reference_label: Label for reference regression + force_equal_aspect: Force equal aspect on plot box add_legend: add legend to the plot legend_location: Location of the legend in the plot """ @@ -325,7 +337,14 @@ def scatter_plot( ) if add_regression: - slope, intercept, _, _, _ = linregress(xdata, ydata) + if regression_method == "linregress": + slope, intercept, _, _, _ = linregress(xdata, ydata) + elif regression_method == "theilslopes": + slope, intercept, _, _ = theilslopes(ydata, xdata) + elif regression_method == "siegelslopes": + slope, intercept = siegelslopes(ydata, xdata) + else: + raise RuntimeError(f"Unknown linear regression method: {regression_method}") regression_label = f"y = {slope:.4f}*x + {intercept:.4f}" yregress = slope * xdata + intercept ax.plot( @@ -336,6 +355,15 @@ def scatter_plot( label=regression_label, lw=2, ) + if add_regression_reference: + reg_ref = regression_reference[0] * xdata + regression_reference[1] + ax.plot( + xdata, + reg_ref, + ls=regression_linestyle, + color=regression_reference_color, + label=regression_reference_label, + ) if model is not None: ax.plot( @@ -373,6 +401,9 @@ def scatter_plot( ax_res.set_ylabel("Residuals") ax_res.set_xlabel(xlabel) + if force_equal_aspect: + ax.set_aspect("equal", adjustable="box") + if title is not None: ax.set_title(title)