diff --git a/docs/api.rst b/docs/api.rst index 42c391e6..2b2800d1 100755 --- a/docs/api.rst +++ b/docs/api.rst @@ -16,3 +16,4 @@ Under development. _api/autoapi/astrohack/extract_locit/index _api/autoapi/astrohack/locit/index _api/autoapi/astrohack/cassegrain_ray_tracing/index + _api/autoapi/astrohack/image_comparison_tool/index diff --git a/src/astrohack/__init__.py b/src/astrohack/__init__.py index ed095d08..979771c6 100644 --- a/src/astrohack/__init__.py +++ b/src/astrohack/__init__.py @@ -14,6 +14,7 @@ from .locit import * from .extract_locit import * from .cassegrain_ray_tracing import * +from .image_comparison_tool import * # This installs a slick, informational tracebacks logger from rich.traceback import install diff --git a/src/astrohack/antenna/antenna_surface.py b/src/astrohack/antenna/antenna_surface.py index 081c02a4..5f4e85c8 100644 --- a/src/astrohack/antenna/antenna_surface.py +++ b/src/astrohack/antenna/antenna_surface.py @@ -14,7 +14,7 @@ from astrohack.visualization.plot_tools import well_positioned_colorbar, create_figure_and_axes, close_figure, \ get_proper_color_map -from astrohack.utils.fits import write_fits, resolution_to_fits_header, axis_to_fits_header +from astrohack.utils.fits import write_fits, put_resolution_in_fits_header, put_axis_in_fits_header lnbr = "\n" SUPPORTED_POL_STATES = ['I', 'RR', 'LL', 'XX', 'YY'] @@ -172,7 +172,7 @@ def _read_xds(self, inputxds): def _define_amp_clip(self, clip_type, clip_level): self.amplitude_noise = np.where(self.base_mask, np.nan, self.amplitude) if clip_type is None or clip_type == 'none': - clip = -np.inf + clip = np.nanmin(self.amplitude) elif clip_type == 'relative': clip = clip_level * np.nanmax(self.amplitude) elif clip_type == 'absolute': @@ -303,7 +303,7 @@ def _create_aperture_mask(self, clip_type, clip_level, exclude_shadows): arm_angle = 0.0 self.base_mask, self.rad, self.phi = create_aperture_mask(self.u_axis, self.v_axis, self.telescope.inlim, - self.telescope.diam/2, + self.telescope.oulim, arm_width=arm_width, arm_angle=arm_angle, return_polar_meshes=True) @@ -865,9 +865,9 @@ def export_to_fits(self, basename): 'WAVELENG': self.wavelength, 'FREQUENC': clight / self.wavelength, } - head = axis_to_fits_header(head, self.u_axis, 1, 'X----LIN', 'm') - head = axis_to_fits_header(head, self.v_axis, 2, 'Y----LIN', 'm') - head = resolution_to_fits_header(head, self.resolution) + head = put_axis_in_fits_header(head, self.u_axis, 1, 'X----LIN', 'm') + head = put_axis_in_fits_header(head, self.v_axis, 2, 'Y----LIN', 'm') + head = put_resolution_in_fits_header(head, self.resolution) write_fits(head, 'Amplitude', self.amplitude, add_prefix(basename, 'amplitude') + '.fits', self.amp_unit, 'panel') diff --git a/src/astrohack/antenna/telescope.py b/src/astrohack/antenna/telescope.py index 9e2c2dff..79e526dc 100644 --- a/src/astrohack/antenna/telescope.py +++ b/src/astrohack/antenna/telescope.py @@ -110,9 +110,19 @@ def write(self, filename): Args: filename: Name of the output file """ - ledict = vars(self) + obj_dict = vars(self) xds = xr.Dataset() - xds.attrs = ledict + xds.attrs = obj_dict + xds.to_zarr(filename, mode="w", compute=True, consolidated=True) + return + + def _save_to_dist(self): + obj_dict = vars(self) + filename = f'{self.filepath}/{self.filename}' + obj_dict.pop('filepath', None) + obj_dict.pop('filename', None) + xds = xr.Dataset() + xds.attrs = obj_dict xds.to_zarr(filename, mode="w", compute=True, consolidated=True) return @@ -139,7 +149,7 @@ def print(self): def __repr__(self): outstr = '' - ledict = vars(self) - for key, item in ledict.items(): + obj_dict = vars(self) + for key, item in obj_dict.items(): outstr += f"{key:20s} = {str(item)}\n" return outstr diff --git a/src/astrohack/config/image_comparison_tool.param.json b/src/astrohack/config/image_comparison_tool.param.json new file mode 100644 index 00000000..49b7e5b4 --- /dev/null +++ b/src/astrohack/config/image_comparison_tool.param.json @@ -0,0 +1,85 @@ +{ + "compare_fits_images":{ + "image":{ + "nullable": false, + "required": true, + "struct_type": ["str"], + "type": ["string", "list"] + }, + "reference_image":{ + "nullable": false, + "required": true, + "struct_type": ["str"], + "type": ["string", "list"] + }, + "telescope_name":{ + "nullable": false, + "required": true, + "type": ["string"] + }, + "destination":{ + "nullable": false, + "required": true, + "type": ["string"] + }, + "comparison":{ + "allowed": ["direct", "scaled"], + "nullable": false, + "required": false, + "type": ["string"] + }, + "zarr_container_name":{ + "nullable": true, + "required": false, + "type": ["string"] + }, + "plot_data":{ + "nullable": false, + "required": false, + "type": ["boolean"] + }, + "plot_percentuals":{ + "nullable": false, + "required": false, + "type": ["boolean"] + }, + "plot_divided_image":{ + "nullable": false, + "required": false, + "type": ["boolean"] + }, + "plot_scatter":{ + "nullable": false, + "required": false, + "type": ["boolean"] + }, + "export_to_fits":{ + "nullable": false, + "required": false, + "type": ["boolean"] + }, + "colormap":{ + "nullable": false, + "required": false, + "type": ["string"], + "check allowed with": "colormaps" + }, + "dpi":{ + "nullable": false, + "required": false, + "type": ["int"], + "min": 1, + "max": 1200 + }, + "display":{ + "nullable": false, + "required": false, + "type": ["boolean"] + }, + "parallel":{ + "nullable": false, + "required": false, + "type": ["boolean"] + } + } +} diff --git a/src/astrohack/core/image_comparison_tool.py b/src/astrohack/core/image_comparison_tool.py new file mode 100644 index 00000000..0fbe73c1 --- /dev/null +++ b/src/astrohack/core/image_comparison_tool.py @@ -0,0 +1,507 @@ +import numpy as np +from scipy.interpolate import griddata +from matplotlib import pyplot as plt +import xarray as xr +import pathlib + +from astrohack.antenna.telescope import Telescope +from astrohack.utils.text import statistics_to_text +from astrohack.utils.algorithms import create_aperture_mask, data_statistics, are_axes_equal +from astrohack.visualization.plot_tools import well_positioned_colorbar, compute_extent +from astrohack.visualization.plot_tools import close_figure, get_proper_color_map, scatter_plot +from astrohack.utils.fits import read_fits, get_axis_from_fits_header, get_stokes_axis_iaxis, put_axis_in_fits_header, \ + write_fits + + +def test_image(fits_image): + if isinstance(fits_image, FITSImage): + pass + else: + raise TypeError('Reference image is not a FITSImage object') + + +class FITSImage: + + def __init__(self): + """ + Blank slate initialization of the FITSImage object + """ + # Attributes: + self.filename = None + self.telescope_name = None + self.rootname = None + self.factor = 1.0 + self.reference_name = None + self.resampled = False + + # Metadata + self.header = None + self.unit = None + self.x_axis = None + self.y_axis = None + self.original_x_axis = None + self.original_y_axis = None + self.x_unit = None + self.y_unit = None + + # Data variables + self.original_data = None + self.data = None + self.residuals = None + self.residuals_percent = None + self.divided_image = None + + @classmethod + def from_xds(cls, xds): + """ + Initialize a FITSImage object using as a base a Xarray dataset + Args: + xds: Xarray dataset + + Returns: + FITSImage object initialized from a xds + """ + return_obj = cls() + return_obj._init_as_xds(xds) + return return_obj + + @classmethod + def from_fits_file(cls, fits_filename, telescope_name): + """ + Initialize a FITSImage object using as a base a FITS file. + Args: + fits_filename: FITS file on disk + telescope_name: Name of the telescope used + + Returns: + FITSImage object initialized from a FITS file + """ + return_obj = cls() + return_obj._init_as_fits(fits_filename, telescope_name) + return return_obj + + @classmethod + def from_zarr(cls, zarr_filename): + """ + Initialize a FITSImage object using as a base a Xarray dataset store on disk in a zarr container + Args: + zarr_filename: Xarray dataset on disk as a zarr container + + Returns: + FITSImage object initialized from a xds + """ + return_obj = cls() + xds = xr.open_zarr(zarr_filename) + return_obj._init_as_xds(xds) + return return_obj + + def _init_as_fits(self, fits_filename, telescope_name, istokes=0, ichan=0): + """ + Backend for FITSImage.from_fits_file + Args: + fits_filename: FITS file on disk + telescope_name: Name of the telescope used + istokes: Stokes axis element to be fetched, should always be zero (singleton stokes axis or fetching I) + ichan: Channel axis element to be fetched, should be zero for most cases, unless image has multiple channels + + Returns: + None + """ + self.filename = fits_filename + self.telescope_name = telescope_name + self.rootname = '.'.join(fits_filename.split('.')[:-1])+'.' + self.header, self.data = read_fits(self.filename, header_as_dict=True) + stokes_iaxis = get_stokes_axis_iaxis(self.header) + + self.unit = self.header['BUNIT'] + + if len(self.data.shape) == 4: + if stokes_iaxis == 4: + self.data = self.data[istokes, ichan, ...] + else: + self.data = self.data[ichan, istokes, ...] + elif len(self.data.shape) == 2: + pass # image is already as expected + else: + raise Exception(f'FITS image has an unsupported shape: {self.data.shape}') + + self.original_data = np.copy(self.data) + + if 'AIPS' in self.header['ORIGIN']: + self.x_axis, _, self.x_unit = get_axis_from_fits_header(self.header, 1, pixel_offset=False) + self.y_axis, _, self.y_unit = get_axis_from_fits_header(self.header, 2, pixel_offset=False) + self.x_unit = 'm' + self.y_unit = 'm' + elif 'Astrohack' in self.header['ORIGIN']: + self.x_axis, _, self.x_unit = get_axis_from_fits_header(self.header, 1) + self.y_axis, _, self.y_unit = get_axis_from_fits_header(self.header, 2) + self.data = np.fliplr(self.data) + else: + raise Exception(f'Unrecognized origin:\n{self.header["origin"]}') + self._create_base_mask() + self.original_x_axis = np.copy(self.x_axis) + self.original_y_axis = np.copy(self.y_axis) + + def _init_as_xds(self, xds): + """ + Backend for FITSImage.from_xds + Args: + xds: Xarray DataSet + Returns: + None + """ + for key in xds.attrs: + setattr(self, key, xds.attrs[key]) + + self.x_axis = xds.x.values + self.y_axis = xds.y.values + self.original_x_axis = xds.original_x.values + self.original_y_axis = xds.original_y.values + + for key, value in xds.items(): + setattr(self, str(key), xds[key].values) + + def _create_base_mask(self): + """ + Create a base mask based on telescope parameters such as arm shadows. + Returns: + None + """ + telescope_obj = Telescope(self.telescope_name) + self.base_mask = create_aperture_mask(self.x_axis, self.y_axis, telescope_obj.inlim, telescope_obj.oulim, + arm_width=telescope_obj.arm_shadow_width, + arm_angle=telescope_obj.arm_shadow_rotation) + + def resample(self, ref_image): + """ + Resamples the data on this object onto the grid in ref_image + Args: + ref_image: Reference FITSImage object + + Returns: + None + """ + test_image(ref_image) + x_mesh_orig, y_mesh_orig = np.meshgrid(self.x_axis, self.y_axis, indexing='ij') + x_mesh_dest, y_mesh_dest = np.meshgrid(ref_image.x_axis, ref_image.y_axis, indexing='ij') + resamp = griddata((x_mesh_orig.ravel(), y_mesh_orig.ravel()), self.data.ravel(), + (x_mesh_dest.ravel(), y_mesh_dest.ravel()), + method='linear') + size = ref_image.x_axis.shape[0], ref_image.y_axis.shape[0] + self.x_axis = ref_image.x_axis + self.y_axis = ref_image.y_axis + self.data = resamp.reshape(size) + self._create_base_mask() + self.resampled = True + + def compare_difference(self, ref_image): + """ + Does the difference comparison between self and ref_image. + Args: + ref_image: Reference FITSImage object + + Returns: + None + """ + test_image(ref_image) + if not self.image_has_same_sampling(ref_image): + self.resample(ref_image) + + self.residuals = ref_image.data - (self.data * self.factor) + self.residuals_percent = 100 * self.residuals/ref_image.data + self.reference_name = ref_image.filename + + def compare_scaled_difference(self, ref_image, rejection=10): + """ + Does the scaled difference comparison between self and ref_image. + Args: + ref_image: Reference FITSImage object + rejection: rejection level for scaling factor + + Returns: + None + """ + test_image(ref_image) + if not self.image_has_same_sampling(ref_image): + self.resample(ref_image) + simple_division = ref_image.data / self.data + rough_factor = np.nanmean(simple_division[self.base_mask]) + self.divided_image = np.where(np.abs(simple_division) > rejection*rough_factor, np.nan, simple_division) + self.factor = np.nanmedian(self.divided_image) + self.compare_difference(ref_image) + + def image_has_same_sampling(self, ref_image): + """ + Tests if self has the same X and Y sampling as ref_image + Args: + ref_image: Reference FITSImage object + + Returns: + True or False + """ + test_image(ref_image) + return are_axes_equal(self.x_axis, ref_image.x_axis) and are_axes_equal(self.y_axis, ref_image.y_axis) + + def _mask_array(self, image_array): + """ + Applies base mask to image_array + Args: + image_array: Data array to be masked + + Returns: + Masked array + """ + return np.where(self.base_mask, image_array, np.nan) + + def plot_images(self, destination, plot_data=False, plot_percentuals=False, + plot_divided_image=False, colormap='viridis', dpi=300, display=False): + """ + Plot image contents of the FITSImage object, always plots the residuals when called + Args: + destination: Location onto which save plot files + plot_data: Also plot data array? + plot_percentuals: Also plot percentual residuals array? + plot_divided_image: Also plot divided image? + colormap: Colormap name for image plots + dpi: png resolution on disk + display: Show interactive view of plots + + Returns: + None + """ + + extent = compute_extent(self.x_axis, self.y_axis, 0.0) + cmap = get_proper_color_map(colormap) + base_name = f'{destination}/{self.rootname}' + + if self.residuals is None: + raise Exception("Cannot plot results as they don't exist yet.") + self._plot_map(self._mask_array(self.residuals), f'Residuals, ref={self.reference_name}', + f'Residuals [{self.unit}]', f'{base_name}residuals.png', cmap, extent, + 'symmetrical', dpi, display, add_statistics=True) + + if plot_data: + self._plot_map(self._mask_array(self.data), 'Original Data', f'Data [{self.unit}]', + f'{base_name}data.png', cmap, extent, [None, None], dpi, display, + add_statistics=False) + + if plot_percentuals: + if self.residuals is None: + raise Exception("Cannot plot results as they don't exist yet.") + self._plot_map(self._mask_array(self.residuals_percent), f'Residuals in %, ref={self.reference_name}', + f'Residuals [%]', f'{base_name}residuals_percent.png', cmap, extent, + 'symmetrical', dpi, display, add_statistics=True) + + if plot_divided_image: + if self.divided_image is None: + raise Exception("Cannot plot a divided image that does not exist.") + self._plot_map(self._mask_array(self.divided_image), + f'Divided image, ref={self.reference_name}, scaling={self.factor:.4f}', + f'Division [ ]', f'{base_name}divided.png', cmap, extent, [None, None], + dpi, display, add_statistics=True) + + def _plot_map(self, data, title, zlabel, filename, cmap, extent, zscale, dpi, display, add_statistics=False): + """ + Backend for plot_images + Args: + data: Data array to be plotted + title: Title to appear on plot + zlabel: Label for the colorbar + filename: name for the png file on disk + cmap: Colormap object for plots + extent: extents of the X and Y axes + zscale: Constraints on the Z axes. + dpi: png resolution on disk + display: Show interactive view of plots + add_statistics: Add simple statistics to plot's subtitle + + Returns: + None + """ + fig, ax = plt.subplots(1, 1, figsize=[10, 8]) + if zscale == 'symmetrical': + scale = max(np.abs(np.nanmin(data)), np.abs(np.nanmax(data))) + vmin, vmax = -scale, scale + else: + vmin, vmax = zscale + if vmin == 'None' or vmin is None: + vmin = np.nanmin(data) + if vmax == 'None' or vmax is None: + vmax = np.nanmax(data) + + im = ax.imshow(data, cmap=cmap, interpolation="nearest", extent=extent, + vmin=vmin, vmax=vmax,) + well_positioned_colorbar(ax, fig, im, zlabel, location='right', size='5%', pad=0.05) + ax.set_xlabel(f"X axis [{self.x_unit}]") + ax.set_ylabel(f"Y axis [{self.y_unit}]") + if add_statistics: + data_stats = data_statistics(data) + ax.set_title(statistics_to_text(data_stats)) + close_figure(fig, title, filename, dpi, display) + + def export_as_xds(self): + """ + Create a Xarray DataSet from the FITSImage object + Returns: + Xarray DataSet + """ + xds = xr.Dataset() + obj_dict = vars(self) + + coords = {'x': self.x_axis, 'y': self.y_axis, + 'original_x': self.original_x_axis, 'original_y': self.original_y_axis} + for key, value in obj_dict.items(): + failed = False + if isinstance(value, np.ndarray): + if len(value.shape) == 2: + if 'original' in key: + xds[key] = xr.DataArray(value, dims=['original_x', 'original_y']) + else: + xds[key] = xr.DataArray(value, dims=['x', 'y']) + elif len(value.shape) == 1: + pass # Axes + else: + failed = True + else: + xds.attrs[key] = value + + if failed: + raise Exception(f"Don't know what to do with: {key}") + + xds = xds.assign_coords(coords) + return xds + + def to_zarr(self, zarr_filename): + """ + Saves a xds representation of self on disk using the zarr format. + Args: + zarr_filename: Name for the zarr container on disk + + Returns: + None + """ + xds = self.export_as_xds() + xds.to_zarr(zarr_filename, mode="w", compute=True, consolidated=True) + + def __repr__(self): + """ + Print method + Returns: + A String summary of the current status of self. + """ + obj_dict = vars(self) + outstr = '' + for key, value in obj_dict.items(): + if isinstance(value, np.ndarray): + outstr += f'{key:17s} -> {value.shape}' + elif isinstance(value, dict): + outstr += f'{key:17s} -> dict()' + else: + outstr += f'{key:17s} = {value}' + outstr += '\n' + return outstr + + def export_to_fits(self, destination): + """ + Export internal images to FITS files. + Args: + destination: location to store FITS files + + Returns: + None + """ + pathlib.Path(destination).mkdir(exist_ok=True) + ext_fits = '.fits' + out_header = self.header.copy() + + put_axis_in_fits_header(out_header, self.x_axis, 1, '', self.x_unit) + put_axis_in_fits_header(out_header, self.y_axis, 2, '', self.y_unit) + + obj_dict = vars(self) + for key, value in obj_dict.items(): + if isinstance(value, np.ndarray): + if len(value.shape) == 2: + if 'original' in key: + pass + else: + if key == 'base_mask' or key == 'divided_image': + unit = '' + + elif key == 'residuals_percent': + unit = '%' + else: + unit = self.unit + filename = f'{destination}/{self.rootname}{key}{ext_fits}' + write_fits(out_header, key, np.fliplr(value.astype(float)), filename, unit, reorder_axis=False) + + def scatter_plot(self, destination, ref_image, dpi=300, display=False): + """ + Produce a scatter plot of self.data agains ref_image.data + Args: + destination: Location to store scatter plot + ref_image: Reference FITSImage object + dpi: png resolution on disk + display: Show interactive view of plot + + Returns: + None + """ + test_image(ref_image) + if not self.image_has_same_sampling(ref_image): + self.resample(ref_image) + + fig, ax = plt.subplots(1, 1, figsize=[10, 8]) + + scatter_mask = np.isfinite(ref_image.data) + scatter_mask = np.where(np.isfinite(self.data), scatter_mask, False) + ydata = self.data[scatter_mask] + xdata = ref_image.data[scatter_mask] + + scatter_plot(ax, xdata, f'Reference image {ref_image.filename} [{ref_image.unit}]', + ydata, f'{self.filename} [{self.unit}]', add_regression=True) + close_figure(fig, 'Scatter plot against reference image', f'{destination}/{self.rootname}scatter.png', + dpi, display) + + +def image_comparison_chunk(compare_params): + """ + Chunk function for parallel execution of the image comparison tool. + Args: + compare_params: Parameter dictionary for workflow control. + + Returns: + A DataTree containing the Image and its reference Image. + """ + + image = FITSImage.from_fits_file(compare_params['this_image'], compare_params['telescope_name']) + ref_image = FITSImage.from_fits_file(compare_params['this_reference_image'], compare_params['telescope_name']) + plot_data = compare_params['plot_data'] + plot_percentuals = compare_params['plot_percentuals'] + plot_divided = compare_params['plot_divided_image'] + destination = compare_params['destination'] + colormap = compare_params['colormap'] + dpi = compare_params['dpi'] + display = compare_params['display'] + + if compare_params['comparison'] == 'direct': + image.compare_difference(ref_image) + image.plot_images(destination, plot_data, plot_percentuals, False, colormap=colormap, dpi=dpi, + display=display) + elif compare_params['comparison'] == 'scaled': + image.compare_scaled_difference(ref_image) + image.plot_images(destination, plot_data, plot_percentuals, plot_divided, colormap=colormap, dpi=dpi, + display=display) + else: + raise Exception(f'Unknown comparison type {compare_params["comparison"]}') + + if compare_params['export_to_fits']: + image.export_to_fits(destination) + + if compare_params['plot_scatter']: + image.scatter_plot(destination, ref_image, dpi=dpi, display=display) + + img_node = xr.DataTree(name=image.filename, dataset=image.export_as_xds()) + ref_node = xr.DataTree(name=ref_image.filename, dataset=ref_image.export_as_xds()) + tree_node = xr.DataTree(name=image.rootname[:-1], children={'Reference': ref_node, 'Image': img_node}) + + return tree_node diff --git a/src/astrohack/data/telescopes/vla.zarr/.zattrs b/src/astrohack/data/telescopes/vla.zarr/.zattrs index f063e385..f5dae34c 100644 --- a/src/astrohack/data/telescopes/vla.zarr/.zattrs +++ b/src/astrohack/data/telescopes/vla.zarr/.zattrs @@ -29,8 +29,29 @@ "ea27", "ea28" ], - "arm_shadow_rotation": 0.0, - "arm_shadow_width": 1.5, + "arm_shadow_rotation": 0, + "arm_shadow_width": [ + [ + 1.983, + 7.391, + 0.8 + ], + [ + 7.391, + 9.144, + 1.2 + ], + [ + 9.144, + 10.87, + 1.4 + ], + [ + 10.87, + 12.5, + 1.8 + ] + ], "array_center": { "m0": { "unit": "rad", diff --git a/src/astrohack/data/telescopes/vla.zarr/.zmetadata b/src/astrohack/data/telescopes/vla.zarr/.zmetadata index dbc7a9ac..eae7956c 100644 --- a/src/astrohack/data/telescopes/vla.zarr/.zmetadata +++ b/src/astrohack/data/telescopes/vla.zarr/.zmetadata @@ -31,8 +31,29 @@ "ea27", "ea28" ], - "arm_shadow_rotation": 0.0, - "arm_shadow_width": 1.5, + "arm_shadow_rotation": 0, + "arm_shadow_width": [ + [ + 1.983, + 7.391, + 0.8 + ], + [ + 7.391, + 9.144, + 1.2 + ], + [ + 9.144, + 10.87, + 1.4 + ], + [ + 10.87, + 12.5, + 1.8 + ] + ], "array_center": { "m0": { "unit": "rad", diff --git a/src/astrohack/image_comparison_tool.py b/src/astrohack/image_comparison_tool.py new file mode 100644 index 00000000..ab8e3351 --- /dev/null +++ b/src/astrohack/image_comparison_tool.py @@ -0,0 +1,142 @@ +from typing import Union, List +import xarray as xr +import pathlib + +import toolviper.utils.logger as logger +import toolviper + +from astrohack.core.image_comparison_tool import image_comparison_chunk +from astrohack.utils.graph import compute_graph_from_lists +from astrohack.utils.validation import custom_plots_checker + + +@toolviper.utils.parameter.validate( + custom_checker=custom_plots_checker +) +def compare_fits_images( + image: Union[str, List[str]], + reference_image: Union[str, List[str]], + telescope_name: str, + destination: str, + comparison: str = 'direct', + zarr_container_name: str = None, + plot_data: bool = False, + plot_percentuals: bool = False, + plot_divided_image: bool = False, + plot_scatter: bool = True, + export_to_fits: bool = False, + colormap: str = 'viridis', + dpi: int = 300, + display: bool = False, + parallel: bool = False +): + """ + Compares a set of images to a set of reference images. + + :param image: FITS image or list of FITS images to be compared. + :type image: list or str + + :param reference_image: FITS image or list of FITS images that serve as references. + :type reference_image: list or str + + :param telescope_name: Name of the telescope used. Used for masking. + :type telescope_name: str + + :param destination: Name of directory onto which save plots + :type destination: str + + :param comparison: Type of comparison to be made between images, "direct" or "scaled", default is "direct". + :type comparison: str, optional + + :param zarr_container_name: Name of the Zarr container to contain the created datatree, default is None, i.e. \ + DataTree is not saved to disk. + :type zarr_container_name: str, optional + + :param plot_data: Plot the data array used in the comparison, default is False. + :type plot_data: bool, optional + + :param plot_percentuals: Plot the residuals in percent of reference image as well, default is False. + :type plot_percentuals: bool, optional + + :param plot_divided_image: Plot the divided image between Image and its reference, default is False. + :type plot_divided_image: bool, optional + + :param plot_scatter: Make a scatter plot of the Image against its reference image, default is True. + :type plot_scatter: bool, optional + + :param export_to_fits: Export created images to FITS files inside destination, default is False. + :type export_to_fits: bool, optional + + :param colormap: Colormap to be used on image plots, default is "viridis". + :type colormap: str, optional + + :param dpi: dots per inch to be used in plots, default is 300. + :type dpi: int, optional + + :param display: Display plots inline or suppress, defaults to True + :type display: bool, optional + + :param parallel: If True will use an existing astrohack client to do comparison in parallel, default is False + :type parallel: bool, optional + + :return: DataTree object containing all the comparisons executed + :rtype: xr.DataTree + + .. _Description: + Compares pairs of FITS images pixel by pixel using a mask based on telescope parameters to exclude problematic \ + regions such as shadows caused by the secondary mirror or the arms supporting it. By default, 2 products are \ + produced, a plot of the residuals image, i.e. (Reference - Image) and a scatter plot of the Reference against the \ + Image. If necessary a resample of Image is conducted to allow for pixel by pixel comparison. + + .. rubric:: Comparison: + Two types of comparison between the images are available: + - *direct*: Where the residuals are simply computed as Reference - Image. + - *scaled*: Where the residuals are Reference - Factor * Image, with Factor = median(Reference/Image). + + .. rubric:: Plots: + A plot of the residuals of the comparison is always produced. + However, a few extra plots can be produced and their production is controlled by the *plot_* parameters, these are: + - *plot_data*: Activates plotting of the data used in the comparison, default is False as this is the data on \ + the FITS file. + - *plot_percentuals*: Activates the plotting of the residuals as a perdentage of the Reference Image, default \ + is False as this is just another view on the residuals. + - *plot_divided_image*: Activates the plotting of Reference/Image, default is False. This plot is only \ + available when using "scaled" comparison. + - *plot_scatter*: Activates the creation of a scatter plot of Reference vs Image, with a linear regression, \ + default is True. + + .. rubric:: Storage on disk: + By default, this function only produces plots, but this can be changed using two parameters: + - *zarr_container_name*: If this parameter is not None a Zarr container will be created on disk with the \ + contents of the produced DataTree. + - *export_to_fits*: If set to True will produce FITS files of the produced images and store them at \ + *destination*. + + .. rubric:: Return type: + This funtion returns a Xarray DataTree containing the Xarray DataSets that represent Image and Reference. The nodes \ + in this DataTree are labelled according to the filenames given as input for easier navigation. + """ + + if isinstance(image, str): + image = [image] + if isinstance(reference_image, str): + reference_image = [reference_image] + if len(image) != len(reference_image): + msg = 'List of reference images has a different size from the list of images' + logger.error(msg) + return + + param_dict = locals() + pathlib.Path(param_dict['destination']).mkdir(exist_ok=True) + + result_list = compute_graph_from_lists(param_dict, image_comparison_chunk, ['image', 'reference_image'], parallel) + + root = xr.DataTree(name='Root') + for item in result_list: + tree_node = item[0] + root = root.assign({tree_node.name: tree_node}) + + if zarr_container_name is not None: + root.to_zarr(zarr_container_name, mode='w', consolidated=True) + + return root diff --git a/src/astrohack/panel.py b/src/astrohack/panel.py index 88a3e1c1..8c8d1445 100644 --- a/src/astrohack/panel.py +++ b/src/astrohack/panel.py @@ -52,7 +52,7 @@ def panel( passing a dictionary, default is 3 (appropriate for sigma clipping) :type clip_level: float, dict, optional - :param exclude_shadows: Exclude regions with significant shadowing from analysis, e.g. secondary supporting arms, + :param exclude_shadows: Exclude regions with significant shadowing from analysis, e.g. secondary supporting arms, \ default is True. :type exclude_shadows: bool, optional diff --git a/src/astrohack/utils/algorithms.py b/src/astrohack/utils/algorithms.py index a6895e9a..957a36d1 100644 --- a/src/astrohack/utils/algorithms.py +++ b/src/astrohack/utils/algorithms.py @@ -538,6 +538,18 @@ def create_coordinate_images(x_axis, y_axis, create_polar_coordinates=False): def create_aperture_mask(x_axis, y_axis, inner_rad, outer_rad, arm_width=None, arm_angle=0, return_polar_meshes=False): """ + Create a basic aperture mask with support for feed supporting arms shadows + Args: + x_axis: The X axis of the Aperture + y_axis: The Y axis of the Aperture + inner_rad: The innermost radius for valid data in aperture + outer_rad: The outermost radius for valid data in aperture + arm_width: The width of the feed arm shadows, can be a list with limiting radii or a single value. + arm_angle: The angle between the arm shadows and the X axis + return_polar_meshes: Return the radial and polar meshes to avoid duplicate computations. + + Returns: + """ x_mesh, y_mesh, radius_mesh, polar_angle_mesh = \ create_coordinate_images(x_axis, y_axis, create_polar_coordinates=True) @@ -547,21 +559,41 @@ def create_aperture_mask(x_axis, y_axis, inner_rad, outer_rad, arm_width=None, a if arm_width is None: pass + elif isinstance(arm_width, (float, int)): + mask = _arm_shadow_masking(mask, x_mesh, y_mesh, radius_mesh, inner_rad, outer_rad, arm_width, arm_angle) + elif isinstance(arm_width, list): + for section in arm_width: + minradius, maxradius, width = section + mask = _arm_shadow_masking(mask, x_mesh, y_mesh, radius_mesh, minradius, maxradius, width, arm_angle) + else: - if arm_angle % pi/2 == 0: - mask = np.where(np.abs(x_mesh) < arm_width/2., False, mask) - mask = np.where(np.abs(y_mesh) < arm_width/2., False, mask) - else: - # first shadow - coeff = np.tan(arm_angle % pi) - distance = np.abs((coeff*x_mesh-y_mesh)/np.sqrt(coeff**2+1)) - mask = np.where(distance < arm_width/2., False, mask) - # second shadow - coeff = np.tan(arm_angle % pi + pi/2) - distance = np.abs((coeff*x_mesh-y_mesh)/np.sqrt(coeff**2+1)) - mask = np.where(distance < arm_width/2., False, mask) + raise Exception(f"Don't know how to handle an arm width of class {type(arm_width)}") if return_polar_meshes: return mask, radius_mesh, polar_angle_mesh else: return mask + + +def _arm_shadow_masking(inmask, x_mesh, y_mesh, radius_mesh, minradius, maxradius, width, angle): + radial_mask = np.where(radius_mesh < minradius, False, inmask) + radial_mask = np.where(radius_mesh >= maxradius, False, radial_mask) + if angle % pi/2 == 0: + oumask = np.where(np.bitwise_and(np.abs(x_mesh) < width/2., radial_mask), False, inmask) + oumask = np.where(np.bitwise_and(np.abs(y_mesh) < width/2., radial_mask), False, oumask) + else: + # first shadow + coeff = np.tan(angle % pi) + distance = np.abs((coeff*x_mesh-y_mesh)/np.sqrt(coeff**2+1)) + oumask = np.where(np.bitwise_and(distance < width/2., radial_mask), False, inmask) + # second shadow + coeff = np.tan(angle % pi + pi/2) + distance = np.abs((coeff*x_mesh-y_mesh)/np.sqrt(coeff**2+1)) + oumask = np.where(np.bitwise_and(distance < width/2., radial_mask), False, oumask) + return oumask + + +def are_axes_equal(axis_a, axis_b): + if axis_a.shape[0] != axis_b.shape[0]: + return False + return np.all(axis_a == axis_b) diff --git a/src/astrohack/utils/fits.py b/src/astrohack/utils/fits.py index d271f4a8..f0d8d5b3 100644 --- a/src/astrohack/utils/fits.py +++ b/src/astrohack/utils/fits.py @@ -10,7 +10,41 @@ from astrohack.utils.text import add_prefix -def read_fits(filename): +def get_stokes_axis_iaxis(header): + """ + Get which of the axis in the header is the stokes axis + Args: + header: FITS header + + Returns: + None if no stokes axis is found, iaxis if stokes axis is found + """ + naxis = header['NAXIS'] + for iaxis in range(naxis): + axis_type = safe_keyword_fetch(header, f'CTYPE{iaxis+1}') + if 'STOKES' in axis_type: + return iaxis + 1 + return None + + + +def safe_keyword_fetch(header_dict, keyword): + """ + Tries to fetch a keyword from a FITS header / dictionary + Args: + header_dict: FITS header / Dictionary + keyword: The intended keyword to fetch + + Returns: + Keyword value if prensent, None if not present. + """ + try: + return header_dict[keyword] + except KeyError: + return None + + +def read_fits(filename, header_as_dict=True): """ Reads a square FITS file and do sanity checks on its dimensionality Args: @@ -21,7 +55,7 @@ def read_fits(filename): """ hdul = fits.open(filename) head = hdul[0].header - data = hdul[0].data[0, 0, :, :] + data = hdul[0].data hdul.close() if head["NAXIS"] != 1: if head["NAXIS"] < 1: @@ -32,10 +66,42 @@ def read_fits(filename): raise Exception(filename + " is not bi-dimensional") if head["NAXIS1"] != head["NAXIS2"]: raise Exception(filename + " does not have the same amount of pixels in the x and y axes") - return head, data + + if header_as_dict: + header_dict = {} + for key, value in head.items(): + header_dict[key] = value + return header_dict, data + else: + return head, data -def write_fits(header, imagetype, data, filename, unit, origin): +def get_axis_from_fits_header(header, iaxis, pixel_offset=True): + """ + Pull axis information from FITS file and store it in a numpy array, ignores rotation in axes. + Args: + header: FITS header + iaxis: Which axis is to be fetched from the header. + pixel_offset: apply one pixel offset + + Returns: + numpy array representation of axis, axis type and axis unit + """ + n_elem = header[f'NAXIS{iaxis}'] + ref = header[f'CRPIX{iaxis}'] + inc = header[f'CDELT{iaxis}'] + if pixel_offset: + val = header[f'CRVAL{iaxis}'] + inc # This makes this routine symmetrical to the put routine. + else: + val = header[f'CRVAL{iaxis}'] + axis = np.arange(n_elem) + axis = val + (ref-axis)*inc + axis_unit = safe_keyword_fetch(header, f'CUNIT{iaxis}') + axis_type = safe_keyword_fetch(header, f'CTYPE{iaxis}') + return axis, axis_type, axis_unit + + +def write_fits(header, imagetype, data, filename, unit, origin=None, reorder_axis=True): """ Write a dictionary and a dataset to a FITS file Args: @@ -45,6 +111,7 @@ def write_fits(header, imagetype, data, filename, unit, origin): filename: The name of the output file unit: to be set to bunit origin: Which astrohack mds has created the FITS being written + reorder_axis: Reorder data axes so that they are compatible with regular FITS ordering """ header['BUNIT'] = unit @@ -52,10 +119,20 @@ def write_fits(header, imagetype, data, filename, unit, origin): header['ORIGIN'] = f'Astrohack v{astrohack.__version__}: {origin}' header['DATE'] = datetime.datetime.now().strftime('%b %d %Y, %H:%M:%S') - hdu = fits.PrimaryHDU(_reorder_axes_for_fits(data)) + if origin is None: + header['ORIGIN'] = f'Astrohack v{astrohack.__version__}' + outfile = filename + else: + header['ORIGIN'] = f'Astrohack v{astrohack.__version__}: {origin}' + outfile = add_prefix(filename, origin) + + if reorder_axis: + hdu = fits.PrimaryHDU(_reorder_axes_for_fits(data)) + else: + hdu = fits.PrimaryHDU(data) for key in header.keys(): hdu.header.set(key, header[key]) - hdu.writeto(add_prefix(filename, origin), overwrite=True) + hdu.writeto(outfile, overwrite=True) return @@ -73,7 +150,7 @@ def _reorder_axes_for_fits(data: np.ndarray): return np.flipud(data) -def resolution_to_fits_header(header, resolution): +def put_resolution_in_fits_header(header, resolution): """ Adds resolution information to standard header keywords: BMAJ, BMIN and BPA Args: @@ -95,7 +172,7 @@ def resolution_to_fits_header(header, resolution): return header -def axis_to_fits_header(header: dict, axis, iaxis, axistype, unit, iswcs=True): +def put_axis_in_fits_header(header: dict, axis, iaxis, axistype, unit, iswcs=True): """ Process an axis to create a FITS compatible linear axis description Args: @@ -103,6 +180,8 @@ def axis_to_fits_header(header: dict, axis, iaxis, axistype, unit, iswcs=True): axis: The axis to be described in the header iaxis: The position of the axis in the data axistype: Axis type to be displayed in the fits header + unit: Axis unit + iswcs: Is the axis a part of World Coordinate System for the image? Returns: The augmented header @@ -142,7 +221,7 @@ def axis_to_fits_header(header: dict, axis, iaxis, axistype, unit, iswcs=True): return outheader -def stokes_axis_to_fits_header(header, iaxis): +def put_stokes_axis_in_fits_header(header, iaxis): """ Inserts a dedicated stokes axis in the header at iaxis Args: diff --git a/src/astrohack/utils/graph.py b/src/astrohack/utils/graph.py index f989a19b..4f88f500 100644 --- a/src/astrohack/utils/graph.py +++ b/src/astrohack/utils/graph.py @@ -89,3 +89,36 @@ def compute_graph(looping_dict, chunk_function, param_dict, key_order, parallel= if parallel: dask.compute(delayed_list) return True + + +def compute_graph_from_lists(param_dict, chunk_function, looping_key_list, parallel=False): + """ + Creates and executes a graph based on entries in a parameter dictionary that are lists + Args: + param_dict: The parameter dictionary + chunk_function: The function for the operation chunk + looping_key_list: The keys that are lists in the parameter dictionaries over which to loop over + parallel: execute graph in parallel? + + Returns: + A list containing the returns of the calls to the chunk function. + """ + niter = len(param_dict[looping_key_list[0]]) + + delayed_list = [] + result_list = [] + for i_iter in range(niter): + this_param = param_dict.copy() + for key in looping_key_list: + this_param[f'this_{key}'] = param_dict[key][i_iter] + + if parallel: + delayed_list.append(dask.delayed(chunk_function)(dask.delayed(this_param))) + else: + delayed_list.append(0) + result_list.append(chunk_function(this_param)) + + if parallel: + result_list = dask.compute(delayed_list) + + return result_list diff --git a/src/astrohack/utils/text.py b/src/astrohack/utils/text.py index 8522386a..fecbeea9 100644 --- a/src/astrohack/utils/text.py +++ b/src/astrohack/utils/text.py @@ -574,7 +574,7 @@ def significant_figures_round(x, digits): def statistics_to_text(data_statistics, keys=None): if keys is None: - outstr = (f'min={data_statistics["min"]:.2e}, max={data_statistics["max"]:.2f}, ' + outstr = (f'min={data_statistics["min"]:.2f}, max={data_statistics["max"]:.2f}, ' f'mean={data_statistics["mean"]:.2f}, med={data_statistics["median"]:.2f}, ' f'rms={data_statistics["rms"]:.2f}') else: diff --git a/src/astrohack/visualization/fits.py b/src/astrohack/visualization/fits.py index 1db92166..c1dcd4b8 100644 --- a/src/astrohack/visualization/fits.py +++ b/src/astrohack/visualization/fits.py @@ -2,7 +2,7 @@ from toolviper.utils import logger as logger from astrohack.antenna import Telescope, AntennaSurface from astrohack.utils import clight, convert_unit, add_prefix -from astrohack.utils.fits import axis_to_fits_header, stokes_axis_to_fits_header, write_fits, resolution_to_fits_header +from astrohack.utils.fits import put_axis_in_fits_header, put_stokes_axis_in_fits_header, write_fits, put_resolution_in_fits_header def export_to_fits_panel_chunk(parm_dict): @@ -85,11 +85,11 @@ def export_to_fits_holog_chunk(parm_dict): if ntime != 1: raise Exception("Data with multiple times not supported for FITS export") - base_header = axis_to_fits_header(base_header, input_xds.chan.values, 3, 'Frequency', 'Hz') - base_header = stokes_axis_to_fits_header(base_header, 4) + base_header = put_axis_in_fits_header(base_header, input_xds.chan.values, 3, 'Frequency', 'Hz') + base_header = put_stokes_axis_in_fits_header(base_header, 4) rad_to_deg = convert_unit('rad', 'deg', 'trigonometric') - beam_header = axis_to_fits_header(base_header, -input_xds.l.values * rad_to_deg, 1, 'RA---SIN', 'deg') - beam_header = axis_to_fits_header(beam_header, input_xds.m.values * rad_to_deg, 2, 'DEC--SIN', 'deg') + beam_header = put_axis_in_fits_header(base_header, -input_xds.l.values * rad_to_deg, 1, 'RA---SIN', 'deg') + beam_header = put_axis_in_fits_header(beam_header, input_xds.m.values * rad_to_deg, 2, 'DEC--SIN', 'deg') beam_header['RADESYSA'] = 'FK5' beam = input_xds['BEAM'].values if parm_dict['complex_split'] == 'cartesian': @@ -103,9 +103,9 @@ def export_to_fits_holog_chunk(parm_dict): write_fits(beam_header, 'Complex beam phase', np.angle(beam), add_prefix(basename, 'beam_phase') + '.fits', 'Radians', 'image') wavelength = clight / input_xds.chan.values[0] - aperture_header = axis_to_fits_header(base_header, input_xds.u.values * wavelength, 1, 'X----LIN', 'm') - aperture_header = axis_to_fits_header(aperture_header, input_xds.u.values * wavelength, 2, 'Y----LIN', 'm') - aperture_header = resolution_to_fits_header(aperture_header, aperture_resolution) + aperture_header = put_axis_in_fits_header(base_header, input_xds.u.values * wavelength, 1, 'X----LIN', 'm') + aperture_header = put_axis_in_fits_header(aperture_header, input_xds.u.values * wavelength, 2, 'Y----LIN', 'm') + aperture_header = put_resolution_in_fits_header(aperture_header, aperture_resolution) aperture = input_xds['APERTURE'].values if parm_dict['complex_split'] == 'cartesian': write_fits(aperture_header, 'Complex aperture real part', aperture.real, @@ -118,9 +118,9 @@ def export_to_fits_holog_chunk(parm_dict): write_fits(aperture_header, 'Complex aperture phase', np.angle(aperture), add_prefix(basename, 'aperture_phase') + '.fits', 'rad', 'image') - phase_amp_header = axis_to_fits_header(base_header, input_xds.u_prime.values * wavelength, 1, 'X----LIN', 'm') - phase_amp_header = axis_to_fits_header(phase_amp_header, input_xds.v_prime.values * wavelength, 2, 'Y----LIN', 'm') - phase_amp_header = resolution_to_fits_header(phase_amp_header, aperture_resolution) + phase_amp_header = put_axis_in_fits_header(base_header, input_xds.u_prime.values * wavelength, 1, 'X----LIN', 'm') + phase_amp_header = put_axis_in_fits_header(phase_amp_header, input_xds.v_prime.values * wavelength, 2, 'Y----LIN', 'm') + phase_amp_header = put_resolution_in_fits_header(phase_amp_header, aperture_resolution) write_fits(phase_amp_header, 'Cropped aperture corrected phase', input_xds['CORRECTED_PHASE'].values, add_prefix(basename, 'corrected_phase') + '.fits', 'rad', 'image') return diff --git a/src/astrohack/visualization/plot_tools.py b/src/astrohack/visualization/plot_tools.py index 618132ea..cf361a93 100644 --- a/src/astrohack/visualization/plot_tools.py +++ b/src/astrohack/visualization/plot_tools.py @@ -1,5 +1,6 @@ import matplotlib.image import numpy as np +from scipy.stats import linregress from matplotlib import pyplot as plt from matplotlib.patches import Rectangle @@ -218,7 +219,10 @@ def scatter_plot( residuals_marker='+', residuals_color='black', residuals_linestyle='', - residuals_label='residuals' + residuals_label='residuals', + add_regression=False, + regression_linestyle='-', + regression_color='black' ): """ Do scatter simple scatter plots of data to a plotting axis @@ -245,12 +249,19 @@ def scatter_plot( model_color: Color of the model marker model_linestyle: Line style for connecting model points model_label: Label for model points + plot_residuals: Add a residuals subplot at the bottom when a model is provided + residuals_marker: Marker for residuals + residuals_color: Color for residual markers + residuals_linestyle: Line style for residuals + residuals_label: Label for residuals + add_regression: Add a linear regression between X and y data + regression_linestyle: Line style for the regression plot + regression_color: Color for the regression plot """ ax.plot(xdata, ydata, ls=data_linestyle, marker=data_marker, color=data_color, label=data_label) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) - if title is not None: - ax.set_title(title) + if xlim is not None: ax.set_xlim(xlim) if ylim is not None: @@ -274,6 +285,13 @@ def scatter_plot( rotation=20 ) + if add_regression: + slope, intercept, _, _, _ = linregress(xdata, ydata) + regression_label = f'y = {slope:.4f}*x + {intercept:.4f}' + yregress = slope*xdata + intercept + ax.plot(xdata, yregress, ls=regression_linestyle, color=regression_color, label=regression_label, lw=2) + ax.legend() + if model is not None: ax.plot(xdata, model, ls=model_linestyle, marker=model_marker, color=model_color, label=model_label) ax.legend() @@ -297,4 +315,7 @@ def scatter_plot( ax_res.axhline(0, color=hv_color, ls=hv_linestyle) ax_res.set_ylabel('Residuals') + if title is not None: + ax.set_title(title) + return diff --git a/tests/stakeholder/test_stakeholder_vla.py b/tests/stakeholder/test_stakeholder_vla.py index ae689d22..82346858 100644 --- a/tests/stakeholder/test_stakeholder_vla.py +++ b/tests/stakeholder/test_stakeholder_vla.py @@ -242,5 +242,9 @@ def test_holography_pipeline(set_data): exclude_shadows=False ) - assert verify_panel_shifts(data_dir=str(set_data), ref_mean_shift=reference_dict["vla"]["offsets"]), "Verify panel shifts" - #assert verify_panel_shifts(data_dir=str(set_data)), "Verify panel shifts" + reference_shifts = np.array([-91.6455227, 61.69666059, 4.39843319, 122.26547831]) + assert verify_panel_shifts(data_dir=str(set_data), ref_mean_shift=reference_shifts), "Verify panel shifts" + # This test using reference values is very hard to be updated, using this hardcoded reference_shifts is a + # temporary work around + # assert verify_panel_shifts(data_dir=str(set_data), ref_mean_shift=reference_dict["vla"]["offsets"]), \ + # "Verify panel shifts" diff --git a/tests/unit/test_class_antenna_surface.py b/tests/unit/test_class_antenna_surface.py index 2465ef16..a59ddcb6 100644 --- a/tests/unit/test_class_antenna_surface.py +++ b/tests/unit/test_class_antenna_surface.py @@ -90,8 +90,8 @@ def test_compile_panel_points_ringed(self): """ Tests that a point falls into the correct panel and that this panel has the correct number of samples """ - compvaluep0 = [3.2456341911764706, 0.7755055147058822, 198, 268, 0.00045656206805518506] - compnsampp0 = 120 + compvaluep0 = [3.3030790441176467, 0.43083639705882354, 197, 262, 0.00025600619999005243] + compnsampp0 = 179 self.tant.compile_panel_points() assert len(self.tant.panels[0].samples) == compnsampp0, 'Number of samples in panel is different from reference' assert self.tant.panels[0].samples[0] == PanelPoint(*compvaluep0), ('Point data in Panel is different from what' @@ -101,9 +101,10 @@ def test_fit_surface(self): """ Tests that fitting results for two panels match the reference """ - solveparsp0 = [0.00035746, 0.00020089, -0.0008455 ] - solveparsp30 = [ 0.00039911, -0.00041468, -0.0007079] + solveparsp0 = [0.00032385, 0.00037302, -0.00092492] + solveparsp30 = [0.00038098, -0.00039892, -0.00067244] self.tant.fit_surface() + assert len(self.tant.panels[0].model.parameters) == len(solveparsp0), ('Fitted results have a different length' ' from reference') for i in range(len(solveparsp30)): diff --git a/tests/unit/test_panel.py b/tests/unit/test_panel.py index 0f63952b..b4ea8f9f 100644 --- a/tests/unit/test_panel.py +++ b/tests/unit/test_panel.py @@ -202,7 +202,7 @@ def test_panel_absolute_clip(self): telescope = Telescope('vla') radius = panel_mds["ant_ea25"]["ddi_0"]['RADIUS'].values - dish_mask = np.where(radius < telescope.diam/2, 1.0, 0) + dish_mask = np.where(radius < telescope.oulim, 1.0, 0) dish_mask = np.where(radius < telescope.inlim, 0, dish_mask) nvalid_pix = np.sum(dish_mask)