diff --git a/config/inference_conf.yaml b/config/inference_conf.yaml new file mode 100644 index 00000000..927723c7 --- /dev/null +++ b/config/inference_conf.yaml @@ -0,0 +1,37 @@ + +inference: + # Inference batch size + batch_size: 16 + # Cycle to use for inference. Can be: 1, 2, ... + cycle: 2 + + # Paths to the configuration files and trained model directory + configs: + # Path to the directory containing the trained model + trained_model_path: /path/to/trained/model/ + + # Subdirectory name of the trained model, e.g. psf_model + model_subdir: model + + # Relative Path to the training configuration file used to train the model + trained_model_config_path: config/training_config.yaml + + # Path to the data config file (this could contain prior information) + data_config_path: + + # The following parameters will overwrite the `model_params` in the training config file. + model_params: + # Num of wavelength bins to reconstruct polychromatic objects. + n_bins_lda: 8 + + # Downsampling rate to match the oversampled model to the specified telescope's sampling. + output_Q: 1 + + # Dimension of the pixel PSF postage stamp + output_dim: 64 + + # Flag to perform centroid error correction + correct_centroids: False + + # Flag to perform CCD misalignment error correction + add_ccd_misalignments: True diff --git a/config/metrics_config.yaml b/config/metrics_config.yaml index cfaca9b9..e51b5ead 100644 --- a/config/metrics_config.yaml +++ b/config/metrics_config.yaml @@ -12,6 +12,9 @@ metrics: # Name of the Trained Model Config file stored in config sub-directory in the trained_model_path parent directory trained_model_config: + # Evaluate the chi2 metric. + eval_chi2_metric: True + # Evaluate the monchromatic RMSE metric. eval_mono_metric: True diff --git a/data/generation/README.md b/data/generation/README.md index aa90bbbf..2a363cbb 100644 --- a/data/generation/README.md +++ b/data/generation/README.md @@ -8,6 +8,16 @@ To run the script: python data_generation_script.py -c data_generation_params_v.0.1.0.yml ``` +> ⚠️ Warning +> +> There are some differences with the original dataset used for the WaveDiff paper (Liaudat et al. 2023) even if we use it as a reference: +> - The assingment of SEDs for each star will not match that one of the original dataset. Although the same templates are used. +> - The assigned noise level (SNR) for each star will not match the original dataset. Although, the same distribution will be used. +> +> Nevertheless, the `C_poly` will match and the `positions` will match. Therefore, results from the new datasets will not be exactly the same as in the original datasets. + + +### Dataset description **Dataset 0.x.x:** @@ -56,3 +66,5 @@ python data_generation_script.py -c data_generation_params_v.0.1.0.yml - v3.2.1/2/3/4 with dummy (unitary) masks - v3.3.1/2/3/4 with realistic masks + + diff --git a/data/generation/data_generation_params_v.0.1.0.yml b/data/generation/data_generation_params_v.0.1.0.yml index e373cd06..221a81c2 100644 --- a/data/generation/data_generation_params_v.0.1.0.yml +++ b/data/generation/data_generation_params_v.0.1.0.yml @@ -106,7 +106,7 @@ dataset_features: # Centroid shift options add_intrapixel_shifts: False # In pixels (should be abs(limits)<0.5) - intrapixel_shift_range: [-0.5, 0.5] + intrapixel_shift_range: [-0.3, 0.3] # CCD misalignment options add_ccd_misalignments: False diff --git a/data/generation/data_generation_params_v.0.2.0.yml b/data/generation/data_generation_params_v.0.2.0.yml index 804d9178..2708a649 100644 --- a/data/generation/data_generation_params_v.0.2.0.yml +++ b/data/generation/data_generation_params_v.0.2.0.yml @@ -106,7 +106,7 @@ dataset_features: # Centroid shift options add_intrapixel_shifts: False # In pixels (should be abs(limits)<0.5) - intrapixel_shift_range: [-0.5, 0.5] + intrapixel_shift_range: [-0.3, 0.3] # CCD misalignment options add_ccd_misalignments: False diff --git a/data/generation/data_generation_params_v.0.3.0.yml b/data/generation/data_generation_params_v.0.3.0.yml index 5c351dc9..d7487e2b 100644 --- a/data/generation/data_generation_params_v.0.3.0.yml +++ b/data/generation/data_generation_params_v.0.3.0.yml @@ -106,7 +106,7 @@ dataset_features: # Centroid shift options add_intrapixel_shifts: False # In pixels (should be abs(limits)<0.5) - intrapixel_shift_range: [-0.5, 0.5] + intrapixel_shift_range: [-0.3, 0.3] # CCD misalignment options add_ccd_misalignments: False diff --git a/data/generation/data_generation_params_v.1.1.0.yml b/data/generation/data_generation_params_v.1.1.0.yml index 14baccbf..f19d7719 100644 --- a/data/generation/data_generation_params_v.1.1.0.yml +++ b/data/generation/data_generation_params_v.1.1.0.yml @@ -106,7 +106,7 @@ dataset_features: # Centroid shift options add_intrapixel_shifts: True # In pixels (should be abs(limits)<0.5) - intrapixel_shift_range: [-0.5, 0.5] + intrapixel_shift_range: [-0.3, 0.3] # CCD misalignment options add_ccd_misalignments: True diff --git a/data/generation/data_generation_params_v.1.2.0.yml b/data/generation/data_generation_params_v.1.2.0.yml index 6e489eaa..e0a9953c 100644 --- a/data/generation/data_generation_params_v.1.2.0.yml +++ b/data/generation/data_generation_params_v.1.2.0.yml @@ -106,7 +106,7 @@ dataset_features: # Centroid shift options add_intrapixel_shifts: True # In pixels (should be abs(limits)<0.5) - intrapixel_shift_range: [-0.5, 0.5] + intrapixel_shift_range: [-0.3, 0.3] # CCD misalignment options add_ccd_misalignments: True diff --git a/data/generation/data_generation_params_v.1.3.0.yml b/data/generation/data_generation_params_v.1.3.0.yml index 8b48221d..542a1140 100644 --- a/data/generation/data_generation_params_v.1.3.0.yml +++ b/data/generation/data_generation_params_v.1.3.0.yml @@ -106,7 +106,7 @@ dataset_features: # Centroid shift options add_intrapixel_shifts: True # In pixels (should be abs(limits)<0.5) - intrapixel_shift_range: [-0.5, 0.5] + intrapixel_shift_range: [-0.3, 0.3] # CCD misalignment options add_ccd_misalignments: True diff --git a/data/generation/data_generation_params_v.2.1.0.yml b/data/generation/data_generation_params_v.2.1.0.yml index a06903ca..f09542b1 100644 --- a/data/generation/data_generation_params_v.2.1.0.yml +++ b/data/generation/data_generation_params_v.2.1.0.yml @@ -106,7 +106,7 @@ dataset_features: # Centroid shift options add_intrapixel_shifts: True # In pixels (should be abs(limits)<0.5) - intrapixel_shift_range: [-0.5, 0.5] + intrapixel_shift_range: [-0.3, 0.3] # CCD misalignment options add_ccd_misalignments: True diff --git a/data/generation/data_generation_params_v.2.2.0.yml b/data/generation/data_generation_params_v.2.2.0.yml index 7e2354b6..6c37148e 100644 --- a/data/generation/data_generation_params_v.2.2.0.yml +++ b/data/generation/data_generation_params_v.2.2.0.yml @@ -106,7 +106,7 @@ dataset_features: # Centroid shift options add_intrapixel_shifts: True # In pixels (should be abs(limits)<0.5) - intrapixel_shift_range: [-0.5, 0.5] + intrapixel_shift_range: [-0.3, 0.3] # CCD misalignment options add_ccd_misalignments: True diff --git a/data/generation/data_generation_params_v.2.3.0.yml b/data/generation/data_generation_params_v.2.3.0.yml index 27f644ea..63c76356 100644 --- a/data/generation/data_generation_params_v.2.3.0.yml +++ b/data/generation/data_generation_params_v.2.3.0.yml @@ -106,7 +106,7 @@ dataset_features: # Centroid shift options add_intrapixel_shifts: True # In pixels (should be abs(limits)<0.5) - intrapixel_shift_range: [-0.5, 0.5] + intrapixel_shift_range: [-0.3, 0.3] # CCD misalignment options add_ccd_misalignments: True diff --git a/data/generation/data_generation_params_v.3.1.0.yml b/data/generation/data_generation_params_v.3.1.0.yml index fd3b1c35..e1fc2f7b 100644 --- a/data/generation/data_generation_params_v.3.1.0.yml +++ b/data/generation/data_generation_params_v.3.1.0.yml @@ -106,7 +106,7 @@ dataset_features: # Centroid shift options add_intrapixel_shifts: True # In pixels (should be abs(limits)<0.5) - intrapixel_shift_range: [-0.5, 0.5] + intrapixel_shift_range: [-0.3, 0.3] # CCD misalignment options add_ccd_misalignments: True diff --git a/data/generation/data_generation_params_v.3.2.0.yml b/data/generation/data_generation_params_v.3.2.0.yml index 224ecdeb..2c3c35dd 100644 --- a/data/generation/data_generation_params_v.3.2.0.yml +++ b/data/generation/data_generation_params_v.3.2.0.yml @@ -106,7 +106,7 @@ dataset_features: # Centroid shift options add_intrapixel_shifts: True # In pixels (should be abs(limits)<0.5) - intrapixel_shift_range: [-0.5, 0.5] + intrapixel_shift_range: [-0.3, 0.3] # CCD misalignment options add_ccd_misalignments: True diff --git a/data/generation/data_generation_params_v.3.3.0.yml b/data/generation/data_generation_params_v.3.3.0.yml index bc457818..9e13d377 100644 --- a/data/generation/data_generation_params_v.3.3.0.yml +++ b/data/generation/data_generation_params_v.3.3.0.yml @@ -106,7 +106,7 @@ dataset_features: # Centroid shift options add_intrapixel_shifts: True # In pixels (should be abs(limits)<0.5) - intrapixel_shift_range: [-0.5, 0.5] + intrapixel_shift_range: [-0.3, 0.3] # CCD misalignment options add_ccd_misalignments: True diff --git a/data/generation/data_generation_script.py b/data/generation/data_generation_script.py index f9746ad9..86ebdd1b 100644 --- a/data/generation/data_generation_script.py +++ b/data/generation/data_generation_script.py @@ -25,9 +25,9 @@ ) from wf_psf.utils.read_config import read_conf, RecursiveNamespace from wf_psf.sims.psf_simulator import PSFSimulator -from wf_psf.utils.preprocessing import shift_x_y_to_zk1_2_wavediff +from wf_psf.data.data_zernike_utils import shift_x_y_to_zk1_2_wavediff from wf_psf.sims.spatial_varying_psf import SpatialVaryingPSF, ZernikeHelper -from wf_psf.utils.ccd_misalignments import CCDMisalignmentCalculator +from wf_psf.instrument.ccd_misalignments import CCDMisalignmentCalculator # Pre-defined colormap @@ -561,6 +561,7 @@ def main(args): # ------------ # # Centroid shifts + no_shift_test_zks = None if add_intrapixel_shifts: @@ -599,6 +600,7 @@ def main(args): # Add the centroid shifts to the Zernike coefficients train_zks += train_delta_centroid_shifts + no_shift_test_zks = np.copy(test_zks) test_zks += test_delta_centroid_shifts # ------------ # @@ -752,12 +754,24 @@ def main(args): sim_PSF_toolkit.generate_poly_PSF(test_SED_list[it], n_bins=n_bins) ) + # Generate test polychromatic PSFs without shifts + if no_shift_test_zks is not None: + test_poly_psf_noshift_list = [] + print("Generate test PSFs at observation resolution without shifts") + for it in tqdm(range(no_shift_test_zks.shape[0])): + sim_PSF_toolkit.set_z_coeffs(no_shift_test_zks[it, :]) + test_poly_psf_noshift_list.append( + sim_PSF_toolkit.generate_poly_PSF(test_SED_list[it], n_bins=n_bins) + ) + # Generate numpy arrays from the lists train_poly_psf_np = np.array(train_poly_psf_list) train_SED_np = np.array(train_SED_list) test_poly_psf_np = np.array(test_poly_psf_list) test_SED_np = np.array(test_SED_list) + if no_shift_test_zks is not None: + test_poly_psf_noshift_np = np.array(test_poly_psf_noshift_list) # Generate the noisy train stars # Copy the training stars @@ -776,11 +790,28 @@ def main(args): axis=0, ) + # Also add noise to the test stars + noisy_test_poly_psf_np = np.copy(test_poly_psf_np) + # Generate a dataset with a SNR varying randomly within the desired range + rand_SNR = ( + np.random.rand(noisy_test_poly_psf_np.shape[0]) * (SNR_range[1] - SNR_range[0]) + ) + SNR_range[0] + # Add Gaussian noise to the observations + noisy_test_poly_psf_np = np.stack( + [ + add_noise(_im, desired_SNR=_SNR) + for _im, _SNR in zip(noisy_test_poly_psf_np, rand_SNR) + ], + axis=0, + ) + # ------------ # # Generate masks if add_masks: + masked_noisy_test_poly_psf_np = np.copy(noisy_test_poly_psf_np) + if mask_type == "random": # Generate random train masks train_masks = generate_n_mask( @@ -797,6 +828,12 @@ def main(args): noisy_train_poly_psf_np.dtype ) + # Apply the random masks to the test stars + masked_noisy_test_poly_psf_np = ( + masked_noisy_test_poly_psf_np + * test_masks.astype(noisy_test_poly_psf_np.dtype) + ) + # Turn masks to SHE convention. 1 (True) means to mask and 0 (False) means to keep train_masks = ~train_masks test_masks = ~test_masks @@ -1093,8 +1130,21 @@ def main(args): SR_sim_PSF_toolkit.generate_poly_PSF(test_SED_list[it_j], n_bins=n_bins) ) + # Generate the test super resolved (SR) polychromatic PSFs without shifts + if no_shift_test_zks is not None: + SR_test_poly_psf_noshift_list = [] + + print("Generate testing SR PSFs no shifts") + for it_j in tqdm(range(n_test_stars)): + SR_sim_PSF_toolkit.set_z_coeffs(no_shift_test_zks[it_j, :]) + SR_test_poly_psf_noshift_list.append( + SR_sim_PSF_toolkit.generate_poly_PSF(test_SED_list[it_j], n_bins=n_bins) + ) + # Generate numpy arrays from the lists SR_test_poly_psf_np = np.array(SR_test_poly_psf_list) + if no_shift_test_zks is not None: + SR_test_poly_psf_noshift_np = np.array(SR_test_poly_psf_noshift_list) # ------------ # # Save test datasets @@ -1131,6 +1181,7 @@ def main(args): test_psf_dataset = { "stars": test_poly_psf_np, "SR_stars": SR_test_poly_psf_np, + "noisy_stars": noisy_test_poly_psf_np, "positions": test_positions, "SEDs": test_SED_np, "zernike_GT": test_zks, @@ -1138,10 +1189,15 @@ def main(args): if add_masks: test_psf_dataset["masks"] = test_masks + test_psf_dataset["masked_noisy_stars"] = masked_noisy_test_poly_psf_np if add_ccd_misalignments: test_psf_dataset["zernike_ccd_misalignments"] = test_delta_Z3_arr + if no_shift_test_zks is not None: + test_psf_dataset["stars_noshift"] = test_poly_psf_noshift_np + test_psf_dataset["SR_stars_noshift"] = SR_test_poly_psf_noshift_np + if add_intrapixel_shifts: test_psf_dataset["zernike_centroid_shifts"] = test_delta_centroid_shifts test_psf_dataset["pix_centroid_shifts"] = np.stack( diff --git a/src/wf_psf/__init__.py b/src/wf_psf/__init__.py index 5df41b29..988b02fe 100644 --- a/src/wf_psf/__init__.py +++ b/src/wf_psf/__init__.py @@ -1,7 +1,7 @@ import importlib # Dynamically import modules to trigger side effects when wf_psf is imported -importlib.import_module('wf_psf.psf_models.psf_models') -importlib.import_module('wf_psf.psf_models.psf_model_semiparametric') -importlib.import_module('wf_psf.psf_models.psf_model_physical_polychromatic') -importlib.import_module('wf_psf.psf_models.tf_psf_field') +importlib.import_module("wf_psf.psf_models.psf_models") +importlib.import_module("wf_psf.psf_models.models.psf_model_semiparametric") +importlib.import_module("wf_psf.psf_models.models.psf_model_physical_polychromatic") +importlib.import_module("wf_psf.psf_models.tf_modules.tf_psf_field") diff --git a/src/wf_psf/utils/centroids.py b/src/wf_psf/data/centroids.py similarity index 72% rename from src/wf_psf/utils/centroids.py rename to src/wf_psf/data/centroids.py index 75d3c7e9..01135428 100644 --- a/src/wf_psf/utils/centroids.py +++ b/src/wf_psf/data/centroids.py @@ -8,22 +8,92 @@ import numpy as np import scipy.signal as scisig -from wf_psf.utils.preprocessing import shift_x_y_to_zk1_2_wavediff +from fractions import Fraction from typing import Optional +def compute_centroid_correction( + model_params, centroid_dataset, batch_size: int = 1 +) -> np.ndarray: + """Compute centroid corrections using Zernike polynomials. + + This function calculates the Zernike contributions required to match the centroid + of the WaveDiff PSF model to the observed star centroids, processing in batches. + + Parameters + ---------- + model_params : RecursiveNamespace + An object containing parameters for the PSF model, including pixel sampling and initial centroid window parameters. + + centroid_dataset : dict + Dictionary containing star data needed for centroiding: + - "stamps" : np.ndarray + Array of star postage stamps (required). + - "masks" : Optional[np.ndarray] + Array of star masks (optional, can be None). + + batch_size : int, optional + The batch size to use when processing the stars. Default is 16. + + Returns + ------- + zernike_centroid_array : np.ndarray + A 2D NumPy array of shape `(n_stars, 3)`, where `n_stars` is the number of + observed stars. The array contains the computed Zernike (Z1, Z2) contributions, + with zero padding applied to the first column to ensure a consistent shape. + """ + # Retrieve stamps and masks from centroid_dataset + star_postage_stamps = centroid_dataset.get("stamps") + star_masks = centroid_dataset.get("masks") # may be None + + if star_postage_stamps is None: + raise ValueError("centroid_dataset must contain 'stamps'") + + pix_sampling = model_params.pix_sampling * 1e-6 # Change units from [um] to [m] + + reference_shifts = [ + float(Fraction(value)) for value in model_params.reference_shifts + ] + + n_stars = len(star_postage_stamps) + zernike_centroid_array = [] + + # Batch process the stars + for i in range(0, n_stars, batch_size): + batch_postage_stamps = star_postage_stamps[i : i + batch_size] + batch_masks = star_masks[i : i + batch_size] if star_masks is not None else None + + # Compute Zernike 1 and Zernike 2 for the batch + zk1_2_batch = -1.0 * compute_zernike_tip_tilt( + batch_postage_stamps, batch_masks, pix_sampling, reference_shifts + ) + + # Zero pad array for each batch and append + zernike_centroid_array.append( + np.pad( + zk1_2_batch, + pad_width=[(0, 0), (1, 0)], + mode="constant", + constant_values=0, + ) + ) + + # Combine all batches into a single array + return np.concatenate(zernike_centroid_array, axis=0) + + def compute_zernike_tip_tilt( star_images: np.ndarray, star_masks: Optional[np.ndarray] = None, pixel_sampling: float = 12e-6, - reference_shifts: list[float] = [-1/3, -1/3], + reference_shifts: list[float] = [-1 / 3, -1 / 3], sigma_init: float = 2.5, n_iter: int = 20, ) -> np.ndarray: """ Compute Zernike tip-tilt corrections for a batch of PSF images. - This function estimates the centroid shifts of multiple PSFs and computes + This function estimates the centroid shifts of multiple PSFs and computes the corresponding Zernike tip-tilt corrections to align them with a reference. Parameters @@ -39,7 +109,7 @@ def compute_zernike_tip_tilt( pixel_sampling : float, optional The pixel size in meters. Defaults to `12e-6 m` (12 microns). reference_shifts : list[float], optional - The target centroid shifts in pixels, specified as `[dy, dx]`. + The target centroid shifts in pixels, specified as `[dy, dx]`. Defaults to `[-1/3, -1/3]` (nominal Euclid conditions). sigma_init : float, optional Initial standard deviation for centroid estimation. Default is `2.5`. @@ -52,19 +122,18 @@ def compute_zernike_tip_tilt( An array of shape `(num_images, 2)`, where: - Column 0 contains `Zk1` (tip) values. - Column 1 contains `Zk2` (tilt) values. - + Notes ----- - This function processes all images at once using vectorized operations. - The Zernike coefficients are computed in the WaveDiff convention. """ + from wf_psf.data.data_zernike_utils import shift_x_y_to_zk1_2_wavediff + # Vectorize the centroid computation centroid_estimator = CentroidEstimator( - im=star_images, - mask=star_masks, - sigma_init=sigma_init, - n_iter=n_iter - ) + im=star_images, mask=star_masks, sigma_init=sigma_init, n_iter=n_iter + ) shifts = centroid_estimator.get_intra_pixel_shifts() @@ -72,78 +141,79 @@ def compute_zernike_tip_tilt( reference_shifts = np.array(reference_shifts) # Reshape to ensure it's a column vector (1, 2) - reference_shifts = reference_shifts[None,:] - + reference_shifts = reference_shifts[None, :] + # Broadcast reference_shifts to match the shape of shifts - reference_shifts = np.broadcast_to(reference_shifts, shifts.shape) - + reference_shifts = np.broadcast_to(reference_shifts, shifts.shape) + # Compute displacements - displacements = (reference_shifts - shifts) # - + displacements = reference_shifts - shifts # + # Ensure the correct axis order for displacements (x-axis, then y-axis) - displacements_swapped = displacements[:, [1, 0]] # Adjust axis order if necessary + displacements_swapped = displacements[:, [1, 0]] # Adjust axis order if necessary # Call shift_x_y_to_zk1_2_wavediff directly on the vector of displacements - zk1_2_array = shift_x_y_to_zk1_2_wavediff(displacements_swapped.flatten() * pixel_sampling ) # vectorized call - + zk1_2_array = shift_x_y_to_zk1_2_wavediff( + displacements_swapped.flatten() * pixel_sampling + ) # vectorized call + # Reshape the result back to the original shape of displacements zk1_2_array = zk1_2_array.reshape(displacements.shape) - - return zk1_2_array + return zk1_2_array class CentroidEstimator: """ Calculate centroids and estimate intra-pixel shifts for a batch of star images. - This class estimates the centroid of each star in a batch of images using an - iterative process that fits an elliptical Gaussian model to the star images. - The estimated centroids are returned along with the intra-pixel shifts, which - represent the difference between the estimated centroid and the center of the + This class estimates the centroid of each star in a batch of images using an + iterative process that fits an elliptical Gaussian model to the star images. + The estimated centroids are returned along with the intra-pixel shifts, which + represent the difference between the estimated centroid and the center of the image grid (or pixel grid). - The process is vectorized, allowing multiple star images to be processed in + The process is vectorized, allowing multiple star images to be processed in parallel, which significantly improves performance when working with large batches. Parameters ---------- im : numpy.ndarray - A 3D numpy array of star image stamps. The shape of the array should be - (n_images, height, width), where n_images is the number of stars, and + A 3D numpy array of star image stamps. The shape of the array should be + (n_images, height, width), where n_images is the number of stars, and height and width are the dimensions of each star's image. - + mask : numpy.ndarray, optional - A 3D numpy array of the same shape as `im`, representing the mask for each star image. - A mask value of `0` indicates that the pixel is fully considered (unmasked), while a value of `1` means the pixel is completely ignored (masked). - Values between `0` and `1` act as weights, allowing partial consideration of the pixel. - If not provided, no mask is applied. + A 3D numpy array of the same shape as `im`, representing the mask for each star image. + A mask value of `0` indicates that the pixel is fully considered (unmasked), while a value of `1` means the pixel is completely ignored (masked). + Values between `0` and `1` act as weights, allowing partial consideration of the pixel. + If not provided, no mask is applied. sigma_init : float, optional - The initial guess for the standard deviation (sigma) of the elliptical Gaussian + The initial guess for the standard deviation (sigma) of the elliptical Gaussian that models the star. Default is 7.5. n_iter : int, optional - The number of iterations for the iterative centroid estimation procedure. + The number of iterations for the iterative centroid estimation procedure. Default is 5. auto_run : bool, optional - If True, the centroid estimation procedure will be automatically run upon + If True, the centroid estimation procedure will be automatically run upon initialization. Default is True. xc : float, optional - The initial guess for the x-component of the centroid. If None, it is set + The initial guess for the x-component of the centroid. If None, it is set to the center of the image. Default is None. yc : float, optional - The initial guess for the y-component of the centroid. If None, it is set + The initial guess for the y-component of the centroid. If None, it is set to the center of the image. Default is None. Attributes ---------- xc : numpy.ndarray The x-components of the estimated centroids for each image. Shape is (n_images,). - + yc : numpy.ndarray The y-components of the estimated centroids for each image. Shape is (n_images,). @@ -154,10 +224,10 @@ class CentroidEstimator: elliptical_gaussian(e1=0, e2=0) Computes an elliptical 2D Gaussian with the specified shear parameters. - + compute_moments() Computes the first-order moments of the star images and updates the centroid estimates. - + estimate() Runs the iterative centroid estimation procedure for all images. @@ -171,13 +241,27 @@ class CentroidEstimator: ----- The iterative centroid estimation procedure fits an elliptical Gaussian to each star image and computes the centroid by calculating the weighted moments. The - `estimate()` method performs the centroid calculation for a batch of images using - the iterative approach defined by the `n_iter` parameter. This class is designed + `estimate()` method performs the centroid calculation for a batch of images using + the iterative approach defined by the `n_iter` parameter. This class is designed to be efficient and scalable when processing large batches of star images. """ - def __init__(self, im, mask=None, sigma_init=7.5, n_iter=5, auto_run=True, xc=None, yc=None): + def __init__( + self, im, mask=None, sigma_init=7.5, n_iter=5, auto_run=True, xc=None, yc=None + ): """Initialize class attributes.""" + # Convert to np.ndarray if not already + im = np.asarray(im) + if mask is not None: + mask = np.asarray(mask) + + # Check im dimensions convert to batch, if 2D + if im.ndim == 2: + # Single stamp → convert to batch of one + im = np.expand_dims(im, axis=0) + elif im.ndim != 3: + raise ValueError(f"Expected 2D or 3D input, got shape {im.shape}") + self.im = im self.mask = mask if self.mask is not None: @@ -195,7 +279,6 @@ def __init__(self, im, mask=None, sigma_init=7.5, n_iter=5, auto_run=True, xc=No if auto_run: self.estimate() - def update_grid(self): """Vectorized update of the grid coordinates for multiple star stamps.""" @@ -205,9 +288,9 @@ def update_grid(self): y_range = np.arange(Ny) # Correct subtraction without mixing axes - self.xx = (x_range - self.xc[:, None]) - self.yy = (y_range - self.yc[:, None]) - + self.xx = x_range - self.xc[:, None] + self.yy = y_range - self.yc[:, None] + # Now, expand to the correct shape (num_images, Nx, Ny) # Add the extra dimension for the number of stars self.xx = self.xx[:, :, None] # Shape: (num_images, Nx, 1) @@ -221,7 +304,7 @@ def elliptical_gaussian(self, e1=0, e2=0): # Shear the grid coordinates gxx = (1 - e1) * self.xx - e2 * self.yy gyy = (1 + e1) * self.yy - e2 * self.xx - + # Compute elliptical Gaussian return np.exp(-(gxx**2 + gyy**2) / (2 * self.sigma_init**2)) @@ -235,7 +318,11 @@ def compute_moments(self): Q0 = np.sum(masked_im_window, axis=(1, 2)) # Sum over images and their pixels Q1 = np.array( [ - np.sum(np.sum(masked_im_window, axis=2 - i) * np.arange(self.stamp_size[i]), axis=1) + np.sum( + np.sum(masked_im_window, axis=2 - i) + * np.arange(self.stamp_size[i]), + axis=1, + ) for i in range(2) ] ) @@ -257,7 +344,7 @@ def get_centroids(self): def get_intra_pixel_shifts(self): """Get intra-pixel shifts for all images. - + Intra-pixel shifts are the differences between the estimated centroid and the center of the image stamp (or pixel grid). These shifts are calculated for all images in the batch. Returns @@ -265,8 +352,8 @@ def get_intra_pixel_shifts(self): np.array A 2D array of shape (num_of_images, 2), where each row corresponds to the x and y shifts for each image. """ - shifts = np.stack([self.xc - self.xc0, self.yc - self.yc0], axis=-1) - + shifts = np.stack([self.xc - self.xc0, self.yc - self.yc0], axis=-1) + return shifts diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py new file mode 100644 index 00000000..c18b9bf7 --- /dev/null +++ b/src/wf_psf/data/data_handler.py @@ -0,0 +1,447 @@ +"""Data Handler Module. + +Provides tools for loading, preprocessing, and managing data used in both training and inference workflows. + +Includes: +- The `DataHandler` class for managing datasets and associated metadata +- Utility functions for loading structured data products +- Preprocessing routines for spectral energy distributions (SEDs), including format conversion (e.g., to TensorFlow) and transformations + +This module serves as a central interface between raw data and modeling components. + +Authors: Jennifer Pollack , Tobias Liaudat +""" + +import os +import numpy as np +import wf_psf.utils.utils as utils +from wf_psf.psf_models.tf_modules.tf_utils import ensure_tensor +import tensorflow as tf +from typing import Optional, Union +import logging + +logger = logging.getLogger(__name__) + + +class DataHandler: + """ + DataHandler for WaveDiff PSF modeling. + + This class manages loading, preprocessing, and TensorFlow conversion of datasets used + for PSF model training, testing, and inference in the WaveDiff framework. + + Parameters + ---------- + dataset_type : str + Indicates the dataset mode ("train", "test", or "inference"). + data_params : RecursiveNamespace + Configuration object containing dataset parameters (e.g., file paths, preprocessing flags). + simPSF : PSFSimulator + An instance of the PSFSimulator class used to encode SEDs into a TensorFlow-compatible format. + n_bins_lambda : int + Number of wavelength bins used to discretize SEDs. + load_data : bool, optional + If True (default), loads and processes data during initialization. If False, data loading + must be triggered explicitly. + dataset : dict or list, optional + If provided, uses this pre-loaded dataset instead of triggering automatic loading. + sed_data : dict or list, optional + If provided, uses this SED data directly instead of extracting it from the dataset. + + Attributes + ---------- + dataset_type : str + Indicates the dataset mode ("train", "test", or "inference"). + data_params : RecursiveNamespace + Configuration parameters for data access and structure. + simPSF : PSFSimulator + Simulator used to transform SEDs into TensorFlow-ready tensors. + n_bins_lambda : int + Number of wavelength bins in the SED representation. + load_data_on_init : bool + Whether data was loaded automatically during initialization. + dataset : dict + Loaded dataset including keys such as 'positions', 'stars', 'noisy_stars', or similar. + sed_data : tf.Tensor + TensorFlow-formatted SED data with shape [batch_size, n_bins_lambda, features]. + """ + + def __init__( + self, + dataset_type, + data_params, + simPSF, + n_bins_lambda, + load_data: bool = True, + dataset: Optional[Union[dict, list]] = None, + sed_data: Optional[Union[dict, list]] = None, + ): + """ + Initialize the DataHandler for PSF dataset preparation. + + This constructor sets up the dataset handler used for PSF simulation tasks, + such as training, testing, or inference. It supports three modes of use: + + 1. **Manual mode** (`load_data=False`, no `dataset`): data loading and SED processing + must be triggered manually via `load_dataset()` and `process_sed_data()`. + 2. **Pre-loaded dataset mode** (`dataset` is provided): the given dataset is used directly, + and `process_sed_data()` is called with either the given `sed_data` or `dataset["SEDs"]`. + 3. **Automatic loading mode** (`load_data=True` and no `dataset`): the dataset is loaded + from disk using `data_params`, and SEDs are extracted and processed automatically. + + Parameters + ---------- + dataset_type : str + One of {"train", "test", "inference"} indicating dataset usage. + data_params : RecursiveNamespace + Configuration object with paths, preprocessing options, and metadata. + simPSF : PSFSimulator + Used to convert SEDs to TensorFlow format. + n_bins_lambda : int + Number of wavelength bins for the SEDs. + load_data : bool, optional + Whether to automatically load and process the dataset (default: True). + dataset : dict or list, optional + A pre-loaded dataset to use directly (overrides `load_data`). + sed_data : array-like, optional + Pre-loaded SED data to use directly. If not provided but `dataset` is, + SEDs are taken from `dataset["SEDs"]`. + + Raises + ------ + ValueError + If SEDs cannot be found in either `dataset` or as `sed_data`. + + Notes + ----- + - `self.dataset` and `self.sed_data` are both `None` if neither `dataset` nor + `load_data=True` is used. + - TensorFlow conversion is performed at the end of initialization via `convert_dataset_to_tensorflow()`. + """ + + self.dataset_type = dataset_type + self.data_params = data_params + self.simPSF = simPSF + self.n_bins_lambda = n_bins_lambda + self.load_data_on_init = load_data + + if dataset is not None: + self.dataset = dataset + self.process_sed_data(sed_data) + self.validate_and_process_dataset() + elif self.load_data_on_init: + self.load_dataset() + self.process_sed_data(self.dataset["SEDs"]) + self.validate_and_process_dataset() + else: + self.dataset = None + self.sed_data = None + + @property + def tf_positions(self): + return ensure_tensor(self.dataset["positions"]) + + def load_dataset(self): + """Load dataset. + + Load the dataset based on the specified dataset type. + + """ + self.dataset = np.load( + os.path.join(self.data_params.data_dir, self.data_params.file), + allow_pickle=True, + )[()] + + def validate_and_process_dataset(self): + """Validate the dataset structure and convert fields to TensorFlow tensors.""" + self._validate_dataset_structure() + self._convert_dataset_to_tensorflow() + + def _validate_dataset_structure(self): + """Validate dataset structure based on dataset_type.""" + if self.dataset is None: + raise ValueError("Dataset is None") + + if "positions" not in self.dataset: + raise ValueError("Dataset missing required field: 'positions'") + + if self.dataset_type == "training": + if "noisy_stars" not in self.dataset: + raise ValueError( + f"Missing required field 'noisy_stars' in {self.dataset_type} dataset." + ) + elif self.dataset_type == "test": + if "stars" not in self.dataset: + raise ValueError( + f"Missing required field 'stars' in {self.dataset_type} dataset." + ) + elif self.dataset_type == "inference": + pass + else: + raise ValueError(f"Unrecognized dataset_type: {self.dataset_type}") + + def _convert_dataset_to_tensorflow(self): + """Convert dataset to TensorFlow tensors.""" + self.dataset["positions"] = ensure_tensor( + self.dataset["positions"], dtype=tf.float32 + ) + + if self.dataset_type == "train": + self.dataset["noisy_stars"] = ensure_tensor( + self.dataset["noisy_stars"], dtype=tf.float32 + ) + elif self.dataset_type == "test": + self.dataset["stars"] = ensure_tensor( + self.dataset["stars"], dtype=tf.float32 + ) + + def process_sed_data(self, sed_data): + """ + Generate and process SED (Spectral Energy Distribution) data. + + This method transforms raw SED inputs into TensorFlow tensors suitable for model input. + It generates wavelength-binned SED elements using the PSF simulator, converts the result + into a tensor, and transposes it to match the expected shape for training or inference. + + Parameters + ---------- + sed_data : list or array-like + A list or array of raw SEDs, where each SED is typically a vector of flux values + or coefficients. These will be processed using the PSF simulator. + + Raises + ------ + ValueError + If `sed_data` is None. + + Notes + ----- + The resulting tensor is stored in `self.sed_data` and has shape + `(num_samples, n_bins_lambda, n_components)`, where: + - `num_samples` is the number of SEDs, + - `n_bins_lambda` is the number of wavelength bins, + - `n_components` is the number of components per SED (e.g., filters or basis terms). + + The intermediate tensor is created with `tf.float64` for precision during generation, + but is converted to `tf.float32` after processing for use in training. + """ + if sed_data is None: + raise ValueError("SED data must be provided explicitly or via dataset.") + + self.sed_data = [ + utils.generate_SED_elems_in_tensorflow( + _sed, self.simPSF, n_bins=self.n_bins_lambda, tf_dtype=tf.float64 + ) + for _sed in sed_data + ] + # Convert list of generated SED tensors to a single TensorFlow tensor of float32 dtype + self.sed_data = ensure_tensor(self.sed_data, dtype=tf.float32) + self.sed_data = tf.transpose(self.sed_data, perm=[0, 2, 1]) + + +def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: + """ + Extract and concatenate star-related data from training and test datasets. + + This function retrieves arrays (e.g., postage stamps, masks, positions) from + both the training and test datasets using the specified keys, converts them + to NumPy if necessary, and concatenates them along the first axis. + + Parameters + ---------- + data : DataConfigHandler + Object containing training and test datasets. + train_key : str + Key to retrieve data from the training dataset + (e.g., 'noisy_stars', 'masks'). + test_key : str + Key to retrieve data from the test dataset + (e.g., 'stars', 'masks'). + + Returns + ------- + np.ndarray + Concatenated NumPy array containing the selected data from both + training and test sets. + + Raises + ------ + KeyError + If either the training or test dataset does not contain the + requested key. + + Notes + ----- + - Designed for datasets with separate train/test splits, such as when + evaluating metrics on held-out data. + - TensorFlow tensors are automatically converted to NumPy arrays. + - Requires eager execution if TensorFlow tensors are present. + """ + # Ensure the requested keys exist in both training and test datasets + missing_keys = [ + key + for key, dataset in [ + (train_key, data.training_data.dataset), + (test_key, data.test_data.dataset), + ] + if key not in dataset + ] + + if missing_keys: + raise KeyError(f"Missing keys in dataset: {missing_keys}") + + # Retrieve data from training and test sets + train_data = data.training_data.dataset[train_key] + test_data = data.test_data.dataset[test_key] + + # Convert to NumPy if necessary + if tf.is_tensor(train_data): + train_data = train_data.numpy() + if tf.is_tensor(test_data): + test_data = test_data.numpy() + + # Concatenate and return + return np.concatenate((train_data, test_data), axis=0) + + +def get_data_array( + data, + run_type: str, + key: str = None, + train_key: str = None, + test_key: str = None, + allow_missing: bool = True, +) -> Optional[np.ndarray]: + """ + Retrieve data from dataset depending on run type. + + This function provides a unified interface for accessing data across different + execution contexts (training, simulation, metrics, inference). It handles + key resolution with sensible fallbacks and optional missing data tolerance. + + Parameters + ---------- + data : DataConfigHandler + Dataset object containing training, test, or inference data. + Expected to have methods compatible with the specified run_type. + run_type : {"training", "simulation", "metrics", "inference"} + Execution context that determines how data is retrieved: + - "training", "simulation", "metrics": Uses extract_star_data function + - "inference": Retrieves data directly from dataset using key lookup + key : str, optional + Primary key for data lookup. Used directly for inference run_type. + If None, falls back to train_key value. Default is None. + train_key : str, optional + Key for training dataset access. If None and key is provided, + defaults to key value. Default is None. + test_key : str, optional + Key for test dataset access. If None, defaults to the resolved + train_key value. Default is None. + allow_missing : bool, default True + Control behavior when data is missing or keys are not found: + - True: Return None instead of raising exceptions + - False: Raise appropriate exceptions (KeyError, ValueError) + + Returns + ------- + np.ndarray or None + Retrieved data as NumPy array. Returns None only when allow_missing=True + and the requested data is not available. + + Raises + ------ + ValueError + If run_type is not one of the supported values, or if no key can be + resolved for the operation and allow_missing=False. + KeyError + If the specified key is not found in the dataset and allow_missing=False. + + Notes + ----- + Key resolution follows this priority order: + 1. train_key = train_key or key + 2. test_key = test_key or resolved_train_key + 3. key = key or resolved_train_key (for inference fallback) + + For TensorFlow tensors, the .numpy() method is called to convert to NumPy. + Other data types are converted using np.asarray(). + + Examples + -------- + >>> # Training data retrieval + >>> train_data = get_data_array(data, "training", train_key="noisy_stars") + + >>> # Inference with fallback handling + >>> inference_data = get_data_array(data, "inference", key="positions", + ... allow_missing=True) + >>> if inference_data is None: + ... print("No inference data available") + + >>> # Using key parameter for both train and inference + >>> result = get_data_array(data, "inference", key="positions") + """ + # Validate run_type early + valid_run_types = {"training", "simulation", "metrics", "inference"} + if run_type not in valid_run_types: + raise ValueError(f"run_type must be one of {valid_run_types}, got '{run_type}'") + + # Simplify key resolution with clear precedence + effective_train_key = train_key or key + effective_test_key = test_key or effective_train_key + effective_key = key or effective_train_key + + try: + if run_type in {"simulation", "training", "metrics"}: + return extract_star_data(data, effective_train_key, effective_test_key) + else: # inference + return _get_direct_data(data, effective_key, allow_missing) + except Exception: + if allow_missing: + return None + raise + + +def _get_direct_data(data, key: str, allow_missing: bool) -> Optional[np.ndarray]: + """ + Extract data directly with proper error handling and type conversion. + + Parameters + ---------- + data : DataConfigHandler + Dataset object with a .dataset attribute that supports .get() method. + key : str or None + Key to lookup in the dataset. If None, behavior depends on allow_missing. + allow_missing : bool + If True, return None for missing keys/data instead of raising exceptions. + + Returns + ------- + np.ndarray or None + Data converted to NumPy array, or None if allow_missing=True and + data is unavailable. + + Raises + ------ + ValueError + If key is None and allow_missing=False. + KeyError + If key is not found in dataset and allow_missing=False. + + Notes + ----- + Conversion logic: + - TensorFlow tensors: Converted using .numpy() method + - Other types: Converted using np.asarray() + """ + if key is None: + if allow_missing: + return None + raise ValueError("No key provided for inference data") + + value = data.dataset.get(key, None) + if value is None: + if allow_missing: + return None + raise KeyError(f"Key '{key}' not found in inference dataset") + + return value.numpy() if tf.is_tensor(value) else np.asarray(value) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py new file mode 100644 index 00000000..54582e7d --- /dev/null +++ b/src/wf_psf/data/data_zernike_utils.py @@ -0,0 +1,457 @@ +"""Utilities for Zernike Data Handling. + +This module provides utility functions for working with Zernike coefficients, including: +- Prior generation +- Data loading +- Conversions between physical displacements (e.g., defocus, centroid shifts) and modal Zernike coefficients + +Useful in contexts where Zernike representations are used to model optical aberrations or link physical misalignments to wavefront modes. + +:Author: Tobias Liaudat + +""" + +from dataclasses import dataclass +from typing import Optional, Union +import numpy as np +import tensorflow as tf +from wf_psf.data.centroids import compute_centroid_correction +from wf_psf.data.data_handler import get_data_array +from wf_psf.instrument.ccd_misalignments import compute_ccd_misalignment +from wf_psf.utils.read_config import RecursiveNamespace +import logging + +logger = logging.getLogger(__name__) + + +@dataclass +class ZernikeInputs: + zernike_prior: Optional[np.ndarray] # true prior, if provided (e.g. from PDC) + centroid_dataset: Optional[ + Union[dict, "RecursiveNamespace"] + ] # only used in training/simulation + misalignment_positions: Optional[np.ndarray] # needed for CCD corrections + + +class ZernikeInputsFactory: + @staticmethod + def build( + data, run_type: str, model_params, prior: Optional[np.ndarray] = None + ) -> ZernikeInputs: + """Builds a ZernikeInputs dataclass instance based on run type and data. + + Parameters + ---------- + data : Union[dict, DataConfigHandler] + Dataset object containing star positions, priors, and optionally pixel data. + run_type : str + One of 'training', 'simulation', or 'inference'. + model_params : RecursiveNamespace + Model parameters, including flags for prior/corrections. + prior : Optional[np.ndarray] + An explicitly passed prior (overrides any inferred one if provided). + + Returns + ------- + ZernikeInputs + """ + centroid_dataset, positions = None, None + + if run_type in {"training", "simulation", "metrics"}: + stamps = get_data_array( + data, run_type, train_key="noisy_stars", test_key="stars" + ) + masks = get_data_array(data, run_type, key="masks", allow_missing=True) + centroid_dataset = {"stamps": stamps, "masks": masks} + + positions = get_data_array(data=data, run_type=run_type, key="positions") + + if model_params.use_prior: + if prior is not None: + logger.warning( + "Explicit prior provided; ignoring dataset-based prior." + ) + else: + prior = get_np_zernike_prior(data) + + elif run_type == "inference": + stamps = get_data_array(data=data, run_type=run_type, key="sources") + masks = get_data_array(data, run_type, key="masks", allow_missing=True) + centroid_dataset = {"stamps": stamps, "masks": masks} + + positions = get_data_array(data=data, run_type=run_type, key="positions") + + if model_params.use_prior: + # Try to extract prior from `data`, if present + prior = ( + getattr(data.dataset, "zernike_prior", None) + if not isinstance(data.dataset, dict) + else data.dataset.get("zernike_prior") + ) + + if prior is None: + logger.warning( + "model_params.use_prior=True but no prior found in inference data. Proceeding with None." + ) + + else: + raise ValueError(f"Unsupported run_type: {run_type}") + + return ZernikeInputs( + zernike_prior=prior, + centroid_dataset=centroid_dataset, + misalignment_positions=positions, + ) + + +def get_np_zernike_prior(data): + """Get the zernike prior from the provided dataset. + + This method concatenates the stars from both the training + and test datasets to obtain the full prior. + + Parameters + ---------- + data : DataConfigHandler + Object containing training and test datasets. + + Returns + ------- + zernike_prior : np.ndarray + Numpy array containing the full prior. + """ + zernike_prior = np.concatenate( + ( + data.training_data.dataset["zernike_prior"], + data.test_data.dataset["zernike_prior"], + ), + axis=0, + ) + + return zernike_prior + + +def pad_contribution_to_order(contribution: np.ndarray, max_order: int) -> np.ndarray: + """Pad a Zernike contribution array to the max Zernike order.""" + current_order = contribution.shape[1] + pad_width = ((0, 0), (0, max_order - current_order)) + return np.pad(contribution, pad_width=pad_width, mode="constant", constant_values=0) + + +def combine_zernike_contributions(contributions: list[np.ndarray]) -> np.ndarray: + """Combine multiple Zernike contributions, padding each to the max order before summing.""" + if not contributions: + raise ValueError("No contributions provided.") + + if len(contributions) == 1: + return contributions[0] + + max_order = max(contrib.shape[1] for contrib in contributions) + n_samples = contributions[0].shape[0] + + if any(c.shape[0] != n_samples for c in contributions): + raise ValueError("All contributions must have the same number of samples.") + + combined = np.zeros((n_samples, max_order)) + # Pad each contribution to the max order and sum them + for contrib in contributions: + padded = pad_contribution_to_order(contrib, max_order) + combined += padded + + return combined + + +def pad_tf_zernikes(zk_param: tf.Tensor, zk_prior: tf.Tensor, n_zks_total: int): + """ + Pad the Zernike coefficient tensors to match the specified total number of Zernikes. + + Parameters + ---------- + zk_param : tf.Tensor + Zernike coefficients for the parametric part. Shape [batch, n_zks_param, 1, 1]. + zk_prior : tf.Tensor + Zernike coefficients for the prior part. Shape [batch, n_zks_prior, 1, 1]. + n_zks_total : int + Total number of Zernikes to pad to. + + Returns + ------- + padded_zk_param : tf.Tensor + Padded Zernike coefficients for the parametric part. Shape [batch, n_zks_total, 1, 1]. + padded_zk_prior : tf.Tensor + Padded Zernike coefficients for the prior part. Shape [batch, n_zks_total, 1, 1]. + """ + pad_num_param = n_zks_total - tf.shape(zk_param)[1] + pad_num_prior = n_zks_total - tf.shape(zk_prior)[1] + + padded_zk_param = tf.cond( + tf.not_equal(pad_num_param, 0), + lambda: tf.pad(zk_param, [(0, 0), (0, pad_num_param), (0, 0), (0, 0)]), + lambda: zk_param, + ) + + padded_zk_prior = tf.cond( + tf.not_equal(pad_num_prior, 0), + lambda: tf.pad(zk_prior, [(0, 0), (0, pad_num_prior), (0, 0), (0, 0)]), + lambda: zk_prior, + ) + + return padded_zk_param, padded_zk_prior + + +def assemble_zernike_contributions( + model_params, + zernike_prior=None, + centroid_dataset=None, + positions=None, + batch_size=16, +): + """ + Assemble the total Zernike contribution map by combining the prior, + centroid correction, and CCD misalignment correction. + + Parameters + ---------- + model_params : RecursiveNamespace + Parameters controlling which contributions to apply. + zernike_prior : Optional[np.ndarray or tf.Tensor] + The precomputed Zernike prior. Can be either a NumPy array or a TensorFlow tensor. + If a Tensor, will be converted to NumPy in eager mode. + centroid_dataset : Optional[object] + Dataset used to compute centroid correction. Must have both training and test sets. + positions : Optional[np.ndarray or tf.Tensor] + Positions used for computing CCD misalignment. Must be available in inference mode. + batch_size : int + Batch size for centroid correction. + + Returns + ------- + tf.Tensor + A tensor representing the full Zernike contribution map. + """ + zernike_contribution_list = [] + + # Prior + if model_params.use_prior and zernike_prior is not None: + logger.info("Adding Zernike prior...") + if isinstance(zernike_prior, tf.Tensor): + if tf.executing_eagerly(): + zernike_prior = zernike_prior.numpy() + else: + raise RuntimeError( + "Zernike prior is a TensorFlow tensor but eager execution is disabled. " + "Cannot call `.numpy()` outside of eager mode." + ) + + elif not isinstance(zernike_prior, np.ndarray): + raise TypeError( + "Unsupported zernike_prior type. Must be np.ndarray or tf.Tensor." + ) + zernike_contribution_list.append(zernike_prior) + else: + logger.info("Skipping Zernike prior (not used or not provided).") + + # Centroid correction (tip/tilt) + if model_params.correct_centroids and centroid_dataset is not None: + logger.info("Computing centroid correction...") + centroid_correction = compute_centroid_correction( + model_params, centroid_dataset, batch_size=batch_size + ) + zernike_contribution_list.append(centroid_correction) + else: + logger.info("Skipping centroid correction (not enabled or no dataset).") + + # CCD misalignment (focus term) + if model_params.add_ccd_misalignments and positions is not None: + logger.info("Computing CCD misalignment correction...") + ccd_misalignment = compute_ccd_misalignment(model_params, positions) + zernike_contribution_list.append(ccd_misalignment) + else: + logger.info( + "Skipping CCD misalignment correction (not enabled or no positions)." + ) + + # If no contributions, return zeros tensor to avoid crashes + if not zernike_contribution_list: + logger.warning("No Zernike contributions found. Returning zero tensor.") + # Infer batch size and zernike order from model_params + n_samples = 1 + n_zks = getattr(model_params.param_hparams, "n_zernikes", 10) + return tf.zeros((n_samples, n_zks), dtype=tf.float32) + + combined_zernike_prior = combine_zernike_contributions(zernike_contribution_list) + + return tf.convert_to_tensor(combined_zernike_prior, dtype=tf.float32) + + +def shift_x_y_to_zk1_2_wavediff(dxy, tel_focal_length=24.5, tel_diameter=1.2): + """Compute Zernike 1(2) for a given shifts in x(y) in WaveDifff conventions. + + All inputs should be in [m]. + A displacement of, for example, 0.5 pixels should be scaled with the corresponding pixel scale, + e.g. 12[um], to get a displacement in [m], which would be `dxy=0.5*12e-6`. + + The output zernike coefficient is in [um] units as expected by wavediff. + + To apply match the centroid with a `dx` that has a corresponding `zk1`, + the new PSF should be generated with `-zk1`. + + The same applies to `dy` and `zk2`. + + Parameters + ---------- + dxy : float + Centroid shift in [m]. It can be on the x-axis or the y-axis. + tel_focal_length : float + Telescope focal length in [m]. + tel_diameter : float + Telescope aperture diameter in [m]. + """ + reference_pix_sampling = 12e-6 + zernike_norm_factor = 2.0 + + # return zernike_norm_factor * (dx/reference_pix_sampling) / (tel_focal_length * tel_diameter / 2) + return ( + zernike_norm_factor + * (tel_diameter / 2) + * np.sin(np.arctan((dxy / reference_pix_sampling) / tel_focal_length)) + * 3.0 + ) + + +def compute_zernike_tip_tilt( + star_images: np.ndarray, + star_masks: Optional[np.ndarray] = None, + pixel_sampling: float = 12e-6, + reference_shifts: list[float] = [-1 / 3, -1 / 3], + sigma_init: float = 2.5, + n_iter: int = 20, +) -> np.ndarray: + """ + Compute Zernike tip-tilt corrections for a batch of PSF images. + + This function estimates the centroid shifts of multiple PSFs and computes + the corresponding Zernike tip-tilt corrections to align them with a reference. + + Parameters + ---------- + star_images : np.ndarray + A batch of PSF images (3D array of shape `(num_images, height, width)`). + star_masks : np.ndarray, optional + A batch of masks (same shape as `star_postage_stamps`). Each mask can have: + - `0` to ignore the pixel. + - `1` to fully consider the pixel. + - Values in `(0,1]` as weights for partial consideration. + Defaults to None. + pixel_sampling : float, optional + The pixel size in meters. Defaults to `12e-6 m` (12 microns). + reference_shifts : list[float], optional + The target centroid shifts in pixels, specified as `[dy, dx]`. + Defaults to `[-1/3, -1/3]` (nominal Euclid conditions). + sigma_init : float, optional + Initial standard deviation for centroid estimation. Default is `2.5`. + n_iter : int, optional + Number of iterations for centroid refinement. Default is `20`. + + Returns + ------- + np.ndarray + An array of shape `(num_images, 2)`, where: + - Column 0 contains `Zk1` (tip) values. + - Column 1 contains `Zk2` (tilt) values. + Notes + ----- + - This function processes all images at once using vectorized operations. + - The Zernike coefficients are computed in the WaveDiff convention. + """ + from wf_psf.data.centroids import CentroidEstimator + + # Vectorize the centroid computation + centroid_estimator = CentroidEstimator( + im=star_images, mask=star_masks, sigma_init=sigma_init, n_iter=n_iter + ) + + shifts = centroid_estimator.get_intra_pixel_shifts() + + # Ensure reference_shifts is a NumPy array (if it's not already) + reference_shifts = np.array(reference_shifts) + + # Reshape to ensure it's a column vector (1, 2) + reference_shifts = reference_shifts[None, :] + + # Broadcast reference_shifts to match the shape of shifts + reference_shifts = np.broadcast_to(reference_shifts, shifts.shape) + + # Compute displacements + displacements = reference_shifts - shifts # + + # Ensure the correct axis order for displacements (x-axis, then y-axis) + displacements_swapped = displacements[:, [1, 0]] # Adjust axis order if necessary + + # Call shift_x_y_to_zk1_2_wavediff directly on the vector of displacements + zk1_2_array = shift_x_y_to_zk1_2_wavediff( + displacements_swapped.flatten() * pixel_sampling + ) # vectorized call + + # Reshape the result back to the original shape of displacements + zk1_2_array = zk1_2_array.reshape(displacements.shape) + + return zk1_2_array + + +def defocus_to_zk4_zemax(dz, tel_focal_length=24.5, tel_diameter=1.2): + """Compute Zernike 4 value for a given defocus in zemax conventions. + + All inputs should be in [m]. + + Parameters + ---------- + dz : float + Shift in the z-axis, perpendicular to the focal plane. Units in [m]. + tel_focal_length : float + Telescope focal length in [m]. + tel_diameter : float + Telescope aperture diameter in [m]. + """ + # Base calculation + zk4 = dz / (8.0 * (tel_focal_length / tel_diameter) ** 2) + # Apply Z4 normalisation + # This step depends on the normalisation of the Zernike basis used + zk4 /= np.sqrt(3) + # Convert to waves with a reference of 800nm + zk4 /= 800e-9 + # Remove the peak to valley value + zk4 /= 2.0 + + return zk4 + + +def defocus_to_zk4_wavediff(dz, tel_focal_length=24.5, tel_diameter=1.2): + """Compute Zernike 4 value for a given defocus in WaveDifff conventions. + + All inputs should be in [m]. + + The output zernike coefficient is in [um] units as expected by wavediff. + + Parameters + ---------- + dz : float + Shift in the z-axis, perpendicular to the focal plane. Units in [m]. + tel_focal_length : float + Telescope focal length in [m]. + tel_diameter : float + Telescope aperture diameter in [m]. + """ + # Base calculation + zk4 = dz / (8.0 * (tel_focal_length / tel_diameter) ** 2) + # Apply Z4 normalisation + # This step depends on the normalisation of the Zernike basis used + zk4 /= np.sqrt(3) + + # Remove the peak to valley value + zk4 /= 2.0 + + # Change units to [um] as Wavediff uses + zk4 *= 1e6 + + return zk4 diff --git a/src/wf_psf/data/training_preprocessing.py b/src/wf_psf/data/training_preprocessing.py deleted file mode 100644 index 95afebb1..00000000 --- a/src/wf_psf/data/training_preprocessing.py +++ /dev/null @@ -1,445 +0,0 @@ -"""Training Data Processing. - -A module to load and preprocess training and validation test data. - -:Authors: Jennifer Pollack and Tobias Liaudat - -""" - -import os -import numpy as np -import wf_psf.utils.utils as utils -import tensorflow as tf -from wf_psf.utils.ccd_misalignments import CCDMisalignmentCalculator -from wf_psf.utils.centroids import compute_zernike_tip_tilt -from fractions import Fraction -import logging - -logger = logging.getLogger(__name__) - - -class DataHandler: - """Data Handler. - - This class manages loading and processing of training and testing data for use during PSF model training and validation. - It provides methods to access and preprocess the data. - - Parameters - ---------- - dataset_type: str - A string indicating type of data ("train" or "test"). - data_params: Recursive Namespace object - Recursive Namespace object containing training data parameters - simPSF: PSFSimulator - An instance of the PSFSimulator class for simulating a PSF. - n_bins_lambda: int - The number of bins in wavelength. - load_data: bool, optional - A flag used to control data loading steps. If True, data is loaded and processed - during initialization. If False, data loading is deferred until explicitly called. - - Attributes - ---------- - dataset_type: str - A string indicating the type of dataset ("train" or "test"). - data_params: Recursive Namespace object - A Recursive Namespace object containing training or test data parameters. - dataset: dict - A dictionary containing the loaded dataset, including positions and stars/noisy_stars. - simPSF: object - An instance of the SimPSFToolkit class for simulating PSF. - n_bins_lambda: int - The number of bins in wavelength. - sed_data: tf.Tensor - A TensorFlow tensor containing the SED data for training/testing. - load_data_on_init: bool, optional - A flag used to control data loading steps. If True, data is loaded and processed - during initialization. If False, data loading is deferred until explicitly called. - """ - - def __init__(self, dataset_type, data_params, simPSF, n_bins_lambda, load_data: bool=True): - """ - Initialize the dataset handler for PSF simulation. - - Parameters - ---------- - dataset_type : str - A string indicating the type of data ("train" or "test"). - data_params : Recursive Namespace object - A Recursive Namespace object containing parameters for both 'train' and 'test' datasets. - simPSF : PSFSimulator - An instance of the PSFSimulator class for simulating a PSF. - n_bins_lambda : int - The number of bins in wavelength. - load_data : bool, optional - A flag to control whether data should be loaded and processed during initialization. - If True, data is loaded and processed during initialization; if False, data loading - is deferred until explicitly called. - """ - self.dataset_type = dataset_type - self.data_params = data_params.__dict__[dataset_type] - self.simPSF = simPSF - self.n_bins_lambda = n_bins_lambda - self.dataset = None - self.sed_data = None - self.load_data_on_init = load_data - if self.load_data_on_init: - self.load_dataset() - self.process_sed_data() - - - def load_dataset(self): - """Load dataset. - - Load the dataset based on the specified dataset type. - - """ - self.dataset = np.load( - os.path.join(self.data_params.data_dir, self.data_params.file), - allow_pickle=True, - )[()] - self.dataset["positions"] = tf.convert_to_tensor( - self.dataset["positions"], dtype=tf.float32 - ) - if self.dataset_type == "training": - if "noisy_stars" in self.dataset: - self.dataset["noisy_stars"] = tf.convert_to_tensor( - self.dataset["noisy_stars"], dtype=tf.float32 - ) - else: - logger.warning(f"Missing 'noisy_stars' in {self.dataset_type} dataset.") - elif self.dataset_type == "test": - if "stars" in self.dataset: - self.dataset["stars"] = tf.convert_to_tensor( - self.dataset["stars"], dtype=tf.float32 - ) - else: - logger.warning(f"Missing 'stars' in {self.dataset_type} dataset.") - - - def process_sed_data(self): - """Process SED Data. - - A method to generate and process SED data. - - """ - self.sed_data = [ - utils.generate_SED_elems_in_tensorflow( - _sed, self.simPSF, n_bins=self.n_bins_lambda, tf_dtype=tf.float64 - ) - for _sed in self.dataset["SEDs"] - ] - self.sed_data = tf.convert_to_tensor(self.sed_data, dtype=tf.float32) - self.sed_data = tf.transpose(self.sed_data, perm=[0, 2, 1]) - - -def get_np_obs_positions(data): - """Get observed positions in numpy from the provided dataset. - - This method concatenates the positions of the stars from both the training - and test datasets to obtain the observed positions. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - np.ndarray - Numpy array containing the observed positions of the stars. - - Notes - ----- - The observed positions are obtained by concatenating the positions of stars - from both the training and test datasets along the 0th axis. - """ - obs_positions = np.concatenate( - ( - data.training_data.dataset["positions"], - data.test_data.dataset["positions"], - ), - axis=0, - ) - - return obs_positions - - -def get_obs_positions(data): - """Get observed positions from the provided dataset. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - tf.Tensor - Tensor containing the observed positions of the stars. - """ - obs_positions = get_np_obs_positions(data) - - return tf.convert_to_tensor(obs_positions, dtype=tf.float32) - - -def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: - """Extract specific star-related data from training and test datasets. - - This function retrieves and concatenates specific star-related data (e.g., stars, masks) from the - star training and test datasets such as star stamps or masks, based on the provided keys. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - train_key : str - The key to retrieve data from the training dataset (e.g., 'noisy_stars', 'masks'). - test_key : str - The key to retrieve data from the test dataset (e.g., 'stars', 'masks'). - - Returns - ------- - np.ndarray - A NumPy array containing the concatenated data for the given keys. - - Raises - ------ - KeyError - If the specified keys do not exist in the training or test datasets. - - Notes - ----- - - If the dataset contains TensorFlow tensors, they will be converted to NumPy arrays. - - Ensure that eager execution is enabled when calling this function. - """ - # Ensure the requested keys exist in both training and test datasets - missing_keys = [ - key for key, dataset in [(train_key, data.training_data.dataset), (test_key, data.test_data.dataset)] - if key not in dataset - ] - - if missing_keys: - raise KeyError(f"Missing keys in dataset: {missing_keys}") - - # Retrieve data from training and test sets - train_data = data.training_data.dataset[train_key] - test_data = data.test_data.dataset[test_key] - - # Convert to NumPy if necessary - if tf.is_tensor(train_data): - train_data = train_data.numpy() - if tf.is_tensor(test_data): - test_data = test_data.numpy() - - # Concatenate and return - return np.concatenate((train_data, test_data), axis=0) - - -def get_np_zernike_prior(data): - """Get the zernike prior from the provided dataset. - - This method concatenates the stars from both the training - and test datasets to obtain the full prior. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - zernike_prior : np.ndarray - Numpy array containing the full prior. - """ - zernike_prior = np.concatenate( - ( - data.training_data.dataset["zernike_prior"], - data.test_data.dataset["zernike_prior"], - ), - axis=0, - ) - - return zernike_prior - - -def compute_centroid_correction(model_params, data, batch_size: int=1) -> np.ndarray: - """Compute centroid corrections using Zernike polynomials. - - This function calculates the Zernike contributions required to match the centroid - of the WaveDiff PSF model to the observed star centroids, processing in batches. - - Parameters - ---------- - model_params : RecursiveNamespace - An object containing parameters for the PSF model, including pixel sampling and initial centroid window parameters. - - data : DataConfigHandler - An object containing training and test datasets, including observed PSFs - and optional star masks. - - batch_size : int, optional - The batch size to use when processing the stars. Default is 16. - - - Returns - ------- - zernike_centroid_array : np.ndarray - A 2D NumPy array of shape `(n_stars, 3)`, where `n_stars` is the number of - observed stars. The array contains the computed Zernike contributions, - with zero padding applied to the first column to ensure a consistent shape. - """ - star_postage_stamps = extract_star_data(data=data, train_key="noisy_stars", test_key="stars") - - # Get star mask catalogue only if "masks" exist in both training and test datasets - star_masks = ( - extract_star_data(data=data, train_key="masks", test_key="masks") - if ( - data.training_data.dataset.get("masks") is not None - and data.test_data.dataset.get("masks") is not None - and tf.size(data.training_data.dataset["masks"]) > 0 - and tf.size(data.test_data.dataset["masks"]) > 0 - ) - else None - ) - - pix_sampling = model_params.pix_sampling * 1e-6 # Change units from [um] to [m] - - # Ensure star_masks is properly handled - star_masks = star_masks if star_masks is not None else None - - reference_shifts = [float(Fraction(value)) for value in model_params.reference_shifts] - - n_stars = len(star_postage_stamps) - zernike_centroid_array = [] - - # Batch process the stars - for i in range(0, n_stars, batch_size): - batch_postage_stamps = star_postage_stamps[i:i + batch_size] - batch_masks = star_masks[i:i + batch_size] if star_masks is not None else None - - # Compute Zernike 1 and Zernike 2 for the batch - zk1_2_batch = -1.0 * compute_zernike_tip_tilt( - batch_postage_stamps, batch_masks, pix_sampling, reference_shifts - ) - - # Zero pad array for each batch and append - zernike_centroid_array.append(np.pad(zk1_2_batch, pad_width=[(0, 0), (1, 0)], mode="constant", constant_values=0)) - - # Combine all batches into a single array - return np.concatenate(zernike_centroid_array, axis=0) - - -def compute_ccd_misalignment(model_params, data): - """Compute CCD misalignment. - - Parameters - ---------- - model_params : RecursiveNamespace - Object containing parameters for this PSF model class. - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - zernike_ccd_misalignment_array : np.ndarray - Numpy array containing the Zernike contributions to model the CCD chip misalignments. - """ - obs_positions = get_np_obs_positions(data) - - ccd_misalignment_calculator = CCDMisalignmentCalculator( - tiles_path=model_params.ccd_misalignments_input_path, - x_lims=model_params.x_lims, - y_lims=model_params.y_lims, - tel_focal_length=model_params.tel_focal_length, - tel_diameter=model_params.tel_diameter, - ) - # Compute required zernike 4 for each position - zk4_values = np.array( - [ - ccd_misalignment_calculator.get_zk4_from_position(single_pos) - for single_pos in obs_positions - ] - ).reshape(-1, 1) - - # Zero pad array to get shape (n_stars, n_zernike=4) - zernike_ccd_misalignment_array = np.pad( - zk4_values, pad_width=[(0, 0), (3, 0)], mode="constant", constant_values=0 - ) - - return zernike_ccd_misalignment_array - - -def get_zernike_prior(model_params, data, batch_size: int=16): - """Get Zernike priors from the provided dataset. - - This method concatenates the Zernike priors from both the training - and test datasets. - - Parameters - ---------- - model_params : RecursiveNamespace - Object containing parameters for this PSF model class. - data : DataConfigHandler - Object containing training and test datasets. - batch_size : int, optional - The batch size to use when processing the stars. Default is 16. - - Returns - ------- - tf.Tensor - Tensor containing the observed positions of the stars. - - Notes - ----- - The Zernike prior are obtained by concatenating the Zernike priors - from both the training and test datasets along the 0th axis. - - """ - # List of zernike contribution - zernike_contribution_list = [] - - if model_params.use_prior: - logger.info("Reading in Zernike prior into Zernike contribution list...") - zernike_contribution_list.append(get_np_zernike_prior(data)) - - if model_params.correct_centroids: - logger.info("Adding centroid correction to Zernike contribution list...") - zernike_contribution_list.append( - compute_centroid_correction(model_params, data, batch_size) - ) - - if model_params.add_ccd_misalignments: - logger.info("Adding CCD mis-alignments to Zernike contribution list...") - zernike_contribution_list.append(compute_ccd_misalignment(model_params, data)) - - if len(zernike_contribution_list) == 1: - zernike_contribution = zernike_contribution_list[0] - else: - # Get max zk order - max_zk_order = np.max( - np.array( - [ - zk_contribution.shape[1] - for zk_contribution in zernike_contribution_list - ] - ) - ) - - zernike_contribution = np.zeros( - (zernike_contribution_list[0].shape[0], max_zk_order) - ) - - # Pad arrays to get the same length and add the final contribution - for it in range(len(zernike_contribution_list)): - current_zk_order = zernike_contribution_list[it].shape[1] - current_zernike_contribution = np.pad( - zernike_contribution_list[it], - pad_width=[(0, 0), (0, int(max_zk_order - current_zk_order))], - mode="constant", - constant_values=0, - ) - - zernike_contribution += current_zernike_contribution - - return tf.convert_to_tensor(zernike_contribution, dtype=tf.float32) diff --git a/src/wf_psf/inference/__init__.py b/src/wf_psf/inference/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py new file mode 100644 index 00000000..e660154b --- /dev/null +++ b/src/wf_psf/inference/psf_inference.py @@ -0,0 +1,407 @@ +"""Inference. + +A module which provides a PSFInference class to perform inference +with trained PSF models. It is able to load a trained model, +perform inference on a dataset of SEDs and positions, and generate polychromatic PSFs. + +:Authors: Jennifer Pollack , Tobias Liaudat + +""" + +import os +from pathlib import Path +import numpy as np +from wf_psf.data.data_handler import DataHandler +from wf_psf.utils.read_config import read_conf +from wf_psf.utils.utils import ensure_batch +from wf_psf.psf_models import psf_models +from wf_psf.psf_models.psf_model_loader import load_trained_psf_model +import tensorflow as tf + + +class InferenceConfigHandler: + ids = ("inference_conf",) + + def __init__(self, inference_config_path: str): + self.inference_config_path = inference_config_path + self.inference_config = None + self.training_config = None + self.data_config = None + + def load_configs(self): + """Load configuration files based on the inference config.""" + self.inference_config = read_conf(self.inference_config_path) + self.set_config_paths() + self.training_config = read_conf(self.trained_model_config_path) + + if self.data_config_path is not None: + # Load the data configuration + self.data_config = read_conf(self.data_config_path) + + def set_config_paths(self): + """Extract and set the configuration paths.""" + # Set config paths + config_paths = self.inference_config.inference.configs + + self.trained_model_path = Path(config_paths.trained_model_path) + self.model_subdir = config_paths.model_subdir + self.trained_model_config_path = ( + self.trained_model_path / config_paths.trained_model_config_path + ) + self.data_config_path = config_paths.data_config_path + + @staticmethod + def overwrite_model_params(training_config=None, inference_config=None): + """ + Overwrite training model_params with values from inference_config if available. + + Parameters + ---------- + training_config : RecursiveNamespace + Configuration object from training phase. + inference_config : RecursiveNamespace + Configuration object from inference phase. + + Notes + ----- + Updates are applied in-place to training_config.training.model_params. + """ + model_params = training_config.training.model_params + inf_model_params = inference_config.inference.model_params + + if model_params and inf_model_params: + for key, value in inf_model_params.__dict__.items(): + if hasattr(model_params, key): + setattr(model_params, key, value) + + +class PSFInference: + """ + Perform PSF inference using a pre-trained WaveDiff model. + + This class handles the setup for PSF inference, including loading configuration + files, instantiating the PSF simulator and data handler, and preparing the + input data required for inference. + + Parameters + ---------- + inference_config_path : str + Path to the inference configuration YAML file. + x_field : array-like, optional + x coordinates in SHE convention. + y_field : array-like, optional + y coordinates in SHE convention. + seds : array-like, optional + Spectral energy distributions (SEDs). + sources : array-like, optional + Postage stamps of sources, e.g. star images (shape: [n_stars, h, w]). + masks : array-like, optional + Corresponding masks for the sources (same shape as sources). Defaults to None. + """ + + def __init__( + self, + inference_config_path: str, + x_field=None, + y_field=None, + seds=None, + sources=None, + masks=None, + ): + + self.inference_config_path = inference_config_path + + # Inputs for the model + self.x_field = x_field + self.y_field = y_field + self.seds = seds + self.sources = sources + self.masks = masks + + # Internal caches for lazy-loading + self._config_handler = None + self._simPSF = None + self._data_handler = None + self._trained_psf_model = None + self._n_bins_lambda = None + self._batch_size = None + self._cycle = None + self._output_dim = None + + # Initialise PSF Inference engine + self.engine = None + + @property + def config_handler(self): + if self._config_handler is None: + self._config_handler = InferenceConfigHandler(self.inference_config_path) + self._config_handler.load_configs() + return self._config_handler + + def prepare_configs(self): + """Prepare the configuration for inference.""" + # Overwrite model parameters with inference config + self.config_handler.overwrite_model_params( + self.training_config, self.inference_config + ) + + @property + def inference_config(self): + return self.config_handler.inference_config + + @property + def training_config(self): + return self.config_handler.training_config + + @property + def data_config(self): + return self.config_handler.data_config + + @property + def simPSF(self): + if self._simPSF is None: + self._simPSF = psf_models.simPSF(self.training_config.training.model_params) + return self._simPSF + + def _prepare_dataset_for_inference(self): + """Prepare dataset dictionary for inference, returning None if positions are invalid.""" + positions = self.get_positions() + if positions is None: + return None + return {"positions": positions, "sources": self.sources, "masks": self.masks} + + @property + def data_handler(self): + if self._data_handler is None: + # Instantiate the data handler + self._data_handler = DataHandler( + dataset_type="inference", + data_params=self.data_config, + simPSF=self.simPSF, + n_bins_lambda=self.n_bins_lambda, + load_data=False, + dataset=self._prepare_dataset_for_inference(), + sed_data=self.seds, + ) + self._data_handler.run_type = "inference" + return self._data_handler + + @property + def trained_psf_model(self): + if self._trained_psf_model is None: + self._trained_psf_model = self.load_inference_model() + return self._trained_psf_model + + def get_positions(self): + """ + Combine x_field and y_field into position pairs. + + Returns + ------- + numpy.ndarray + Array of shape (num_positions, 2) where each row contains [x, y] coordinates. + Returns None if either x_field or y_field is None. + + Raises + ------ + ValueError + If x_field and y_field have different lengths. + """ + if self.x_field is None or self.y_field is None: + return None + + x_arr = np.asarray(self.x_field) + y_arr = np.asarray(self.y_field) + + if x_arr.size == 0 or y_arr.size == 0: + return None + + if x_arr.size != y_arr.size: + raise ValueError( + f"x_field and y_field must have the same length. " + f"Got {x_arr.size} and {y_arr.size}" + ) + + # Flatten arrays to handle any input shape, then stack + x_flat = x_arr.flatten() + y_flat = y_arr.flatten() + + return np.column_stack((x_flat, y_flat)) + + def load_inference_model(self): + """Load the trained PSF model based on the inference configuration.""" + model_path = self.config_handler.trained_model_path + model_dir = self.config_handler.model_subdir + model_name = self.training_config.training.model_params.model_name + id_name = self.training_config.training.id_name + + weights_path_pattern = os.path.join( + model_path, + model_dir, + f"{model_dir}*_{model_name}*{id_name}_cycle{self.cycle}*", + ) + + # Load the trained PSF model + return load_trained_psf_model( + self.training_config, + self.data_handler, + weights_path_pattern, + ) + + @property + def n_bins_lambda(self): + if self._n_bins_lambda is None: + self._n_bins_lambda = ( + self.inference_config.inference.model_params.n_bins_lda + ) + return self._n_bins_lambda + + @property + def batch_size(self): + if self._batch_size is None: + self._batch_size = self.inference_config.inference.batch_size + assert self._batch_size > 0, "Batch size must be greater than 0." + return self._batch_size + + @property + def cycle(self): + if self._cycle is None: + self._cycle = self.inference_config.inference.cycle + return self._cycle + + @property + def output_dim(self): + if self._output_dim is None: + self._output_dim = self.inference_config.inference.model_params.output_dim + return self._output_dim + + def _prepare_positions_and_seds(self): + """ + Preprocess and return tensors for positions and SEDs with consistent shapes. + + Handles single-star, multi-star, and even scalar inputs, ensuring: + - positions: shape (n_samples, 2) + - sed_data: shape (n_samples, n_bins, 2) + """ + # Ensure x_field and y_field are at least 1D + x_arr = np.atleast_1d(self.x_field) + y_arr = np.atleast_1d(self.y_field) + + if x_arr.size != y_arr.size: + raise ValueError( + f"x_field and y_field must have the same length. " + f"Got {x_arr.size} and {y_arr.size}" + ) + + # Combine into positions array (n_samples, 2) + positions = np.column_stack((x_arr, y_arr)) + positions = tf.convert_to_tensor(positions, dtype=tf.float32) + + # Ensure SEDs have shape (n_samples, n_bins, 2) + sed_data = ensure_batch(self.seds) + + if sed_data.shape[0] != positions.shape[0]: + raise ValueError( + f"SEDs batch size {sed_data.shape[0]} does not match number of positions {positions.shape[0]}" + ) + + if sed_data.shape[2] != 2: + raise ValueError( + f"SEDs last dimension must be 2 (flux, wavelength). Got {sed_data.shape}" + ) + + # Process SEDs through the data handler + self.data_handler.process_sed_data(sed_data) + sed_data_tensor = self.data_handler.sed_data + + return positions, sed_data_tensor + + def run_inference(self): + """Run PSF inference and return the full PSF array.""" + # Prepare the configuration for inference + self.prepare_configs() + + # Prepare positions and SEDs for inference + positions, sed_data = self._prepare_positions_and_seds() + + self.engine = PSFInferenceEngine( + trained_model=self.trained_psf_model, + batch_size=self.batch_size, + output_dim=self.output_dim, + ) + return self.engine.compute_psfs(positions, sed_data) + + def _ensure_psf_inference_completed(self): + if self.engine is None or self.engine.inferred_psfs is None: + self.run_inference() + + def get_psfs(self): + self._ensure_psf_inference_completed() + return self.engine.get_psfs() + + def get_psf(self, index: int = 0) -> np.ndarray: + """ + Get the PSF at a specific index. + + If only a single star was passed during instantiation, the index defaults to 0. + """ + self._ensure_psf_inference_completed() + + inferred_psfs = self.engine.get_psfs() + + # If a single-star batch, ignore index bounds + if inferred_psfs.shape[0] == 1: + return inferred_psfs[0] + + # Otherwise, return the PSF at the requested index + return inferred_psfs[index] + + +class PSFInferenceEngine: + def __init__(self, trained_model, batch_size: int, output_dim: int): + self.trained_model = trained_model + self.batch_size = batch_size + self.output_dim = output_dim + self._inferred_psfs = None + + @property + def inferred_psfs(self) -> np.ndarray: + """Access the cached inferred PSFs, if available.""" + return self._inferred_psfs + + def compute_psfs(self, positions: tf.Tensor, sed_data: tf.Tensor) -> np.ndarray: + """Compute and cache PSFs for the input source parameters.""" + n_samples = positions.shape[0] + self._inferred_psfs = np.zeros( + (n_samples, self.output_dim, self.output_dim), dtype=np.float32 + ) + + # Initialize counter + counter = 0 + while counter < n_samples: + # Calculate the batch end element + end_sample = min(counter + self.batch_size, n_samples) + + # Define the batch positions + batch_pos = positions[counter:end_sample, :] + batch_seds = sed_data[counter:end_sample, :, :] + batch_inputs = [batch_pos, batch_seds] + # Generate PSFs for the current batch + batch_psfs = self.trained_model(batch_inputs, training=False) + self.inferred_psfs[counter:end_sample, :, :] = batch_psfs.numpy() + + # Update the counter + counter = end_sample + return self._inferred_psfs + + def get_psfs(self) -> np.ndarray: + """Get all the generated PSFs.""" + if self._inferred_psfs is None: + raise ValueError("PSFs not yet computed. Call compute_psfs() first.") + return self._inferred_psfs + + def get_psf(self, index: int) -> np.ndarray: + """Get the PSF at a specific index.""" + if self._inferred_psfs is None: + raise ValueError("PSFs not yet computed. Call compute_psfs() first.") + return self._inferred_psfs[index] diff --git a/src/wf_psf/instrument/__init__.py b/src/wf_psf/instrument/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/wf_psf/utils/ccd_misalignments.py b/src/wf_psf/instrument/ccd_misalignments.py similarity index 86% rename from src/wf_psf/utils/ccd_misalignments.py rename to src/wf_psf/instrument/ccd_misalignments.py index 0f51c32f..6b73babe 100644 --- a/src/wf_psf/utils/ccd_misalignments.py +++ b/src/wf_psf/instrument/ccd_misalignments.py @@ -10,7 +10,45 @@ import numpy as np import matplotlib.path as mpltPath from scipy.spatial import KDTree -from wf_psf.utils.preprocessing import defocus_to_zk4_wavediff + + +def compute_ccd_misalignment(model_params, positions: np.ndarray) -> np.ndarray: + """Compute CCD misalignment. + + Parameters + ---------- + model_params : RecursiveNamespace + Object containing parameters for this PSF model class. + positions : np.ndarray + Numpy array containing the positions of the stars in the focal plane. + Shape: (n_stars, 2), where n_stars is the number of stars and 2 corresponds to x and y coordinates. + + Returns + ------- + zernike_ccd_misalignment_array : np.ndarray + Numpy array containing the Zernike contributions to model the CCD chip misalignments. + """ + ccd_misalignment_calculator = CCDMisalignmentCalculator( + tiles_path=model_params.ccd_misalignments_input_path, + x_lims=model_params.x_lims, + y_lims=model_params.y_lims, + tel_focal_length=model_params.tel_focal_length, + tel_diameter=model_params.tel_diameter, + ) + # Compute required zernike 4 for each position + zk4_values = np.array( + [ + ccd_misalignment_calculator.get_zk4_from_position(single_pos) + for single_pos in positions + ] + ).reshape(-1, 1) + + # Zero pad array to get shape (n_stars, n_zernike=4) + zernike_ccd_misalignment_array = np.pad( + zk4_values, pad_width=[(0, 0), (3, 0)], mode="constant", constant_values=0 + ) + + return zernike_ccd_misalignment_array class CCDMisalignmentCalculator: @@ -18,8 +56,8 @@ class CCDMisalignmentCalculator: This class processes and analyzes CCD misalignment data using tile position information. - The `tiles_data` array is a data cube where each slice is a 4×3 matrix representing - the four corners of a tile. The first two columns correspond to x/y coordinates (in mm), + The `tiles_data` array is a data cube where each slice is a 4×3 matrix representing + the four corners of a tile. The first two columns correspond to x/y coordinates (in mm), and the third column represents z displacement (in µm). Parameters @@ -62,6 +100,7 @@ class CCDMisalignmentCalculator: d_list : np.ndarray List of plane offset values for CCD planes. """ + def __init__( self, tiles_path: str, @@ -82,7 +121,11 @@ def __init__( raise ValueError("Tile data must have three coordinate columns (x, y, z).") # Initialize attributes - self.tiles_x_lims, self.tiles_y_lims, self.tiles_z_lims = np.zeros(2), np.zeros(2), np.zeros(2) + self.tiles_x_lims, self.tiles_y_lims, self.tiles_z_lims = ( + np.zeros(2), + np.zeros(2), + np.zeros(2), + ) self.tiles_z_average: float = 0.0 self.ccd_polygons: list[mpltPath.Path] = [] @@ -94,7 +137,6 @@ def __init__( self._initialize() - def _initialize(self) -> None: """Run all required initialization steps.""" self._preprocess_tile_data() @@ -104,16 +146,20 @@ def _initialize(self) -> None: def _preprocess_tile_data(self) -> None: """Preprocess tile data by computing spatial limits and averages.""" - self.tiles_x_lims = np.array([np.min(self.tiles_data[:, 0, :]), np.max(self.tiles_data[:, 0, :])]) - self.tiles_y_lims = np.array([np.min(self.tiles_data[:, 1, :]), np.max(self.tiles_data[:, 1, :])]) - self.tiles_z_lims = np.array([np.min(self.tiles_data[:, 2, :]), np.max(self.tiles_data[:, 2, :])]) + self.tiles_x_lims = np.array( + [np.min(self.tiles_data[:, 0, :]), np.max(self.tiles_data[:, 0, :])] + ) + self.tiles_y_lims = np.array( + [np.min(self.tiles_data[:, 1, :]), np.max(self.tiles_data[:, 1, :])] + ) + self.tiles_z_lims = np.array( + [np.min(self.tiles_data[:, 2, :]), np.max(self.tiles_data[:, 2, :])] + ) self.tiles_z_average = np.mean(self.tiles_z_lims) - def _initialize_polygons(self): """Initialize polygons to look for CCD IDs""" - # Build polygon list corresponding to each CCD self.ccd_polygons = [] @@ -180,7 +226,6 @@ def _precompute_CCD_planes(self): self.normal_list.append(normal) self.d_list.append(d) - def scale_position_to_tile_reference(self, pos): """Scale input position into tiles coordinate system. @@ -190,7 +235,6 @@ def scale_position_to_tile_reference(self, pos): Focal plane position in wavediff coordinate system respecting `self.x_lims` and `self.y_lims`. Shape: (2,) """ - self.check_position_wavediff_limits(pos) pos_x = pos[0] @@ -210,7 +254,6 @@ def scale_position_to_tile_reference(self, pos): return np.array([scaled_x, scaled_y]) - def scale_position_to_wavediff_reference(self, pos): """Scale input position into wavediff coordinate system. @@ -219,7 +262,6 @@ def scale_position_to_wavediff_reference(self, pos): pos : np.ndarray Tile position in input tile coordinate system. Shape: (2,) """ - self.check_position_tile_limits(pos) pos_x = pos[0] @@ -239,7 +281,6 @@ def scale_position_to_wavediff_reference(self, pos): def check_position_wavediff_limits(self, pos): """Check if position is within wavediff limits.""" - if (pos[0] < self.x_lims[0] or pos[0] > self.x_lims[1]) or ( pos[1] < self.y_lims[0] or pos[1] > self.y_lims[1] ): @@ -249,14 +290,12 @@ def check_position_wavediff_limits(self, pos): def check_position_tile_limits(self, pos): """Check if position is within tile limits.""" - if (pos[0] < self.tiles_x_lims[0] or pos[0] > self.tiles_x_lims[1]) or ( pos[1] < self.tiles_y_lims[0] or pos[1] > self.tiles_y_lims[1] ): raise ValueError( "Input position is not within the tile focal plane limits." ) - def get_ccd_from_position(self, pos): """Get CCD ID from the position. @@ -299,7 +338,6 @@ def get_ccd_from_position(self, pos): return ccd_id - def get_dz_from_position(self, pos): """Get z-axis displacement for a focal plane position. @@ -328,7 +366,6 @@ def get_dz_from_position(self, pos): return dz - def get_zk4_from_position(self, pos): """Get defocus Zernike contribution from focal plane position. @@ -343,12 +380,12 @@ def get_zk4_from_position(self, pos): Zernike 4 value in wavediff convention corresponding to the delta z of the given input position `pos`. """ + from wf_psf.data.data_zernike_utils import defocus_to_zk4_wavediff dz = self.get_dz_from_position(pos) return defocus_to_zk4_wavediff(dz, self.tel_focal_length, self.tel_diameter) - @staticmethod def compute_z_from_plane_data(pos, normal, d): """Compute z value from plane data. @@ -373,12 +410,10 @@ def compute_z_from_plane_data(pos, normal, d): d : np.ndarray `d` value from the plane ecuation. Shape (3,) """ - z = (-normal[0] * pos[0] - normal[1] * pos[1] - d) * 1.0 / normal[2] return z - @staticmethod def check_position_format(pos): if type(pos) is list: diff --git a/src/wf_psf/metrics/metrics.py b/src/wf_psf/metrics/metrics.py index 3bf671a4..76895cc7 100644 --- a/src/wf_psf/metrics/metrics.py +++ b/src/wf_psf/metrics/metrics.py @@ -14,6 +14,7 @@ import wf_psf.utils.utils as utils from wf_psf.psf_models.psf_models import build_PSF_model from wf_psf.sims import psf_simulator as psf_simulator +from wf_psf.training.train_utils import compute_noise_std_from_stars import logging logger = logging.getLogger(__name__) @@ -92,7 +93,9 @@ def compute_poly_metric( preds = tf_semiparam_field.predict(x=pred_inputs, batch_size=batch_size) # Ground truth data preparation - if dataset_dict is None or "stars" not in dataset_dict: + if dataset_dict is None or ( + "stars" not in dataset_dict and "noisy_stars" not in dataset_dict + ): logger.info( "No precomputed ground truth stars found. Regenerating from the ground truth model using configured interpolation settings." ) @@ -112,8 +115,16 @@ def compute_poly_metric( gt_preds = gt_tf_semiparam_field.predict(x=pred_inputs, batch_size=batch_size) else: - logger.info("Using precomputed ground truth stars from dataset_dict['stars'].") - gt_preds = dataset_dict["stars"] + if "stars" in dataset_dict: + gt_preds = dataset_dict["stars"] + logger.info( + "Using precomputed ground truth stars from dataset_dict['stars']." + ) + elif "noisy_stars" in dataset_dict: + gt_preds = dataset_dict["noisy_stars"] + logger.info( + "Using precomputed noisy ground truth stars from dataset_dict['noisy_stars']." + ) # If the data is masked, mask the predictions if mask: @@ -157,6 +168,204 @@ def compute_poly_metric( return rmse, rel_rmse, std_rmse, std_rel_rmse +def compute_chi2_metric( + tf_trained_psf_model, + gt_tf_psf_model, + simPSF_np, + tf_pos, + tf_SEDs, + n_bins_lda=20, + n_bins_gt=20, + batch_size=16, + dataset_dict=None, + mask=False, +): + """Calculate the chi2 metric for polychromatic reconstructions at observation resolution. + + The ``tf_trained_psf_model`` should be the model to evaluate, and the + ``gt_tf_psf_model`` should be loaded with the ground truth PSF field. + + Parameters + ---------- + tf_trained_psf_model: PSF field object + Trained model to evaluate. + gt_tf_psf_model: PSF field object + Ground truth model to produce gt observations at any position + and wavelength. + simPSF_np: PSF simulator object + Simulation object to be used by ``generate_packed_elems`` function. + tf_pos: Tensor or numpy.ndarray [batch x 2] floats + Positions to evaluate the model. + tf_SEDs: numpy.ndarray [batch x SED_samples x 2] + SED samples for the corresponding positions. + n_bins_lda: int + Number of wavelength bins to use for the polychromatic PSF. + n_bins_gt: int + Number of wavelength bins to use for the ground truth polychromatic PSF. + batch_size: int + Batch size for the PSF calcualtions. + dataset_dict: dict + Dictionary containing the dataset information. If provided, and if the `'stars'` key + is present, the noiseless stars from the dataset are used to compute the metrics. + Otherwise, the stars are generated from the gt model. + Default is `None`. + mask: bool + If `True`, predictions are masked using the same mask as the target data, ensuring + that metric calculations consider only unmasked regions. + Default is `False`. + + Returns + ------- + reduced_chi2_stat: float + Reduced chi squared value. + avg_noise_std_dev: float + Average estimated noise standard deviation used for the chi squared calculation. + + """ + # Create flag + noiseless_stars = False + + # Generate SED data list for the model + packed_SED_data = [ + utils.generate_packed_elems(_sed, simPSF_np, n_bins=n_bins_lda) + for _sed in tf_SEDs + ] + tf_packed_SED_data = tf.convert_to_tensor(packed_SED_data, dtype=tf.float32) + tf_packed_SED_data = tf.transpose(tf_packed_SED_data, perm=[0, 2, 1]) + pred_inputs = [tf_pos, tf_packed_SED_data] + + # Model prediction + preds = tf_trained_psf_model.predict(x=pred_inputs, batch_size=batch_size) + + # Ground truth data preparation + if dataset_dict is None or ( + "stars" not in dataset_dict and "noisy_stars" not in dataset_dict + ): + logger.info( + "No precomputed ground truth stars found. Regenerating from the ground truth model using configured interpolation settings." + ) + # The stars will be noiseless as we are recreating them from the ground truth model + noiseless_stars = True + + # Change interpolation parameters for the ground truth simPSF + simPSF_np.SED_interp_pts_per_bin = 0 + simPSF_np.SED_sigma = 0 + # Generate SED data list for gt model + packed_SED_data = [ + utils.generate_packed_elems(_sed, simPSF_np, n_bins=n_bins_gt) + for _sed in tf_SEDs + ] + tf_packed_SED_data = tf.convert_to_tensor(packed_SED_data, dtype=tf.float32) + tf_packed_SED_data = tf.transpose(tf_packed_SED_data, perm=[0, 2, 1]) + pred_inputs = [tf_pos, tf_packed_SED_data] + + # Ground Truth model prediction + reference_stars = gt_tf_psf_model.predict(x=pred_inputs, batch_size=batch_size) + + else: + if "stars" in dataset_dict: + reference_stars = dataset_dict["stars"] + logger.info( + "Using precomputed ground truth stars from dataset_dict['stars']." + ) + elif "noisy_stars" in dataset_dict: + reference_stars = dataset_dict["noisy_stars"] + logger.info( + "Using precomputed noisy ground truth stars from dataset_dict['noisy_stars']." + ) + + # If the data is masked, mask the predictions + if mask: + logger.info( + "Applying masks to predictions. Only unmasked regions will be considered for metric calculations." + ) + # Change convention + masks = 1 - dataset_dict["masks"] + # Ensure masks as float dtype + masks = masks.astype(preds.dtype) + + else: + # We create a dummy mask of ones + masks = np.ones_like(reference_stars, dtype=preds.dtype) + + # Compute noise standard deviation from the reference stars + if not noiseless_stars: + estimated_noise_std_dev = compute_noise_std_from_stars( + reference_stars, masks.astype(bool) + ) + # Check if there is a zero value + if np.any(estimated_noise_std_dev == 0): + logger.info( + "Chi2 metric calculation: Some estimated standard deviations are zero. Setting them to 1 to avoid division by zero." + ) + estimated_noise_std_dev[estimated_noise_std_dev == 0] = 1.0 + else: + # If the stars are noiseless, we set the std dev to 1 + estimated_noise_std_dev = np.ones(reference_stars.shape[0], dtype=preds.dtype) + logger.info( + "Using noiseless stars for chi2 calculation. Setting all std dev to 1." + ) + + # Compute residuals + residuals = (reference_stars - preds) * masks + + # Standardize residuals -> remove mean and divide by std dev + standardized_residuals = np.array( + [ + (residual - np.sum(residual) / np.sum(mask)) / std_est + for residual, mask, std_est in zip( + residuals, masks, estimated_noise_std_dev + ) + ] + ) + # Per-image reduced chi2 statistic + reduced_chi2_stat_per_image = np.array( + [ + np.sum((standardized_residual * mask) ** 2) / (np.sum(mask)) + for standardized_residual, mask in zip(standardized_residuals, masks) + ] + ) + + # Compute the degrees of freedom and the mean + degrees_of_freedom = np.sum(masks) + mean_standardized_residuals = np.sum(standardized_residuals) / degrees_of_freedom + # The degrees of freedom is reduced by 1 because we're removing the mean (see Cochran's theorem) + reduced_chi2_stat = np.sum( + ((standardized_residuals - mean_standardized_residuals) * masks) ** 2 + ) / (degrees_of_freedom - 1) + + # Compute the average and media values of the noise std deviation + mean_noise_std_dev = np.mean(estimated_noise_std_dev) + median_noise_std_dev = np.median(estimated_noise_std_dev) + + # Compute the average, median and std dev of the reduced chi2 statistic per image + mean_reduced_chi2_stat_per_image = np.mean(reduced_chi2_stat_per_image) + median_reduced_chi2_stat_per_image = np.median(reduced_chi2_stat_per_image) + std_reduced_chi2_stat_per_image = np.std(reduced_chi2_stat_per_image) + + # Print chi2 results + logger.info("Reduced chi2:\t\t\t %.5e" % (reduced_chi2_stat)) + + logger.info("Average chi2 per image:\t\t %.5e" % (mean_reduced_chi2_stat_per_image)) + logger.info( + "Median chi2 per image:\t\t %.5e" % (median_reduced_chi2_stat_per_image) + ) + logger.info("Std dev chi2 per image:\t\t %.5e" % (std_reduced_chi2_stat_per_image)) + + logger.info("Average noise std dev:\t\t %.5e" % (mean_noise_std_dev)) + logger.info("Median noise std dev:\t\t %.5e" % (median_noise_std_dev)) + + return ( + reduced_chi2_stat, + reduced_chi2_stat_per_image, + mean_reduced_chi2_stat_per_image, + median_reduced_chi2_stat_per_image, + std_reduced_chi2_stat_per_image, + mean_noise_std_dev, + estimated_noise_std_dev, + ) + + def compute_mono_metric( tf_semiparam_field, gt_tf_semiparam_field, @@ -216,8 +425,8 @@ def compute_mono_metric( lambda_obs = lambda_list[it] phase_N = simPSF_np.feasible_N(lambda_obs) - residuals = np.zeros((total_samples)) - gt_star_mean = np.zeros((total_samples)) + residuals = np.zeros(total_samples) + gt_star_mean = np.zeros(total_samples) # Total number of epochs n_epochs = int(np.ceil(total_samples / batch_size)) diff --git a/src/wf_psf/metrics/metrics_interface.py b/src/wf_psf/metrics/metrics_interface.py index 027b0213..c06819ef 100644 --- a/src/wf_psf/metrics/metrics_interface.py +++ b/src/wf_psf/metrics/metrics_interface.py @@ -36,15 +36,12 @@ def __init__(self, metrics_params, trained_model): self.metrics_params = metrics_params self.trained_model = trained_model - def evaluate_metrics_polychromatic_lowres(self, - psf_model: Any, - simPSF: Any, - data: Any, - dataset: Dict[str, Any] - ) -> Dict[str, float]: + def evaluate_metrics_polychromatic_lowres( + self, psf_model: Any, simPSF: Any, data: Any, dataset: Dict[str, Any] + ) -> Dict[str, float]: """Evaluate RMSE metrics for low-resolution polychromatic PSF. - This method computes Root Mean Square Error (RMSE) metrics for a + This method computes Root Mean Square Error (RMSE) metrics for a low-resolution polychromatic Point Spread Function (PSF) model. Parameters @@ -62,14 +59,14 @@ def evaluate_metrics_polychromatic_lowres(self, - ``C_poly`` Tensor or None, optional The Zernike coefficient matrix used in generating simulations of the PSF model. This matrix defines the Zernike polynomials up to a given order used to simulate the PSF - field. It may be present in some datasets or only required for some classes. + field. It may be present in some datasets or only required for some classes. If not present or required, the model will proceed without it. Returns ------- dict - A dictionary containing the RMSE, relative RMSE, and their + A dictionary containing the RMSE, relative RMSE, and their corresponding standard deviation values. - ``rmse`` : float @@ -113,17 +110,94 @@ def evaluate_metrics_polychromatic_lowres(self, "std_rel_rmse": std_rel_rmse, } + def evaluate_metrics_chi2( + self, psf_model: Any, simPSF: Any, data: Any, dataset: Dict[str, Any] + ) -> Dict[str, float]: + """Evaluate reduced chi2 metric for low-resolution polychromatic PSF. - def evaluate_metrics_mono_rmse(self, - psf_model: Any, - simPSF: Any, - data: Any, - dataset: Dict[str, Any] - ) -> Dict[str, float]: + This method computes reduced chi2 metric for a + low-resolution polychromatic Point Spread Function (PSF) model. + + Parameters + ---------- + psf_model : object + An instance of the PSF model selected for metrics evaluation. + simPSF : object + An instance of the PSFSimulator. + data : object + A DataConfigHandler object containing training and test datasets. + dataset : dict + Dictionary containing dataset details, including: + - ``SEDs`` (Spectral Energy Distributions) + - ``positions`` (Star positions) + - ``C_poly`` Tensor or None, optional + The Zernike coefficient matrix used in generating simulations of the PSF model. This + matrix defines the Zernike polynomials up to a given order used to simulate the PSF + field. It may be present in some datasets or only required for some classes. + If not present or required, the model will proceed without it. + + + Returns + ------- + dict + A dictionary containing the reduced chi2 statistic and the Average estimated + noise standard deviation used for the chi squared calculation. + + - ``reduced_chi2`` : float + Reduced chi squared value. + - ``mean_noise_std_dev`` : float + Average estimated noise standard deviation used for the chi squared calculation. + + """ + logger.info("Computing polychromatic metrics at low resolution.") + + # Check if testing predictions should be masked + mask = self.trained_model.training_hparams.loss == "mask_mse" + + # Compute metrics + ( + reduced_chi2_stat, + reduced_chi2_stat_per_image, + mean_reduced_chi2_stat_per_image, + median_reduced_chi2_stat_per_image, + std_reduced_chi2_stat_per_image, + mean_noise_std_dev, + estimated_noise_std_dev, + ) = wf_metrics.compute_chi2_metric( + tf_trained_psf_model=psf_model, + gt_tf_psf_model=psf_models.get_psf_model( + self.metrics_params.ground_truth_model.model_params, + self.metrics_params.metrics_hparams, + data, + dataset.get("C_poly", None), # Extract C_poly or default to None + ), + simPSF_np=simPSF, + tf_pos=dataset["positions"], + tf_SEDs=dataset["SEDs"], + n_bins_lda=self.trained_model.model_params.n_bins_lda, + n_bins_gt=self.metrics_params.ground_truth_model.model_params.n_bins_lda, + batch_size=self.metrics_params.metrics_hparams.batch_size, + dataset_dict=dataset, + mask=mask, + ) + + return { + "reduced_chi2": reduced_chi2_stat, + "reduced_chi2_stat_per_image": reduced_chi2_stat_per_image, + "mean_reduced_chi2_stat_per_image": mean_reduced_chi2_stat_per_image, + "median_reduced_chi2_stat_per_image": median_reduced_chi2_stat_per_image, + "std_reduced_chi2_stat_per_image": std_reduced_chi2_stat_per_image, + "mean_noise_std_dev": mean_noise_std_dev, + "estimated_noise_std_dev": estimated_noise_std_dev, + } + + def evaluate_metrics_mono_rmse( + self, psf_model: Any, simPSF: Any, data: Any, dataset: Dict[str, Any] + ) -> Dict[str, float]: """Evaluate RMSE metrics for Monochromatic PSF. - This method computes Root Mean Square Error (RMSE) metrics for a - monochromatic Point Spread Function (PSF) model across a range of + This method computes Root Mean Square Error (RMSE) metrics for a + monochromatic Point Spread Function (PSF) model across a range of wavelengths. Parameters @@ -140,13 +214,13 @@ def evaluate_metrics_mono_rmse(self, - ``C_poly`` (Tensor or None, optional) The Zernike coefficient matrix used in generating simulations of the PSF model. This matrix defines the Zernike polynomials up to a given order used to simulate the PSF - field. It may be present in some datasets or only required for some classes. + field. It may be present in some datasets or only required for some classes. If not present or required, the model will proceed without it. Returns ------- dict - A dictionary containing RMSE, relative RMSE, and their corresponding + A dictionary containing RMSE, relative RMSE, and their corresponding standard deviation values computed over a wavelength range. - ``rmse_lda`` : float @@ -187,17 +261,13 @@ def evaluate_metrics_mono_rmse(self, "std_rmse_lda": std_rmse_lda, "std_rel_rmse_lda": std_rel_rmse_lda, } - - - def evaluate_metrics_opd(self, - psf_model: Any, - simPSF: Any, - data: Any, - dataset: Dict[str, Any] - ) -> Dict[str, float]: + + def evaluate_metrics_opd( + self, psf_model: Any, simPSF: Any, data: Any, dataset: Dict[str, Any] + ) -> Dict[str, float]: """Evaluate Optical Path Difference (OPD) metrics. - - This method computes Root Mean Square Error (RMSE) and relative RMSE + + This method computes Root Mean Square Error (RMSE) and relative RMSE for Optical Path Differences (OPD), along with their standard deviations. Parameters @@ -214,13 +284,13 @@ def evaluate_metrics_opd(self, - ``C_poly`` (Tensor or None, optional) The Zernike coefficient matrix used in generating simulations of the PSF model. This matrix defines the Zernike polynomials up to a given order used to simulate the PSF - field. It may be present in some datasets or only required for some classes. + field. It may be present in some datasets or only required for some classes. If not present or required, the model will proceed without it. Returns ------- dict - A dictionary containing RMSE, relative RMSE, and their corresponding + A dictionary containing RMSE, relative RMSE, and their corresponding standard deviation values for OPD metrics. - ``rmse_opd`` : float @@ -259,16 +329,12 @@ def evaluate_metrics_opd(self, "rel_rmse_std_opd": rel_rmse_std_opd, } - - def evaluate_metrics_shape(self, - psf_model: Any, - simPSF: Any, - data: Any, - dataset: Dict[str, Any] - ) -> Dict[str, float]: + def evaluate_metrics_shape( + self, psf_model: Any, simPSF: Any, data: Any, dataset: Dict[str, Any] + ) -> Dict[str, float]: """Evaluate PSF Shape Metrics. - Computes shape-related metrics for the PSF model, including RMSE, + Computes shape-related metrics for the PSF model, including RMSE, relative RMSE, and their standard deviations. Parameters @@ -286,7 +352,7 @@ def evaluate_metrics_shape(self, - ``C_poly`` (Tensor or None, optional) The Zernike coefficient matrix used in generating simulations of the PSF model. This matrix defines the Zernike polynomials up to a given order used to simulate the PSF - field. It may be present in some datasets or only required for some classes. + field. It may be present in some datasets or only required for some classes. If not present or required, the model will proceed without it. Returns @@ -325,11 +391,10 @@ def evaluate_model( trained_model_params, data, psf_model, - weights_path, metrics_output, ): """Evaluate the trained model on both training and test datasets by computing various metrics. - + The metrics to evaluate are determined by the configuration in `metrics_params` and `metric_evaluation_flags`. Metrics are computed for both the training and test datasets, and results are stored in a dictionary. @@ -344,8 +409,6 @@ def evaluate_model( DataHandler object containing training and test data psf_model: object PSF model object - weights_path: str - Directory location of model weights metrics_output: str Directory location of metrics output @@ -356,8 +419,8 @@ def evaluate_model( try: ## Load datasets # ----------------------------------------------------- - # Get training data - logger.info("Fetching and preprocessing training and test data...") + # Get training and test data + logger.info("Fetching training and test data...") # Initialize metrics_handler metrics_handler = MetricsParamsHandler(metrics_params, trained_model_params) @@ -366,14 +429,6 @@ def evaluate_model( # Prepare np input simPSF_np = data.training_data.simPSF - ## Load the model's weights - try: - logger.info("Loading PSF model weights from {}".format(weights_path)) - psf_model.load_weights(weights_path) - except Exception as e: - logger.exception("An error occurred with the weights_path file: %s", e) - exit() - # Define datasets datasets = {"test": data.test_data.dataset, "train": data.training_data.dataset} @@ -386,6 +441,10 @@ def evaluate_model( "test": True, "train": True, }, + "chi2_metric": { + "test": metrics_params.eval_chi2_metric, + "train": metrics_params.eval_chi2_metric, + }, "mono_metric": { "test": metrics_params.eval_mono_metric, "train": metrics_params.eval_mono_metric, @@ -403,6 +462,7 @@ def evaluate_model( # Define the metric evaluation functions metric_functions = { "poly_metric": metrics_handler.evaluate_metrics_polychromatic_lowres, + "chi2_metric": metrics_handler.evaluate_metrics_chi2, "mono_metric": metrics_handler.evaluate_metrics_mono_rmse, "opd_metric": metrics_handler.evaluate_metrics_opd, "shape_results_dict": metrics_handler.evaluate_metrics_shape, diff --git a/src/wf_psf/plotting/plots_interface.py b/src/wf_psf/plotting/plots_interface.py index 3ab6ab78..613d488e 100644 --- a/src/wf_psf/plotting/plots_interface.py +++ b/src/wf_psf/plotting/plots_interface.py @@ -369,6 +369,7 @@ class ShapeMetricsPlotHandler: """ShapeMetricsPlotHandler class. A class to handle plot parameters shape metrics results. + Parameters ---------- id: str @@ -526,6 +527,7 @@ def get_number_of_stars(metrics): ---------- metrics: dict A dictionary containig the metrics results per run + Returns ------- list_of_stars: list diff --git a/src/wf_psf/psf_models/models/__init__.py b/src/wf_psf/psf_models/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/wf_psf/psf_models/psf_model_parametric.py b/src/wf_psf/psf_models/models/psf_model_parametric.py similarity index 98% rename from src/wf_psf/psf_models/psf_model_parametric.py rename to src/wf_psf/psf_models/models/psf_model_parametric.py index 94f79f40..95a11615 100644 --- a/src/wf_psf/psf_models/psf_model_parametric.py +++ b/src/wf_psf/psf_models/models/psf_model_parametric.py @@ -9,7 +9,7 @@ import tensorflow as tf from wf_psf.psf_models.psf_models import register_psfclass -from wf_psf.psf_models.tf_layers import ( +from wf_psf.psf_models.tf_modules.tf_layers import ( TFPolynomialZernikeField, TFZernikeOPD, TFBatchPolychromaticPSF, @@ -166,7 +166,6 @@ def predict_mono_psfs(self, input_positions, lambda_obs, phase_N): ``simPSF_np = wf_psf.sims.psf_simulator.PSFSimulator(...)`` ``phase_N = simPSF_np.feasible_N(lambda_obs)`` """ - # Initialise the monochromatic PSF batch calculator tf_batch_mono_psf = TFBatchMonochromaticPSF( obscurations=self.obscurations, @@ -205,7 +204,7 @@ def predict_opd(self, input_positions): return opd_maps - def call(self, inputs): + def call(self, inputs, **kwargs): """Define the PSF field forward model. [1] From positions to Zernike coefficients diff --git a/src/wf_psf/psf_models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py similarity index 67% rename from src/wf_psf/psf_models/psf_model_physical_polychromatic.py rename to src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index ad61c1b5..fb8bc902 100644 --- a/src/wf_psf/psf_models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -10,11 +10,14 @@ from typing import Optional import tensorflow as tf from tensorflow.python.keras.engine import data_adapter +from wf_psf.data.data_handler import get_data_array +from wf_psf.data.data_zernike_utils import ( + ZernikeInputsFactory, + assemble_zernike_contributions, + pad_tf_zernikes, +) from wf_psf.psf_models import psf_models as psfm -from wf_psf.utils.read_config import RecursiveNamespace -from wf_psf.utils.configs_handler import DataConfigHandler -from wf_psf.data.training_preprocessing import get_obs_positions, get_zernike_prior -from wf_psf.psf_models.tf_layers import ( +from wf_psf.psf_models.tf_modules.tf_layers import ( TFPolynomialZernikeField, TFZernikeOPD, TFBatchPolychromaticPSF, @@ -22,6 +25,7 @@ TFNonParametricPolynomialVariationsOPD, TFPhysicalLayer, ) +from wf_psf.psf_models.tf_modules.tf_utils import ensure_tensor import logging @@ -97,8 +101,8 @@ def __init__(self, model_params, training_params, data, coeff_mat=None): A Recursive Namespace object containing parameters for this PSF model class. training_params: Recursive Namespace A Recursive Namespace object containing training hyperparameters for this PSF model class. - data: DataConfigHandler - A DataConfigHandler object that provides access to training and test datasets, as well as prior knowledge like Zernike coefficients. + data: DataConfigHandler or dict + A DataConfigHandler object or dict that provides access to single or multiple datasets (e.g. train and test), as well as prior knowledge like Zernike coefficients. coeff_mat: Tensor or None, optional Coefficient matrix defining the parametric PSF field model. @@ -108,204 +112,192 @@ def __init__(self, model_params, training_params, data, coeff_mat=None): Initialized instance of the TFPhysicalPolychromaticField class. """ super().__init__(model_params, training_params, coeff_mat) - self._initialize_parameters_and_layers( - model_params, training_params, data, coeff_mat - ) - - def _initialize_parameters_and_layers( - self, - model_params: RecursiveNamespace, - training_params: RecursiveNamespace, - data: DataConfigHandler, - coeff_mat: Optional[tf.Tensor] = None, - ): - """Initialize Parameters of the PSF model. - - This method sets up the PSF model parameters, observational positions, - Zernike coefficients, and components required for the automatically - differentiable optical forward model. + self.model_params = model_params + self.training_params = training_params + self.data = data + self.run_type = self._get_run_type(data) + self.obs_pos = self.get_obs_pos() - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - training_params: Recursive Namespace - A Recursive Namespace object containing training hyperparameters for this PSF model class. - data: DataConfigHandler object - A DataConfigHandler object providing access to training and tests datasets, as well as prior knowledge like Zernike coefficients. - coeff_mat: Tensor or None, optional - Initialization of the coefficient matrix defining the parametric psf field model. - - Notes - ----- - - Initializes Zernike parameters based on dataset priors. - - Configures the PSF model layers according to `model_params`. - - If `coeff_mat` is provided, the model coefficients are updated accordingly. - """ + # Initialize the model parameters self.output_Q = model_params.output_Q - self.obs_pos = get_obs_positions(data) self.l2_param = model_params.param_hparams.l2_param - # Inputs: Save optimiser history Parametric model features - self.save_optim_history_param = ( - model_params.param_hparams.save_optim_history_param - ) - # Inputs: Save optimiser history NonParameteric model features - self.save_optim_history_nonparam = ( - model_params.nonparam_hparams.save_optim_history_nonparam - ) - self._initialize_zernike_parameters(model_params, data) - self._initialize_layers(model_params, training_params) + self.output_dim = model_params.output_dim - # Initialize the model parameters with non-default value + # Initialise lazy loading of external Zernike prior + self._external_prior = None + + # Set Zernike Polynomial Coefficient Matrix if not None if coeff_mat is not None: self.assign_coeff_matrix(coeff_mat) - def _initialize_zernike_parameters(self, model_params, data): - """Initialize the Zernike parameters. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - data: DataConfigHandler object - A DataConfigHandler object providing access to training and tests datasets, as well as prior knowledge like Zernike coefficients. - """ - self.zks_prior = get_zernike_prior(model_params, data, data.batch_size) - self.n_zks_total = max( - model_params.param_hparams.n_zernikes, - tf.cast(tf.shape(self.zks_prior)[1], tf.int32), - ) - self.zernike_maps = psfm.generate_zernike_maps_3d( - self.n_zks_total, model_params.pupil_diameter + # Compute contributions once eagerly (outside graph) + zks_total_contribution_np = self._assemble_zernike_contributions().numpy() + self._zks_total_contribution = tf.convert_to_tensor( + zks_total_contribution_np, dtype=tf.float32 ) - def _initialize_layers(self, model_params, training_params): - """Initialize the layers of the PSF model. - - This method initializes the layers of the PSF model, including the physical layer, polynomial Zernike field, batch polychromatic layer, and non-parametric OPD layer. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - training_params: Recursive Namespace - A Recursive Namespace object containing training hyperparameters for this PSF model class. - coeff_mat: Tensor or None, optional - Initialization of the coefficient matrix defining the parametric psf field model. - - """ - self._initialize_physical_layer(model_params) - self._initialize_polynomial_Z_field(model_params) - self._initialize_Zernike_OPD(model_params) - self._initialize_batch_polychromatic_layer(model_params, training_params) - self._initialize_nonparametric_opd_layer(model_params, training_params) - - def _initialize_physical_layer(self, model_params): - """Initialize the physical layer of the PSF model. - - This method initializes the physical layer of the PSF model using parameters - specified in the `model_params` object. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - """ - self.tf_physical_layer = TFPhysicalLayer( - self.obs_pos, - self.zks_prior, - interpolation_type=model_params.interpolation_type, - interpolation_args=model_params.interpolation_args, + # Compute n_zks_total as int + self._n_zks_total = max( + self.model_params.param_hparams.n_zernikes, + zks_total_contribution_np.shape[1], ) - def _initialize_polynomial_Z_field(self, model_params): - """Initialize the polynomial Zernike field of the PSF model. - - This method initializes the polynomial Zernike field of the PSF model using - parameters specified in the `model_params` object. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - - """ - self.tf_poly_Z_field = TFPolynomialZernikeField( - x_lims=model_params.x_lims, - y_lims=model_params.y_lims, - random_seed=model_params.param_hparams.random_seed, - n_zernikes=model_params.param_hparams.n_zernikes, - d_max=model_params.param_hparams.d_max, + # Precompute zernike maps as tf.float32 + self._zernike_maps = psfm.generate_zernike_maps_3d( + n_zernikes=self._n_zks_total, pupil_diam=self.model_params.pupil_diameter ) - def _initialize_Zernike_OPD(self, model_params): - """Initialize the Zernike OPD field of the PSF model. - - This method initializes the Zernike Optical Path Difference - field of the PSF model using parameters specified in the `model_params` object. + # Precompute OPD dimension + self._opd_dim = self._zernike_maps.shape[1] - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. + # Precompute obscurations as tf.complex64 + self._obscurations = psfm.tf_obscurations( + pupil_diam=self.model_params.pupil_diameter, + N_filter=self.model_params.LP_filter_length, + rotation_angle=self.model_params.obscuration_rotation_angle, + ) - """ - # Initialize the zernike to OPD layer - self.tf_zernike_OPD = TFZernikeOPD(zernike_maps=self.zernike_maps) + # Eagerly initialise model layers + self.tf_batch_poly_PSF = self._build_tf_batch_poly_PSF() + _ = self.tf_poly_Z_field + _ = self.tf_np_poly_opd - def _initialize_batch_polychromatic_layer(self, model_params, training_params): - """Initialize the batch polychromatic PSF layer. + def _get_run_type(self, data): + if hasattr(data, "run_type"): + run_type = data.run_type + elif isinstance(data, dict) and "run_type" in data: + run_type = data["run_type"] + else: + raise ValueError("data must have a 'run_type' attribute or key") + + if run_type not in {"training", "simulation", "metrics", "inference"}: + raise ValueError(f"Unknown run_type: {run_type}") + return run_type + + def _assemble_zernike_contributions(self): + zks_inputs = ZernikeInputsFactory.build( + data=self.data, + run_type=self.run_type, + model_params=self.model_params, + prior=self._external_prior if hasattr(self, "_external_prior") else None, + ) + return assemble_zernike_contributions( + model_params=self.model_params, + zernike_prior=zks_inputs.zernike_prior, + centroid_dataset=zks_inputs.centroid_dataset, + positions=zks_inputs.misalignment_positions, + batch_size=self.training_params.batch_size, + ) - This method initializes the batch opd to batch polychromatic PSF layer - using the provided `model_params` and `training_params`. + @property + def save_param_history(self) -> bool: + """Check if the model should save the optimization history for parametric features.""" + return getattr( + self.model_params.param_hparams, "save_optim_history_param", False + ) - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - training_params: Recursive Namespace - A Recursive Namespace object containing training hyperparameters for this PSF model class. + @property + def save_nonparam_history(self) -> bool: + """Check if the model should save the optimization history for non-parametric features.""" + return getattr( + self.model_params.nonparam_hparams, "save_optim_history_nonparam", False + ) + def get_obs_pos(self): + assert self.run_type in { + "training", + "simulation", + "metrics", + "inference", + }, f"Unknown run_type: {self.run_type}" - """ - self.batch_size = training_params.batch_size - self.obscurations = psfm.tf_obscurations( - pupil_diam=model_params.pupil_diameter, - N_filter=model_params.LP_filter_length, - rotation_angle=model_params.obscuration_rotation_angle, + raw_pos = get_data_array( + data=self.data, run_type=self.run_type, key="positions" ) - self.output_dim = model_params.output_dim - self.tf_batch_poly_PSF = TFBatchPolychromaticPSF( + obs_pos = ensure_tensor(raw_pos, dtype=tf.float32) + + return obs_pos + + # === Lazy properties ===. + @property + def zks_total_contribution(self): + return self._zks_total_contribution + + @property + def n_zks_total(self): + """Get the total number of Zernike coefficients.""" + return self._n_zks_total + + @property + def zernike_maps(self): + """Get Zernike maps.""" + return self._zernike_maps + + @property + def opd_dim(self): + return self._opd_dim + + @property + def obscurations(self): + return self._obscurations + + @property + def tf_poly_Z_field(self): + """Lazy loading of the polynomial Zernike field layer.""" + if not hasattr(self, "_tf_poly_Z_field"): + self._tf_poly_Z_field = TFPolynomialZernikeField( + x_lims=self.model_params.x_lims, + y_lims=self.model_params.y_lims, + random_seed=self.model_params.param_hparams.random_seed, + n_zernikes=self.model_params.param_hparams.n_zernikes, + d_max=self.model_params.param_hparams.d_max, + ) + return self._tf_poly_Z_field + + @tf_poly_Z_field.deleter + def tf_poly_Z_field(self): + del self._tf_poly_Z_field + + @property + def tf_physical_layer(self): + """Lazy loading of the physical layer of the PSF model.""" + if not hasattr(self, "_tf_physical_layer"): + self._tf_physical_layer = TFPhysicalLayer( + self.obs_pos, + self.zks_total_contribution, + interpolation_type=self.model_params.interpolation_type, + interpolation_args=self.model_params.interpolation_args, + ) + return self._tf_physical_layer + + @property + def tf_zernike_OPD(self): + """Lazy loading of the Zernike Optical Path Difference (OPD) layer.""" + if not hasattr(self, "_tf_zernike_OPD"): + self._tf_zernike_OPD = TFZernikeOPD(zernike_maps=self.zernike_maps) + return self._tf_zernike_OPD + + def _build_tf_batch_poly_PSF(self): + """Eagerly build the TFBatchPolychromaticPSF layer with numpy-based obscurations.""" + return TFBatchPolychromaticPSF( obscurations=self.obscurations, output_Q=self.output_Q, output_dim=self.output_dim, ) - def _initialize_nonparametric_opd_layer(self, model_params, training_params): - """Initialize the non-parametric OPD layer. - - This method initializes the non-parametric OPD layer using the provided - `model_params` and `training_params`. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - training_params: Recursive Namespace - A Recursive Namespace object containing training hyperparameters for this PSF model class. - - """ - # self.d_max_nonparam = model_params.nonparam_hparams.d_max_nonparam - # self.opd_dim = tf.shape(self.zernike_maps)[1].numpy() - - self.tf_np_poly_opd = TFNonParametricPolynomialVariationsOPD( - x_lims=model_params.x_lims, - y_lims=model_params.y_lims, - random_seed=model_params.param_hparams.random_seed, - d_max=model_params.nonparam_hparams.d_max_nonparam, - opd_dim=tf.shape(self.zernike_maps)[1].numpy(), - ) + @property + def tf_np_poly_opd(self): + """Lazy loading of the non-parametric polynomial variations OPD layer.""" + if not hasattr(self, "_tf_np_poly_opd"): + self._tf_np_poly_opd = TFNonParametricPolynomialVariationsOPD( + x_lims=self.model_params.x_lims, + y_lims=self.model_params.y_lims, + random_seed=self.model_params.param_hparams.random_seed, + d_max=self.model_params.nonparam_hparams.d_max_nonparam, + opd_dim=self.opd_dim, + ) + return self._tf_np_poly_opd def get_coeff_matrix(self): """Get coefficient matrix.""" @@ -335,18 +327,15 @@ def set_output_Q(self, output_Q: float, output_dim: Optional[int] = None) -> Non This method updates the `output_Q` parameter, which defines the resampling factor for generating PSFs at different resolutions - relative to the telescope's native sampling. It also allows optionally - updating `output_dim`, which sets the output resolution of the PSF model. + relative to the telescope's native sampling. It also allows optionally updating `output_dim`, which sets the output resolution of the PSF model. If `output_dim` is provided, the PSF model's output resolution is updated. - The method then reinitializes the batch polychromatic PSF generator - to reflect the updated parameters. + The method then reinitializes the batch polychromatic PSF generator to reflect the updated parameters. Parameters ---------- output_Q : float - The resampling factor that determines the output PSF resolution - relative to the telescope's native sampling. + The resampling factor that determines the output PSF resolution relative to the telescope's native sampling. output_dim : Optional[int], default=None The new output dimension for the PSF model. If `None`, the output dimension remains unchanged. @@ -358,6 +347,7 @@ def set_output_Q(self, output_Q: float, output_dim: Optional[int] = None) -> Non self.output_Q = output_Q if output_dim is not None: self.output_dim = output_dim + # Reinitialize the PSF batch poly generator self.tf_batch_poly_PSF = TFBatchPolychromaticPSF( obscurations=self.obscurations, @@ -471,12 +461,16 @@ def predict_step(self, data, evaluate_step=False): # Compute zernikes from parametric model and physical layer zks_coeffs = self.predict_zernikes(input_positions) + # Propagate to obtain the OPD param_opd_maps = self.tf_zernike_OPD(zks_coeffs) + # Calculate the non parametric part nonparam_opd_maps = self.tf_np_poly_opd(input_positions) + # Add the estimations opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) + # Compute the polychromatic PSFs poly_psfs = self.tf_batch_poly_PSF([opd_maps, packed_SEDs]) @@ -519,10 +513,13 @@ def predict_mono_psfs(self, input_positions, lambda_obs, phase_N): # Predict zernikes from parametric model and physical layer zks_coeffs = self.predict_zernikes(input_positions) + # Propagate to obtain the OPD param_opd_maps = self.tf_zernike_OPD(zks_coeffs) + # Calculate the non parametric part nonparam_opd_maps = self.tf_np_poly_opd(input_positions) + # Add the estimations opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) @@ -547,10 +544,13 @@ def predict_opd(self, input_positions): """ # Predict zernikes from parametric model and physical layer zks_coeffs = self.predict_zernikes(input_positions) + # Propagate to obtain the OPD param_opd_maps = self.tf_zernike_OPD(zks_coeffs) + # Calculate the non parametric part nonparam_opd_maps = self.tf_np_poly_opd(input_positions) + # Add the estimations opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) @@ -585,9 +585,10 @@ def compute_zernikes(self, input_positions): zernike_prior = self.tf_physical_layer.call(input_positions) # Pad and sum the zernike coefficients - padded_zernike_params, padded_zernike_prior = self.pad_zernikes( - zernike_params, zernike_prior + padded_zernike_params, padded_zernike_prior = pad_tf_zernikes( + zernike_params, zernike_prior, self.n_zks_total ) + zernike_coeffs = tf.math.add(padded_zernike_params, padded_zernike_prior) return zernike_coeffs @@ -622,8 +623,8 @@ def predict_zernikes(self, input_positions): physical_layer_prediction = self.tf_physical_layer.predict(input_positions) # Pad and sum the Zernike coefficients - padded_zernike_params, padded_physical_layer_prediction = self.pad_zernikes( - zernike_params, physical_layer_prediction + padded_zernike_params, padded_physical_layer_prediction = pad_tf_zernikes( + zernike_params, physical_layer_prediction, self.n_zks_total ) zernike_coeffs = tf.math.add( padded_zernike_params, padded_physical_layer_prediction @@ -688,22 +689,21 @@ def call(self, inputs, training=True): # Compute zernikes from parametric model and physical layer zks_coeffs = self.compute_zernikes(input_positions) - # Propagate to obtain the OPD + # Parametric OPD maps from Zernikes param_opd_maps = self.tf_zernike_OPD(zks_coeffs) - # Add l2 loss on the parametric OPD - self.add_loss( - self.l2_param * tf.math.reduce_sum(tf.math.square(param_opd_maps)) - ) + # Add L2 regularization loss on parametric OPD maps + self.add_loss(self.l2_param * tf.reduce_sum(tf.square(param_opd_maps))) - # Calculate the non parametric part + # Non-parametric correction nonparam_opd_maps = self.tf_np_poly_opd(input_positions) - # Add the estimations - opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) + # Combine both contributions + opd_maps = tf.add(param_opd_maps, nonparam_opd_maps) # Compute the polychromatic PSFs poly_psfs = self.tf_batch_poly_PSF([opd_maps, packed_SEDs]) + # For the inference else: # Compute predictions diff --git a/src/wf_psf/psf_models/psf_model_semiparametric.py b/src/wf_psf/psf_models/models/psf_model_semiparametric.py similarity index 99% rename from src/wf_psf/psf_models/psf_model_semiparametric.py rename to src/wf_psf/psf_models/models/psf_model_semiparametric.py index dc535204..7b2ff04d 100644 --- a/src/wf_psf/psf_models/psf_model_semiparametric.py +++ b/src/wf_psf/psf_models/models/psf_model_semiparametric.py @@ -10,9 +10,9 @@ import numpy as np import tensorflow as tf from wf_psf.psf_models import psf_models as psfm -from wf_psf.psf_models import tf_layers as tfl +from wf_psf.psf_models.tf_modules import tf_layers as tfl from wf_psf.utils.utils import decompose_tf_obscured_opd_basis -from wf_psf.psf_models.tf_layers import ( +from wf_psf.psf_models.tf_modules.tf_layers import ( TFBatchPolychromaticPSF, TFBatchMonochromaticPSF, ) @@ -421,7 +421,7 @@ def project_DD_features(self, tf_zernike_cube=None): s_new = self.tf_np_poly_opd.S_mat - s_mat_projected self.assign_S_mat(s_new) - def call(self, inputs): + def call(self, inputs, **kwargs): """Define the PSF field forward model. [1] From positions to Zernike coefficients diff --git a/src/wf_psf/psf_models/psf_model_loader.py b/src/wf_psf/psf_models/psf_model_loader.py new file mode 100644 index 00000000..e41e3536 --- /dev/null +++ b/src/wf_psf/psf_models/psf_model_loader.py @@ -0,0 +1,57 @@ +"""PSF Model Loader. + +This module provides helper functions for loading trained PSF models. +It includes utilities to: +- Load a model from disk using its configuration and weights. +- Prepare inputs for inference or evaluation workflows. + +Author: Jennifer Pollack +""" + +import logging +from wf_psf.psf_models.psf_models import get_psf_model, get_psf_model_weights_filepath + +logger = logging.getLogger(__name__) + + +def load_trained_psf_model(training_conf, data_conf, weights_path_pattern): + """ + Loads a trained PSF model and applies saved weights. + + Parameters + ---------- + training_conf : RecursiveNamespace + Configuration object containing model parameters and training hyperparameters. + Supports attribute-style access to nested fields. + data_conf : RecursiveNamespace or dict + Configuration RecursiveNamespace object or a dictionary containing data parameters (e.g. pixel data, positions, masks, etc). + weights_path_pattern : str + Glob-style pattern used to locate the model weights file. + + Returns + ------- + model : tf.keras.Model or compatible + The PSF model instance with loaded weights. + + Raises + ------ + RuntimeError + If loading the model weights fails for any reason. + """ + model = get_psf_model( + training_conf.training.model_params, + training_conf.training.training_hparams, + data_conf, + ) + + weights_path = get_psf_model_weights_filepath(weights_path_pattern) + + try: + logger.info(f"Loading PSF model weights from {weights_path}") + status = model.load_weights(weights_path) + status.expect_partial() + + except Exception as e: + logger.exception("Failed to load model weights.") + raise RuntimeError("Model weight loading failed.") from e + return model diff --git a/src/wf_psf/psf_models/psf_models.py b/src/wf_psf/psf_models/psf_models.py index dcb6abd9..41976590 100644 --- a/src/wf_psf/psf_models/psf_models.py +++ b/src/wf_psf/psf_models/psf_models.py @@ -186,24 +186,24 @@ def build_PSF_model(model_inst, optimizer=None, loss=None, metrics=None): def get_psf_model_weights_filepath(weights_filepath): """Get PSF model weights filepath. - A function to return the basename of the user-specified psf model weights path. + A function to return the basename of the user-specified PSF model weights path. Parameters ---------- weights_filepath: str - Basename of the psf model weights to be loaded. + Basename of the PSF model weights to be loaded. Returns ------- str - The absolute path concatenated to the basename of the psf model weights to be loaded. + The absolute path concatenated to the basename of the PSF model weights to be loaded. """ try: return glob.glob(weights_filepath)[0].split(".")[0] except IndexError: logger.exception( - "PSF weights file not found. Check that you've specified the correct weights file in the metrics config file." + "PSF weights file not found. Check that you've specified the correct weights file in the your config file." ) raise PSFModelError("PSF model weights error.") diff --git a/src/wf_psf/psf_models/tf_modules/__init__.py b/src/wf_psf/psf_models/tf_modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/wf_psf/psf_models/tf_layers.py b/src/wf_psf/psf_models/tf_modules/tf_layers.py similarity index 98% rename from src/wf_psf/psf_models/tf_layers.py rename to src/wf_psf/psf_models/tf_modules/tf_layers.py index 4479e73f..30de4804 100644 --- a/src/wf_psf/psf_models/tf_layers.py +++ b/src/wf_psf/psf_models/tf_modules/tf_layers.py @@ -1,6 +1,7 @@ import tensorflow as tf import tensorflow_addons as tfa -from wf_psf.psf_models.tf_modules import TFMonochromaticPSF +from wf_psf.psf_models.tf_modules.tf_modules import TFMonochromaticPSF +from wf_psf.psf_models.tf_modules.tf_utils import find_position_indices from wf_psf.utils.utils import calc_poly_position_mat import wf_psf.utils.utils as utils import logging @@ -188,7 +189,6 @@ def calculate_monochromatic_PSF(self, packed_elems): def calculate_polychromatic_PSF(self, packed_elems): """Calculate a polychromatic PSF.""" - self.current_opd = packed_elems[0][tf.newaxis, :, :] SED_pack_data = packed_elems[1] @@ -213,7 +213,6 @@ def _calculate_polychromatic_PSF(elems_to_unpack): def call(self, inputs): """Calculate the batch polychromatic PSFs.""" - # Unpack Inputs opd_batch = inputs[0] packed_SED_data = inputs[1] @@ -298,7 +297,6 @@ def set_output_params(self, output_Q, output_dim): def call(self, opd_batch): """Calculate the batch poly PSFs.""" - if self.phase_N is None: self.set_lambda_phaseN() @@ -311,7 +309,7 @@ def _calculate_PSF_batch(elems_to_unpack): swap_memory=True, ) - mono_psf_batch = _calculate_PSF_batch((opd_batch)) + mono_psf_batch = _calculate_PSF_batch(opd_batch) return mono_psf_batch @@ -319,7 +317,6 @@ def _calculate_PSF_batch(elems_to_unpack): class TFNonParametricPolynomialVariationsOPD(tf.keras.layers.Layer): """Non-parametric OPD generation with polynomial variations. - Parameters ---------- x_lims: [int, int] @@ -423,7 +420,6 @@ def call(self, positions): class TFNonParametricMCCDOPDv2(tf.keras.layers.Layer): """Non-parametric OPD generation with hybrid-MCCD variations. - Parameters ---------- obs_pos: tensor(n_stars, 2) @@ -641,7 +637,6 @@ def calc_index(idx_pos): class TFNonParametricGraphOPD(tf.keras.layers.Layer): """Non-parametric OPD generation with only graph-cosntraint variations. - Parameters ---------- obs_pos: tensor(n_stars, 2) @@ -749,7 +744,6 @@ def set_alpha_identity(self): def predict(self, positions): """Prediction step.""" - ## Graph part A_graph_train = tf.linalg.matmul(self.graph_dic, self.alpha_graph) # RBF interpolation @@ -959,13 +953,10 @@ def call(self, positions): If the shape of the input `positions` tensor is not compatible. """ + # Find indices for all positions in one batch operation + idx = find_position_indices(self.obs_pos, positions) - def calc_index(idx_pos): - return tf.where(tf.equal(self.obs_pos, idx_pos))[0, 0] - - # Calculate the indices of the input batch - indices = tf.map_fn(calc_index, positions, fn_output_signature=tf.int64) - # Recover the prior zernikes from the batch indexes - batch_zks = tf.gather(self.zks_prior, indices=indices, axis=0, batch_dims=0) + # Gather the corresponding Zernike coefficients + batch_zks = tf.gather(self.zks_prior, idx, axis=0) return batch_zks[:, :, tf.newaxis, tf.newaxis] diff --git a/src/wf_psf/psf_models/tf_modules.py b/src/wf_psf/psf_models/tf_modules/tf_modules.py similarity index 93% rename from src/wf_psf/psf_models/tf_modules.py rename to src/wf_psf/psf_models/tf_modules/tf_modules.py index 5a597847..2d93834e 100644 --- a/src/wf_psf/psf_models/tf_modules.py +++ b/src/wf_psf/psf_models/tf_modules/tf_modules.py @@ -1,10 +1,11 @@ """TensorFlow-Based PSF Modeling. -A module containing TensorFlow implementations for modeling monochromatic PSFs using Zernike polynomials and Fourier optics. +A module containing TensorFlow implementations for modeling monochromatic PSFs using Zernike polynomials and Fourier optics. :Author: Tobias Liaudat """ + import numpy as np import tensorflow as tf from typing import Optional @@ -21,7 +22,9 @@ class TFFftDiffract(tf.Module): Downsampling factor. Must be integer. """ - def __init__(self, output_dim: int = 64, output_Q: int = 2, name: Optional[str] = None) -> None: + def __init__( + self, output_dim: int = 64, output_Q: int = 2, name: Optional[str] = None + ) -> None: """Initialize the TFFftDiffract class. Parameters @@ -47,15 +50,15 @@ def __init__(self, output_dim: int = 64, output_Q: int = 2, name: Optional[str] def tf_crop_img(self, image, output_crop_dim): """Crop images using TensorFlow methods. - This method handles a batch of 2D images and crops them to the specified dimension. - The images are expected to have the shape [batch, width, height], and the method + This method handles a batch of 2D images and crops them to the specified dimension. + The images are expected to have the shape [batch, width, height], and the method uses TensorFlow's `crop_to_bounding_box` to crop each image in the batch. Parameters ---------- image : tf.Tensor A batch of 2D images with shape [batch, height, width]. The images are expected - to be 3D tensors where the second and third dimensions represent the height and width. + to be 3D tensors where the second and third dimensions represent the height and width. output_crop_dim : int The dimension of the square crop. The image will be cropped to this dimension. @@ -108,8 +111,8 @@ def normalize_psf(self, psf): def __call__(self, input_phase): """Calculate the normalized Point Spread Function (PSF) from a phase array. - This method takes a 2D input phase array, applies a 2D FFT-based diffraction operation, - crops the resulting PSF, and downscales it by a factor of Q if necessary. Finally, the PSF + This method takes a 2D input phase array, applies a 2D FFT-based diffraction operation, + crops the resulting PSF, and downscales it by a factor of Q if necessary. Finally, the PSF is normalized by summing over its spatial dimensions. Parameters @@ -120,7 +123,7 @@ def __call__(self, input_phase): Returns ------- tf.Tensor - The normalized PSF tensor with shape [batch, height, width], where each PSF is normalized + The normalized PSF tensor with shape [batch, height, width], where each PSF is normalized by its sum over the spatial dimensions. """ # Perform the FFT-based diffraction operation @@ -174,7 +177,13 @@ class TFBuildPhase(tf.Module): A tensor representing the obscurations (e.g., apertures or masks) to be applied to the phase. """ - def __init__(self, phase_N: int, lambda_obs: float, obscurations: tf.Tensor, name: Optional[str] = None) -> None: + def __init__( + self, + phase_N: int, + lambda_obs: float, + obscurations: tf.Tensor, + name: Optional[str] = None, + ) -> None: """Initialize the TFBuildPhase class. Parameters @@ -225,7 +234,6 @@ def zero_padding_diffraction(self, no_pad_phase): padded_phase = tf.pad(no_pad_phase, padding) return padded_phase - def apply_obscurations(self, phase: tf.Tensor) -> tf.Tensor: """Apply obscurations to the phase map. @@ -295,15 +303,15 @@ def __call__(self, opd): class TFZernikeOPD(tf.Module): """Convert Zernike coefficients into an Optical Path Difference (OPD). - This class performs the weighted sum of Zernike coefficients and Zernike maps - to compute the OPD. The Zernike maps and the corresponding Zernike coefficients + This class performs the weighted sum of Zernike coefficients and Zernike maps + to compute the OPD. The Zernike maps and the corresponding Zernike coefficients are required to perform the calculation. Parameters ---------- zernike_maps : tf.Tensor - A tensor containing the Zernike maps. The shape should be - (num_coeffs, x_dim, y_dim), where `num_coeffs` is the number of Zernike coefficients + A tensor containing the Zernike maps. The shape should be + (num_coeffs, x_dim, y_dim), where `num_coeffs` is the number of Zernike coefficients and `x_dim`, `y_dim` are the dimensions of each map. name : str, optional @@ -312,12 +320,12 @@ class TFZernikeOPD(tf.Module): Returns ------- tf.Tensor - A tensor representing the OPD, with shape (num_star, x_dim, y_dim), - where `num_star` corresponds to the number of stars and `x_dim`, `y_dim` are + A tensor representing the OPD, with shape (num_star, x_dim, y_dim), + where `num_star` corresponds to the number of stars and `x_dim`, `y_dim` are the dimensions of the OPD map. """ - def __init__(self, zernike_maps : tf.Tensor, name: Optional[str] = None) -> None: + def __init__(self, zernike_maps: tf.Tensor, name: Optional[str] = None) -> None: """ Initialize the TFZernikeOPD class. @@ -332,18 +340,18 @@ def __init__(self, zernike_maps : tf.Tensor, name: Optional[str] = None) -> None self.zernike_maps = zernike_maps - def __call__(self, z_coeffs : tf.Tensor) -> tf.Tensor: + def __call__(self, z_coeffs: tf.Tensor) -> tf.Tensor: """Compute the OPD from Zernike coefficients and maps. - This method calculates the OPD by performing the weighted sum of Zernike - coefficients and corresponding Zernike maps. The result is a tensor representing + This method calculates the OPD by performing the weighted sum of Zernike + coefficients and corresponding Zernike maps. The result is a tensor representing the computed OPD for the given coefficients. Parameters ---------- z_coeffs : tf.Tensor - A tensor containing the Zernike coefficients. The shape should be - (num_star, num_coeffs, 1, 1), where `num_star` is the number of stars and + A tensor containing the Zernike coefficients. The shape should be + (num_star, num_coeffs, 1, 1), where `num_star` is the number of stars and `num_coeffs` is the number of Zernike coefficients. Returns @@ -358,30 +366,30 @@ def __call__(self, z_coeffs : tf.Tensor) -> tf.Tensor: class TFZernikeMonochromaticPSF(tf.Module): """Build a monochromatic Point Spread Function (PSF) from Zernike coefficients. - This class computes the monochromatic PSF by following the Zernike model. It - involves multiple stages, including the calculation of the OPD (Optical Path - Difference), the phase from the OPD, and diffraction via FFT-based operations. + This class computes the monochromatic PSF by following the Zernike model. It + involves multiple stages, including the calculation of the OPD (Optical Path + Difference), the phase from the OPD, and diffraction via FFT-based operations. The Zernike coefficients are used to generate the PSF. Parameters ---------- phase_N : int The size of the phase grid, typically a square matrix dimension. - + lambda_obs : float The wavelength of the observed light. - + obscurations : tf.Tensor - A tensor representing the obscurations in the system, which will be applied + A tensor representing the obscurations in the system, which will be applied to the phase. zernike_maps : tf.Tensor - A tensor containing the Zernike maps, with the shape (num_coeffs, x_dim, y_dim), - where `num_coeffs` is the number of Zernike coefficients and `x_dim`, `y_dim` are + A tensor containing the Zernike maps, with the shape (num_coeffs, x_dim, y_dim), + where `num_coeffs` is the number of Zernike coefficients and `x_dim`, `y_dim` are the dimensions of the Zernike maps. output_dim : int, optional, default=64 - The output dimension of the PSF, i.e., the size of the resulting image. + The output dimension of the PSF, i.e., the size of the resulting image. name : str, optional The name of the module. Default is `None`. @@ -390,17 +398,22 @@ class TFZernikeMonochromaticPSF(tf.Module): ---------- tf_build_opd_zernike : TFZernikeOPD A module used to generate the OPD from the Zernike coefficients. - + tf_build_phase : TFBuildPhase A module used to compute the phase from the OPD. - + tf_fft_diffract : TFFftDiffract A module that performs the diffraction calculation using FFT-based methods. """ def __init__( - self, phase_N: int, lambda_obs: float, obscurations: tf.Tensor, - zernike_maps: tf.Tensor, output_dim: int = 64, name: Optional[str] = None + self, + phase_N: int, + lambda_obs: float, + obscurations: tf.Tensor, + zernike_maps: tf.Tensor, + output_dim: int = 64, + name: Optional[str] = None, ): """ Initialize the TFZernikeMonochromaticPSF class. @@ -437,15 +450,15 @@ def __call__(self, z_coeffs): Parameters ---------- z_coeffs : tf.Tensor - A tensor containing the Zernike coefficients. The shape should be - (num_star, num_coeffs, 1, 1), where `num_star` is the number of stars + A tensor containing the Zernike coefficients. The shape should be + (num_star, num_coeffs, 1, 1), where `num_star` is the number of stars and `num_coeffs` is the number of Zernike coefficients. Returns ------- tf.Tensor - A tensor representing the computed PSF, with shape - (num_star, output_dim, output_dim), where `output_dim` is the size of + A tensor representing the computed PSF, with shape + (num_star, output_dim, output_dim), where `output_dim` is the size of the resulting PSF image. """ # Generate OPD from Zernike coefficients diff --git a/src/wf_psf/psf_models/tf_psf_field.py b/src/wf_psf/psf_models/tf_modules/tf_psf_field.py similarity index 97% rename from src/wf_psf/psf_models/tf_psf_field.py rename to src/wf_psf/psf_models/tf_modules/tf_psf_field.py index 9ef58ac6..86575298 100644 --- a/src/wf_psf/psf_models/tf_psf_field.py +++ b/src/wf_psf/psf_models/tf_modules/tf_psf_field.py @@ -9,15 +9,16 @@ import numpy as np import tensorflow as tf from tensorflow.python.keras.engine import data_adapter -from wf_psf.psf_models.tf_layers import ( +from wf_psf.psf_models.tf_modules.tf_layers import ( TFZernikeOPD, TFBatchPolychromaticPSF, TFBatchMonochromaticPSF, TFPhysicalLayer, ) -from wf_psf.psf_models.psf_model_semiparametric import TFSemiParametricField -from wf_psf.data.training_preprocessing import get_obs_positions +from wf_psf.psf_models.models.psf_model_semiparametric import TFSemiParametricField +from wf_psf.data.data_handler import get_data_array from wf_psf.psf_models import psf_models as psfm +from wf_psf.psf_models.tf_modules.tf_utils import ensure_tensor import logging logger = logging.getLogger(__name__) @@ -221,7 +222,9 @@ def __init__(self, model_params, training_params, data, coeff_mat): self.output_Q = model_params.output_Q # Inputs: TF_physical_layer - self.obs_pos = get_obs_positions(data) + self.obs_pos = ensure_tensor( + get_data_array(data, data.run_type, key="positions"), dtype=tf.float32 + ) self.zks_prior = get_ground_truth_zernike(data) self.n_zks_prior = tf.shape(self.zks_prior)[1].numpy() diff --git a/src/wf_psf/psf_models/tf_modules/tf_utils.py b/src/wf_psf/psf_models/tf_modules/tf_utils.py new file mode 100644 index 00000000..49f3471f --- /dev/null +++ b/src/wf_psf/psf_models/tf_modules/tf_utils.py @@ -0,0 +1,102 @@ +"""TensorFlow Utilities Module. + +Provides lightweight utility functions for safely converting and managing data types +within TensorFlow-based workflows. + +Includes: +- `ensure_tensor`: ensures inputs are TensorFlow tensors with specified dtype + +These tools are designed to support PSF model components, including lazy property evaluation, +data input validation, and type normalization. + +This module is intended for internal use in model layers and inference components to enforce +TensorFlow-compatible inputs. + +Authors: Jennifer Pollack +""" + +import tensorflow as tf + + +@tf.function +def find_position_indices(obs_pos, batch_positions): + """Find indices of batch positions within observed positions using vectorized operations. + + This function locates the indices of multiple query positions within a + reference set of observed positions using broadcasting and vectorized operations. + Each position in the batch must have an exact match in the observed positions. + + Parameters + ---------- + obs_pos : tf.Tensor + Reference positions tensor of shape (n_obs, 2), where n_obs is the number of + observed positions. Each row contains [x, y] coordinates. + batch_positions : tf.Tensor + Query positions tensor of shape (batch_size, 2), where batch_size is the number + of positions to look up. Each row contains [x, y] coordinates. + + Returns + ------- + indices : tf.Tensor + Tensor of shape (batch_size,) containing the indices of each batch position + within obs_pos. The dtype is tf.int64. + + Raises + ------ + tf.errors.InvalidArgumentError + If any position in batch_positions is not found in obs_pos. + + Notes + ----- + Uses exact equality matching - positions must match exactly. More efficient than + iterative lookups for multiple positions due to vectorized operations. + """ + # Shape: obs_pos (n_obs, 2), batch_positions (batch_size, 2) + # Expand for broadcasting: (1, n_obs, 2) and (batch_size, 1, 2) + obs_expanded = tf.expand_dims(obs_pos, 0) + pos_expanded = tf.expand_dims(batch_positions, 1) + + # Compare all positions at once: (batch_size, n_obs) + matches = tf.reduce_all(tf.equal(obs_expanded, pos_expanded), axis=2) + + # Find the index of the matching position for each batch item + # argmax returns the first True value's index along axis=1 + indices = tf.argmax(tf.cast(matches, tf.int32), axis=1) + + # Verify all positions were found + tf.debugging.assert_equal( + tf.reduce_all(tf.reduce_any(matches, axis=1)), + True, + message="Some positions not found in obs_pos", + ) + + return indices + + +def ensure_tensor(input_array, dtype=tf.float32): + """ + Ensure the input is a TensorFlow tensor of the specified dtype. + + Parameters + ---------- + input_array : array-like, tf.Tensor, or np.ndarray + The input to convert. + dtype : tf.DType, optional + The desired TensorFlow dtype (default: tf.float32). +<<<<<<< HEAD + +======= +>>>>>>> f2d8aa4 (merge feature/159-psf-output-from-trained-model with real data metrics) + Returns + ------- + tf.Tensor + A TensorFlow tensor with the specified dtype. + """ + if tf.is_tensor(input_array): + # If already a tensor, optionally cast dtype if different + if input_array.dtype != dtype: + return tf.cast(input_array, dtype) + return input_array + else: + # Convert numpy arrays or other types to tensor + return tf.convert_to_tensor(input_array, dtype=dtype) diff --git a/src/wf_psf/psf_models/zernikes.py b/src/wf_psf/psf_models/zernikes.py deleted file mode 100644 index dcfa6e39..00000000 --- a/src/wf_psf/psf_models/zernikes.py +++ /dev/null @@ -1,58 +0,0 @@ -"""Zernikes. - -A module to make Zernike maps. - -:Author: Tobias Liaudat and Jennifer Pollack - -""" - -import numpy as np -import zernike as zk -import logging - -logger = logging.getLogger(__name__) - - -def zernike_generator(n_zernikes, wfe_dim): - """ - Generate Zernike maps. - - Based on the zernike github repository. - https://github.com/jacopoantonello/zernike - - Parameters - ---------- - n_zernikes: int - Number of Zernike modes desired. - wfe_dim: int - Dimension of the Zernike map [wfe_dim x wfe_dim]. - - Returns - ------- - zernikes: list of np.ndarray - List containing the Zernike modes. - The values outside the unit circle are filled with NaNs. - """ - # Calculate which n (from the (n,m) Zernike convention) we need - # so that we have the desired total number of Zernike coefficients - min_n = (-3 + np.sqrt(1 + 8 * n_zernikes)) / 2 - n = int(np.ceil(min_n)) - - # Initialize the zernike generator - cart = zk.RZern(n) - # Create a [-1,1] mesh - ddx = np.linspace(-1.0, 1.0, wfe_dim) - ddy = np.linspace(-1.0, 1.0, wfe_dim) - xv, yv = np.meshgrid(ddx, ddy) - cart.make_cart_grid(xv, yv) - - c = np.zeros(cart.nk) - zernikes = [] - - # Extract each Zernike map one by one - for i in range(n_zernikes): - c *= 0.0 - c[i] = 1.0 - zernikes.append(cart.eval_grid(c, matrix=True)) - - return zernikes diff --git a/src/wf_psf/run.py b/src/wf_psf/run.py index 828a6e17..171519bb 100644 --- a/src/wf_psf/run.py +++ b/src/wf_psf/run.py @@ -93,9 +93,7 @@ def mainMethod(): except Exception as e: logger.error( - "Check your config file {} for errors. Error Msg: {}.".format( - args.conffile, e - ), + f"Check your config file {args.conffile} for errors. Error Msg: {e}.", exc_info=True, ) diff --git a/src/wf_psf/sims/psf_simulator.py b/src/wf_psf/sims/psf_simulator.py index f953cded..dec06ff3 100644 --- a/src/wf_psf/sims/psf_simulator.py +++ b/src/wf_psf/sims/psf_simulator.py @@ -19,7 +19,7 @@ print("Problem importing skimage..") -class PSFSimulator(object): +class PSFSimulator: """Simulate PSFs. In the future the zernike maps could be created with galsim or some other @@ -198,7 +198,7 @@ def fft_diffract(wf, output_Q, output_dim=64): start = int(psf.shape[0] // 2 - (output_dim * output_Q) // 2) stop = int(psf.shape[0] // 2 + (output_dim * output_Q) // 2) else: - start = int(0) + start = 0 stop = psf.shape[0] # Crop psf @@ -227,7 +227,7 @@ def generate_euclid_pupil_obscurations(N_pix=1024, N_filter=3, rotation_angle=0) """Generate Euclid like pupil obscurations. This method simulates the 2D pupil obscurations for the Euclid telescope, - considering the aperture stop, mirror obscurations, and spider arms. It does + considering the aperture stop, mirror obscurations, and spider arms. It does not account for any 3D projections or the angle of the Field of View (FoV). Parameters @@ -366,7 +366,6 @@ def decimate_im(input_im, decim_f): Based on the PIL library using the default interpolator. """ - pil_im = PIL.Image.fromarray(input_im) (width, height) = (pil_im.width // decim_f, pil_im.height // decim_f) im_resized = pil_im.resize((width, height)) @@ -593,7 +592,6 @@ def calculate_wfe_rms(self, z_coeffs=None): def check_wfe_rms(self, z_coeffs=None, max_wfe_rms=None): """Check if Zernike coefficients are within the maximum admitted error.""" - if max_wfe_rms is None: max_wfe_rms = self.max_wfe_rms diff --git a/src/wf_psf/sims/spatial_varying_psf.py b/src/wf_psf/sims/spatial_varying_psf.py index 33548d7f..f343b025 100644 --- a/src/wf_psf/sims/spatial_varying_psf.py +++ b/src/wf_psf/sims/spatial_varying_psf.py @@ -240,7 +240,6 @@ def check_position_coordinate_limits(xv, yv, x_lims, y_lims, verbose): None """ - x_check = np.sum(xv >= x_lims[1] * 1.1) + np.sum(xv <= x_lims[0] * 1.1) y_check = np.sum(yv >= y_lims[1] * 1.1) + np.sum(yv <= y_lims[0] * 1.1) @@ -480,7 +479,7 @@ def calculate_zernike( ) -class SpatialVaryingPSF(object): +class SpatialVaryingPSF: """Spatial Varying PSF. Generate PSF field with polynomial variations of Zernike coefficients. @@ -621,7 +620,6 @@ def calculate_wfe_rms(self, xv, yv, polynomial_coeffs): numpy.ndarray An array containing the WFE RMS values for the provided positions. """ - Z = ZernikeHelper.generate_zernike_polynomials( xv, yv, self.x_lims, self.y_lims, self.d_max, polynomial_coeffs ) @@ -645,7 +643,6 @@ def build_polynomial_coeffs(self): ------- None """ - # Build mesh xv_grid, yv_grid = MeshHelper.build_mesh( self.x_lims, self.y_lims, self.grid_points diff --git a/src/wf_psf/tests/__init__.py b/src/wf_psf/tests/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/src/wf_psf/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/src/wf_psf/tests/conftest.py b/src/wf_psf/tests/conftest.py index 5b617c63..beb6b9fb 100644 --- a/src/wf_psf/tests/conftest.py +++ b/src/wf_psf/tests/conftest.py @@ -13,7 +13,7 @@ from wf_psf.training.train import TrainingParamsHandler from wf_psf.utils.configs_handler import DataConfigHandler from wf_psf.psf_models import psf_models -from wf_psf.data.training_preprocessing import DataHandler +from wf_psf.data.data_handler import DataHandler training_config = RecursiveNamespace( id_name="-coherent_euclid_200stars", diff --git a/src/wf_psf/tests/data/validation/masked_loss/results/plot_results.ipynb b/src/wf_psf/tests/data/validation/masked_loss/results/plot_results.ipynb index bec994a4..c36a9e31 100644 --- a/src/wf_psf/tests/data/validation/masked_loss/results/plot_results.ipynb +++ b/src/wf_psf/tests/data/validation/masked_loss/results/plot_results.ipynb @@ -17,19 +17,19 @@ "outputs": [], "source": [ "# Trained on masked data, tested on masked data\n", - "metrics_path_mm = '../runs/masked_train_masked_test/wf-outputs/wf-outputs-202503131718/metrics/metrics-polymask_train_mask_test.npy'\n", + "metrics_path_mm = \"../runs/masked_train_masked_test/wf-outputs/wf-outputs-202503131718/metrics/metrics-polymask_train_mask_test.npy\"\n", "mask_train_mask_test = np.load(metrics_path_mm, allow_pickle=True)[()]\n", "\n", "# Trained on masked data, tested on unmasked data\n", - "metrics_path_mu = '../runs/masked_train_unit_mask_test/wf-outputs/wf-outputs-202503131720/metrics/metrics-polymasked_train_unit_mask_test.npy'\n", + "metrics_path_mu = \"../runs/masked_train_unit_mask_test/wf-outputs/wf-outputs-202503131720/metrics/metrics-polymasked_train_unit_mask_test.npy\"\n", "mask_train_nomask_test = np.load(metrics_path_mu, allow_pickle=True)[()]\n", "\n", "# Trained on unmasked data, tested on unmasked data\n", - "metrics_path_c = '../runs/control_train/wf-outputs/wf-outputs-202503131716/metrics/metrics-polycontrol_train.npy'\n", + "metrics_path_c = \"../runs/control_train/wf-outputs/wf-outputs-202503131716/metrics/metrics-polycontrol_train.npy\"\n", "control_train = np.load(metrics_path_c, allow_pickle=True)[()]\n", "\n", "# Trained and tested with unitary masks\n", - "metrics_path_u = '../runs/unit_masked_train/wf-outputs/wf-outputs-202503131721/metrics/metrics-polyunit_masked_train.npy'\n", + "metrics_path_u = \"../runs/unit_masked_train/wf-outputs/wf-outputs-202503131721/metrics/metrics-polyunit_masked_train.npy\"\n", "unitary = np.load(metrics_path_u, allow_pickle=True)[()]" ] }, @@ -50,8 +50,8 @@ ], "source": [ "print(mask_train_mask_test.keys())\n", - "print(mask_train_mask_test['test_metrics'].keys())\n", - "print(mask_train_mask_test['test_metrics']['poly_metric'].keys())" + "print(mask_train_mask_test[\"test_metrics\"].keys())\n", + "print(mask_train_mask_test[\"test_metrics\"][\"poly_metric\"].keys())" ] }, { @@ -60,17 +60,25 @@ "metadata": {}, "outputs": [], "source": [ - "mask_test_mask_test_rel_rmse = mask_train_mask_test['test_metrics']['poly_metric']['rel_rmse']\n", - "mask_test_mask_test_std_rel_rmse = mask_train_mask_test['test_metrics']['poly_metric']['std_rel_rmse']\n", + "mask_test_mask_test_rel_rmse = mask_train_mask_test[\"test_metrics\"][\"poly_metric\"][\n", + " \"rel_rmse\"\n", + "]\n", + "mask_test_mask_test_std_rel_rmse = mask_train_mask_test[\"test_metrics\"][\"poly_metric\"][\n", + " \"std_rel_rmse\"\n", + "]\n", "\n", - "mask_test_nomask_test_rel_rmse = mask_train_nomask_test['test_metrics']['poly_metric']['rel_rmse']\n", - "mask_test_nomask_test_std_rel_rmse = mask_train_nomask_test['test_metrics']['poly_metric']['std_rel_rmse']\n", + "mask_test_nomask_test_rel_rmse = mask_train_nomask_test[\"test_metrics\"][\"poly_metric\"][\n", + " \"rel_rmse\"\n", + "]\n", + "mask_test_nomask_test_std_rel_rmse = mask_train_nomask_test[\"test_metrics\"][\n", + " \"poly_metric\"\n", + "][\"std_rel_rmse\"]\n", "\n", - "control_test_rel_rmse = control_train['test_metrics']['poly_metric']['rel_rmse']\n", - "control_test_std_rel_rmse = control_train['test_metrics']['poly_metric']['std_rel_rmse']\n", + "control_test_rel_rmse = control_train[\"test_metrics\"][\"poly_metric\"][\"rel_rmse\"]\n", + "control_test_std_rel_rmse = control_train[\"test_metrics\"][\"poly_metric\"][\"std_rel_rmse\"]\n", "\n", - "unitary_test_rel_rmse = unitary['test_metrics']['poly_metric']['rel_rmse']\n", - "unitary_test_std_rel_rmse = unitary['test_metrics']['poly_metric']['std_rel_rmse']" + "unitary_test_rel_rmse = unitary[\"test_metrics\"][\"poly_metric\"][\"rel_rmse\"]\n", + "unitary_test_std_rel_rmse = unitary[\"test_metrics\"][\"poly_metric\"][\"std_rel_rmse\"]" ] }, { @@ -92,12 +100,29 @@ "source": [ "# Plot the results\n", "fig, ax = plt.subplots(figsize=(8, 4))\n", - "plt.title('Relative RMSE 1x - Test dataset')\n", - "ax.errorbar([0, 1, 2, 3], [control_test_rel_rmse, mask_test_mask_test_rel_rmse, mask_test_nomask_test_rel_rmse, unitary_test_rel_rmse], yerr=[control_test_std_rel_rmse, mask_test_mask_test_std_rel_rmse, mask_test_nomask_test_std_rel_rmse, unitary_test_std_rel_rmse], fmt='o')\n", + "plt.title(\"Relative RMSE 1x - Test dataset\")\n", + "ax.errorbar(\n", + " [0, 1, 2, 3],\n", + " [\n", + " control_test_rel_rmse,\n", + " mask_test_mask_test_rel_rmse,\n", + " mask_test_nomask_test_rel_rmse,\n", + " unitary_test_rel_rmse,\n", + " ],\n", + " yerr=[\n", + " control_test_std_rel_rmse,\n", + " mask_test_mask_test_std_rel_rmse,\n", + " mask_test_nomask_test_std_rel_rmse,\n", + " unitary_test_std_rel_rmse,\n", + " ],\n", + " fmt=\"o\",\n", + ")\n", "ax.set_xticks([0, 1, 2, 3])\n", - "ax.set_xticklabels(['Control Train', 'Mask Train Mask Test', 'Mask Train Unit Mask Test', 'Unitary'])\n", - "ax.grid('minor')\n", - "ax.set_ylabel('Relative RMSE')\n", + "ax.set_xticklabels(\n", + " [\"Control Train\", \"Mask Train Mask Test\", \"Mask Train Unit Mask Test\", \"Unitary\"]\n", + ")\n", + "ax.grid(\"minor\")\n", + "ax.set_ylabel(\"Relative RMSE\")\n", "plt.show()" ] }, @@ -107,17 +132,27 @@ "metadata": {}, "outputs": [], "source": [ - "mask_train_mask_test_rel_rmse = mask_train_mask_test['train_metrics']['poly_metric']['rel_rmse']\n", - "mask_train_mask_test_std_rel_rmse = mask_train_mask_test['train_metrics']['poly_metric']['std_rel_rmse']\n", + "mask_train_mask_test_rel_rmse = mask_train_mask_test[\"train_metrics\"][\"poly_metric\"][\n", + " \"rel_rmse\"\n", + "]\n", + "mask_train_mask_test_std_rel_rmse = mask_train_mask_test[\"train_metrics\"][\n", + " \"poly_metric\"\n", + "][\"std_rel_rmse\"]\n", "\n", - "mask_train_nomask_test_rel_rmse = mask_train_nomask_test['train_metrics']['poly_metric']['rel_rmse']\n", - "mask_train_nomask_test_std_rel_rmse = mask_train_nomask_test['train_metrics']['poly_metric']['std_rel_rmse']\n", + "mask_train_nomask_test_rel_rmse = mask_train_nomask_test[\"train_metrics\"][\n", + " \"poly_metric\"\n", + "][\"rel_rmse\"]\n", + "mask_train_nomask_test_std_rel_rmse = mask_train_nomask_test[\"train_metrics\"][\n", + " \"poly_metric\"\n", + "][\"std_rel_rmse\"]\n", "\n", - "control_train_rel_rmse = control_train['train_metrics']['poly_metric']['rel_rmse']\n", - "control_train_std_rel_rmse = control_train['train_metrics']['poly_metric']['std_rel_rmse']\n", + "control_train_rel_rmse = control_train[\"train_metrics\"][\"poly_metric\"][\"rel_rmse\"]\n", + "control_train_std_rel_rmse = control_train[\"train_metrics\"][\"poly_metric\"][\n", + " \"std_rel_rmse\"\n", + "]\n", "\n", - "unitary_rel_rmse = unitary['train_metrics']['poly_metric']['rel_rmse']\n", - "unitary_std_rel_rmse = unitary['train_metrics']['poly_metric']['std_rel_rmse']" + "unitary_rel_rmse = unitary[\"train_metrics\"][\"poly_metric\"][\"rel_rmse\"]\n", + "unitary_std_rel_rmse = unitary[\"train_metrics\"][\"poly_metric\"][\"std_rel_rmse\"]" ] }, { @@ -139,12 +174,29 @@ "source": [ "# Plot the results\n", "fig, ax = plt.subplots(figsize=(8, 4))\n", - "plt.title('Relative RMSE 1x - Train dataset')\n", - "ax.errorbar([0, 1, 2, 3], [control_train_rel_rmse, mask_train_mask_test_rel_rmse, mask_train_nomask_test_rel_rmse, unitary_rel_rmse], yerr=[control_train_std_rel_rmse, mask_train_mask_test_std_rel_rmse, mask_train_nomask_test_std_rel_rmse, unitary_std_rel_rmse], fmt='o')\n", + "plt.title(\"Relative RMSE 1x - Train dataset\")\n", + "ax.errorbar(\n", + " [0, 1, 2, 3],\n", + " [\n", + " control_train_rel_rmse,\n", + " mask_train_mask_test_rel_rmse,\n", + " mask_train_nomask_test_rel_rmse,\n", + " unitary_rel_rmse,\n", + " ],\n", + " yerr=[\n", + " control_train_std_rel_rmse,\n", + " mask_train_mask_test_std_rel_rmse,\n", + " mask_train_nomask_test_std_rel_rmse,\n", + " unitary_std_rel_rmse,\n", + " ],\n", + " fmt=\"o\",\n", + ")\n", "ax.set_xticks([0, 1, 2, 3])\n", - "ax.set_xticklabels(['Control Train', 'Mask Train Mask Test', 'Mask Train Unit Mask Test', 'Unitary'])\n", - "ax.set_ylabel('Relative RMSE')\n", - "ax.grid('minor')\n", + "ax.set_xticklabels(\n", + " [\"Control Train\", \"Mask Train Mask Test\", \"Mask Train Unit Mask Test\", \"Unitary\"]\n", + ")\n", + "ax.set_ylabel(\"Relative RMSE\")\n", + "ax.grid(\"minor\")\n", "plt.show()" ] }, @@ -167,16 +219,50 @@ "source": [ "# Plot test and train relative RMSE in the same plot\n", "fig, ax = plt.subplots(figsize=(8, 4))\n", - "plt.title('Relative RMSE 1x - Train and Test dataset')\n", - "ax.errorbar([0, 1, 2, 3], [control_train_rel_rmse, mask_train_mask_test_rel_rmse, mask_train_nomask_test_rel_rmse, unitary_rel_rmse], yerr=[control_train_std_rel_rmse, mask_train_mask_test_std_rel_rmse, mask_train_nomask_test_std_rel_rmse, unitary_std_rel_rmse], fmt='o', label='Train')\n", - "ax.errorbar([0.02, 1.02, 2.02, 3.02], [control_test_rel_rmse, mask_test_mask_test_rel_rmse, mask_test_nomask_test_rel_rmse, unitary_test_rel_rmse], yerr=[control_test_std_rel_rmse, mask_test_mask_test_std_rel_rmse, mask_test_nomask_test_std_rel_rmse, unitary_test_std_rel_rmse], fmt='o', label='Test')\n", + "plt.title(\"Relative RMSE 1x - Train and Test dataset\")\n", + "ax.errorbar(\n", + " [0, 1, 2, 3],\n", + " [\n", + " control_train_rel_rmse,\n", + " mask_train_mask_test_rel_rmse,\n", + " mask_train_nomask_test_rel_rmse,\n", + " unitary_rel_rmse,\n", + " ],\n", + " yerr=[\n", + " control_train_std_rel_rmse,\n", + " mask_train_mask_test_std_rel_rmse,\n", + " mask_train_nomask_test_std_rel_rmse,\n", + " unitary_std_rel_rmse,\n", + " ],\n", + " fmt=\"o\",\n", + " label=\"Train\",\n", + ")\n", + "ax.errorbar(\n", + " [0.02, 1.02, 2.02, 3.02],\n", + " [\n", + " control_test_rel_rmse,\n", + " mask_test_mask_test_rel_rmse,\n", + " mask_test_nomask_test_rel_rmse,\n", + " unitary_test_rel_rmse,\n", + " ],\n", + " yerr=[\n", + " control_test_std_rel_rmse,\n", + " mask_test_mask_test_std_rel_rmse,\n", + " mask_test_nomask_test_std_rel_rmse,\n", + " unitary_test_std_rel_rmse,\n", + " ],\n", + " fmt=\"o\",\n", + " label=\"Test\",\n", + ")\n", "ax.set_xticks([0, 1, 2, 3])\n", - "ax.set_xticklabels(['Control Train', 'Mask Train Mask Test', 'Mask Train Unit Mask Test', 'Unitary'])\n", - "ax.set_ylabel('Relative RMSE')\n", - "ax.grid('minor')\n", + "ax.set_xticklabels(\n", + " [\"Control Train\", \"Mask Train Mask Test\", \"Mask Train Unit Mask Test\", \"Unitary\"]\n", + ")\n", + "ax.set_ylabel(\"Relative RMSE\")\n", + "ax.grid(\"minor\")\n", "plt.legend()\n", "# plt.show()\n", - "plt.savefig('masked_loss_validation.pdf')\n" + "plt.savefig(\"masked_loss_validation.pdf\")" ] }, { diff --git a/src/wf_psf/tests/test_data/__init__.py b/src/wf_psf/tests/test_data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/wf_psf/tests/test_utils/centroids_test.py b/src/wf_psf/tests/test_data/centroids_test.py similarity index 52% rename from src/wf_psf/tests/test_utils/centroids_test.py rename to src/wf_psf/tests/test_data/centroids_test.py index 02e479e9..e22ab042 100644 --- a/src/wf_psf/tests/test_utils/centroids_test.py +++ b/src/wf_psf/tests/test_data/centroids_test.py @@ -8,109 +8,94 @@ import numpy as np import pytest +from wf_psf.data.centroids import compute_centroid_correction, CentroidEstimator +from wf_psf.utils.read_config import RecursiveNamespace from unittest.mock import MagicMock, patch -from wf_psf.utils.centroids import ( - compute_zernike_tip_tilt, - CentroidEstimator -) + # Function to compute centroid based on first-order moments def calculate_centroid(image, mask=None): if mask is not None: image = np.ma.masked_array(image, mask=mask) - + # Calculate moments M00 = np.sum(image) M10 = np.sum(np.arange(image.shape[1]) * np.sum(image, axis=0)) M01 = np.sum(np.arange(image.shape[0]) * np.sum(image, axis=1)) - + # Centroid formula xc = M10 / M00 yc = M01 / M00 return (xc, yc) -@pytest.fixture -def simple_image(): - """Fixture for a batch of simple star images.""" - num_images = 1 # Change this to test with multiple images - image = np.zeros((num_images, 5, 5)) # Create a 3D array - image[:, 2, 2] = 1 # Place the star at the center for each image - return image - -@pytest.fixture -def multiple_images(): - """Fixture for a batch of images with stars at different positions.""" - images = np.zeros((3, 5, 5)) # 3 images, each of size 5x5 - images[0, 2, 2] = 1 # Star at center of image 0 - images[1, 1, 3] = 1 # Star at (1, 3) in image 1 - images[2, 3, 1] = 1 # Star at (3, 1) in image 2 - return images @pytest.fixture def simple_star_and_mask(): """Fixture for an image with multiple non-zero pixels for centroid calculation.""" num_images = 1 # Change this to test with multiple images - image = np.zeros((num_images, 5, 5)) # Create a 3D array (5x5 image for each "image") - + image = np.zeros( + (num_images, 5, 5) + ) # Create a 3D array (5x5 image for each "image") + # Place non-zero values in multiple pixels image[:, 2, 2] = 10 # Star at the center - image[:, 2, 3] = 10 # Adjacent pixel + image[:, 2, 3] = 10 # Adjacent pixel image[:, 3, 2] = 10 # Adjacent pixel image[:, 3, 3] = 10 # Adjacent pixel forming a symmetric pattern mask = np.zeros_like(image) - mask[:, 3, 2] = 1 - mask[:, 3, 3] = 1 + mask[:, 3, 2] = 1 + mask[:, 3, 3] = 1 return image, mask -@pytest.fixture -def identity_mask(): - """Creates a mask where all pixels are fully considered.""" - return np.ones((5, 5)) - - @pytest.fixture def simple_image_with_mask(simple_image): """Fixture for a batch of star images with masks.""" - num_images = simple_image.shape[0] # Get the number of images from the first dimension + num_images = simple_image.shape[ + 0 + ] # Get the number of images from the first dimension mask = np.ones((num_images, 5, 5)) # Create a batch of masks mask[:, 1:4, 1:4] = 0 # Mask a 3x3 region for each image return simple_image, mask + @pytest.fixture def centroid_estimator(simple_image): """Fixture for initializing CentroidEstimator.""" return CentroidEstimator(simple_image) + @pytest.fixture def centroid_estimator_with_mask(simple_image_with_mask): """Fixture for initializing CentroidEstimator with a mask.""" image, mask = simple_image_with_mask return CentroidEstimator(image, mask=mask) + @pytest.fixture def simple_image_with_centroid(simple_image): """Fixture for a simple image with known centroid and initial position.""" image = simple_image - + # Known centroid and initial position (xc0, yc0) - for testing xc0, yc0 = 2.0, 2.0 # Assume the initial center of the image is (2.0, 2.0) - + # Create CentroidEstimator instance - centroid_estimator = CentroidEstimator(im=image, n_iter=1) + centroid_estimator = CentroidEstimator(im=image, n_iter=1) - centroid_estimator.window=np.ones_like(image) + centroid_estimator.window = np.ones_like(image) centroid_estimator.xc0 = xc0 centroid_estimator.yc0 = yc0 - + # Simulate the computed centroid being slightly off-center centroid_estimator.xc = 2.3 centroid_estimator.yc = 2.7 - + return centroid_estimator + @pytest.fixture def batch_images(): """Fixture for multiple PSF images.""" @@ -120,101 +105,84 @@ def batch_images(): return images -def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_mask): - """Test compute_zernike_tip_tilt with single batch input and mocks.""" - - # Mock the CentroidEstimator class - mock_centroid_calc = mocker.patch("wf_psf.utils.centroids.CentroidEstimator", autospec=True) - - # Create a mock instance and configure get_intra_pixel_shifts() - mock_instance = mock_centroid_calc.return_value - mock_instance.get_intra_pixel_shifts.return_value = np.array([[0.05, -0.02]]) # Shape (1, 2) - - # Mock shift_x_y_to_zk1_2_wavediff to return predictable values - mock_shift_fn = mocker.patch( - "wf_psf.utils.centroids.shift_x_y_to_zk1_2_wavediff", - side_effect=lambda shift: shift * 0.5 # Mocked conversion for test +def test_compute_centroid_correction_with_masks(mock_data): + """Test compute_centroid_correction function with masks present.""" + # Given that compute_centroid_correction expects a model_params and data object + model_params = RecursiveNamespace( + pix_sampling=12e-6, # Example pixel sampling in meters + correct_centroids=True, + reference_shifts=["-1/3", "-1/3"], ) - # Define test inputs (batch of 1 image) - pixel_sampling = 12e-6 - reference_shifts = [-1 / 3, -1 / 3] # Default Euclid conditions - - # Run the function - zernike_corrections = compute_zernike_tip_tilt(simple_image, identity_mask, pixel_sampling, reference_shifts) - zernike_corrections = compute_zernike_tip_tilt(simple_image, identity_mask, pixel_sampling, reference_shifts) - - # Expected shifts based on centroid calculation - expected_dx = (reference_shifts[1] - (-0.02)) # Expected x-axis shift in meters - expected_dy = (reference_shifts[0] - 0.05) # Expected y-axis shift in meters - - # Expected calls to the mocked function - # Extract the arguments passed to mock_shift_fn - args, _ = mock_shift_fn.call_args_list[0] # Get the first call args - - # Compare expected values with the actual arguments passed to the mock function - np.testing.assert_allclose(args[0][0], expected_dx * pixel_sampling, rtol=1e-7, atol=0) - - # Check dy values similarly - np.testing.assert_allclose(args[0][1], expected_dy * pixel_sampling, rtol=1e-7, atol=0) - - # Expected values based on mock side_effect (0.5 * shift) - np.testing.assert_allclose(zernike_corrections[0, 0], expected_dx * pixel_sampling * 0.5) # Zk1 - np.testing.assert_allclose(zernike_corrections[0, 1], expected_dy * pixel_sampling * 0.5) # Zk2 - -def test_compute_zernike_tip_tilt_batch(mocker, multiple_images): - """Test compute_zernike_tip_tilt with batch input and mocks.""" - - # Mock the CentroidEstimator class - mock_centroid_calc = mocker.patch("wf_psf.utils.centroids.CentroidEstimator", autospec=True) - - # Create a mock instance and configure get_intra_pixel_shifts() - mock_instance = mock_centroid_calc.return_value - mock_instance.get_intra_pixel_shifts.return_value = np.array([[0.05, -0.02], [0.04, -0.01], [0.06, -0.03]]) # Shape (3, 2) - - # Mock shift_x_y_to_zk1_2_wavediff to return predictable values - mock_shift_fn = mocker.patch( - "wf_psf.utils.centroids.shift_x_y_to_zk1_2_wavediff", - side_effect=lambda shift: shift * 0.5 # Mocked conversion for test + # Wrap mock_data into a dict to match the function signature + centroid_dataset = { + "stamps": mock_data.training_data.dataset["noisy_stars"], + "masks": mock_data.training_data.dataset["masks"], + } + + # Mock the internal function calls: + with ( + patch( + "wf_psf.data.centroids.compute_zernike_tip_tilt" + ) as mock_compute_zernike_tip_tilt, + ): + # Mock compute_zernike_tip_tilt to return synthetic Zernike coefficients + mock_compute_zernike_tip_tilt.return_value = np.array([[0.1, 0.2], [0.3, 0.4]]) + + # Call the function under test + result = compute_centroid_correction(model_params, centroid_dataset) + + # Ensure the result has the correct shape + assert result.shape == (4, 3) # Should be (n_stars, 3 Zernike components) + + assert np.allclose( + result[0, :], np.array([0, -0.1, -0.2]) + ) # First star Zernike coefficients + assert np.allclose( + result[1, :], np.array([0, -0.3, -0.4]) + ) # Second star Zernike coefficients + + +def test_compute_centroid_correction_without_masks(mock_data): + """Test compute_centroid_correction function when no masks are provided.""" + # Define model parameters + model_params = RecursiveNamespace( + pix_sampling=12e-6, # Example pixel sampling in meters + correct_centroids=True, + reference_shifts=["-1/3", "-1/3"], ) - # Define test inputs (batch of 3 images) - pixel_sampling = 12e-6 - reference_shifts = [-1 / 3, -1 / 3] # Default Euclid conditions - - # Run the function - zernike_corrections = compute_zernike_tip_tilt( - star_images=multiple_images, - pixel_sampling=pixel_sampling, - reference_shifts=reference_shifts + # Wrap mock_data into a dict to match the function signature + centroid_dataset = { + "stamps": mock_data.training_data.dataset["noisy_stars"], + } + + # Mock internal function calls + with ( + patch( + "wf_psf.data.centroids.compute_zernike_tip_tilt" + ) as mock_compute_zernike_tip_tilt, + ): + + # Mock compute_zernike_tip_tilt assuming no masks + mock_compute_zernike_tip_tilt.return_value = np.array([[0.1, 0.2], [0.3, 0.4]]) + + # Call function under test + result = compute_centroid_correction(model_params, centroid_dataset) + + # Validate result shape + assert result.shape == (4, 3) # (n_stars, 3 Zernike components) + + # Validate expected values (adjust based on behavior) + expected_result = -1.0 * np.array( + [ + [0, 0.1, 0.2], # From training data + [0, 0.3, 0.4], + [0, 0.1, 0.2], # From test data (reused mocked return) + [0, 0.3, 0.4], + ] ) - - # Check if the mock function was called once with the full batch - assert len(mock_shift_fn.call_args_list) == 1, f"Expected 1 call, but got {len(mock_shift_fn.call_args_list)}" - - # Get the arguments passed to the mock function for the batch of images - args, _ = mock_shift_fn.call_args_list[0] - - print("Shape of args[0]:", args[0].shape) - print("Contents of args[0]:", args[0]) - print("Mock function call args list:", mock_shift_fn.call_args_list) - - # Reshape args[0] to (N, 2) for batch processing - args_array = np.array(args[0]).reshape(-1, 2) - - # Process the displacements and expected values for each image in the batch - expected_dx = reference_shifts[1] - mock_instance.get_intra_pixel_shifts.return_value[:, 1] # Expected x-axis shift in meters - - expected_dy = reference_shifts[0] - mock_instance.get_intra_pixel_shifts.return_value[:, 0] # Expected y-axis shift in meters - - # Compare expected values with the actual arguments passed to the mock function - np.testing.assert_allclose(args_array[:, 0], expected_dx * pixel_sampling, rtol=1e-7, atol=0) - np.testing.assert_allclose(args_array[:, 1], expected_dy * pixel_sampling, rtol=1e-7, atol=0) - - # Expected values based on mock side_effect (0.5 * shift) - np.testing.assert_allclose(zernike_corrections[:, 0], expected_dx * pixel_sampling * 0.5) # Zk1 for each image - np.testing.assert_allclose(zernike_corrections[:, 1], expected_dy * pixel_sampling * 0.5) # Zk2 for each image - + assert np.allclose(result, expected_result) # Test for centroid calculation without mask @@ -225,6 +193,7 @@ def test_centroid_calculation_one_star(centroid_estimator): assert np.isclose(xc, 2.0) assert np.isclose(yc, 2.0) + # Test for centroid calculation with mask def test_centroid_calculation_with_one_star_and_mask(centroid_estimator_with_mask): """Test that the centroid is calculated correctly when a mask is applied.""" @@ -233,53 +202,61 @@ def test_centroid_calculation_with_one_star_and_mask(centroid_estimator_with_mas assert np.isclose(xc, 2.0) assert np.isclose(yc, 2.0) + def test_centroid_calculation_multiple_images(multiple_images): """Test the centroid estimation for a batch of images.""" estimator = CentroidEstimator(im=multiple_images) - + # Check that centroids are correctly estimated expected_centroids = [(2.0, 2.0), (1.0, 3.0), (3.0, 1.0)] - for i, (xc, yc) in enumerate(zip(estimator.xc,estimator.yc)): + for i, (xc, yc) in enumerate(zip(estimator.xc, estimator.yc)): assert np.isclose(xc, expected_centroids[i][0]) assert np.isclose(yc, expected_centroids[i][1]) + def test_centroid_no_mask(simple_star_and_mask): # Extract star star, _ = simple_star_and_mask # Expected centroid for the symmetric pattern - true_centroid = (2.5, 2.5) - + true_centroid = (2.5, 2.5) + # Create the CentroidEstimator instance (assuming auto_run=True by default) centroid_estimator = CentroidEstimator(im=star, n_iter=1) - + # Check if the centroid is calculated correctly computed_centroid = (centroid_estimator.xc, centroid_estimator.yc) assert np.isclose(computed_centroid[0], true_centroid[0]) assert np.isclose(computed_centroid[1], true_centroid[1]) + # Test for centroid calculation with a mask def test_centroid_with_mask(simple_star_and_mask): # Extract star and mask star, mask = simple_star_and_mask # Expected centroid after masking (estimated manually) - expected_masked_centroid = (2.0, 2.5) - + expected_masked_centroid = (2.0, 2.5) + # Create the CentroidEstimator instance (with mask) centroid_estimator = CentroidEstimator(im=star, mask=mask, n_iter=1) - + # Check if the centroid is calculated correctly with the mask applied computed_centroid = (centroid_estimator.xc, centroid_estimator.yc) assert np.isclose(computed_centroid[0], expected_masked_centroid[0]) assert np.isclose(computed_centroid[1], expected_masked_centroid[1]) + def test_centroid_estimator_initialization(simple_image): """Test the initialization of the CentroidEstimator.""" estimator = CentroidEstimator(simple_image) assert estimator.im.shape == (1, 5, 5) # Shape should match the input image - assert estimator.xc0 == 2.5 # Default xc should be the center of the image, i.e. float(self.stamp_size[0]) / 2 - assert estimator.yc0 == 2.5 # Default yc should be the center of the image, i.e. float(self.stamp_size[0]) / 2 + assert ( + estimator.xc0 == 2.5 + ) # Default xc should be the center of the image, i.e. float(self.stamp_size[0]) / 2 + assert ( + estimator.yc0 == 2.5 + ) # Default yc should be the center of the image, i.e. float(self.stamp_size[0]) / 2 assert estimator.sigma_init == 7.5 # Default sigma_init should be 7.5 assert estimator.n_iter == 5 # Default number of iterations should be 5 assert estimator.mask is None # By default, mask should be None @@ -287,7 +264,6 @@ def test_centroid_estimator_initialization(simple_image): def test_single_iteration(centroid_estimator): """Test that the internal methods are called exactly once for n_iter=1.""" - # Mock the methods centroid_estimator.update_grid = MagicMock() centroid_estimator.elliptical_gaussian = MagicMock() @@ -295,7 +271,7 @@ def test_single_iteration(centroid_estimator): # Set n_iter to 1 centroid_estimator.n_iter = 1 - + # Run the estimate method centroid_estimator.estimate() @@ -304,14 +280,18 @@ def test_single_iteration(centroid_estimator): centroid_estimator.elliptical_gaussian.assert_called_once() centroid_estimator.compute_moments.assert_called_once() + def test_single_iteration_auto_run(simple_image): """Test that the internal methods are called exactly once for n_iter=1.""" - # Patch the methods at the time the object is created - with patch.object(CentroidEstimator, 'update_grid') as update_grid_mock, \ - patch.object(CentroidEstimator, 'elliptical_gaussian') as elliptical_gaussian_mock, \ - patch.object(CentroidEstimator, 'compute_moments') as compute_moments_mock: - + with ( + patch.object(CentroidEstimator, "update_grid") as update_grid_mock, + patch.object( + CentroidEstimator, "elliptical_gaussian" + ) as elliptical_gaussian_mock, + patch.object(CentroidEstimator, "compute_moments") as compute_moments_mock, + ): + # Initialize the CentroidEstimator with auto_run=True centroid_estimator = CentroidEstimator(im=simple_image, n_iter=1, auto_run=True) @@ -320,54 +300,79 @@ def test_single_iteration_auto_run(simple_image): elliptical_gaussian_mock.assert_called_once() compute_moments_mock.assert_called_once() + def test_update_grid(simple_image): """Test that the grid is correctly updated.""" centroid_estimator = CentroidEstimator(im=simple_image, auto_run=True, n_iter=1) - + # Check the shapes of the grid coordinates assert centroid_estimator.xx.shape == (1, 5, 5) assert centroid_estimator.yy.shape == (1, 5, 5) - + # Check the values of the grid coordinates # xx should be the same for all rows and columns (broadcasted across the image) - assert np.allclose(centroid_estimator.xx, - np.array([[[[-2.5, -2.5, -2.5, -2.5, -2.5], - [-1.5, -1.5, -1.5, -1.5, -1.5], - [-0.5, -0.5, -0.5, -0.5, -0.5], - [ 0.5, 0.5, 0.5, 0.5, 0.5], - [ 1.5, 1.5, 1.5, 1.5, 1.5]]]])) - + assert np.allclose( + centroid_estimator.xx, + np.array( + [ + [ + [ + [-2.5, -2.5, -2.5, -2.5, -2.5], + [-1.5, -1.5, -1.5, -1.5, -1.5], + [-0.5, -0.5, -0.5, -0.5, -0.5], + [0.5, 0.5, 0.5, 0.5, 0.5], + [1.5, 1.5, 1.5, 1.5, 1.5], + ] + ] + ] + ), + ) + # yy should be the same for all columns and rows (broadcasted across the image) - assert np.allclose(centroid_estimator.yy, - np.array([[[[-2.5, -1.5, -0.5, 0.5, 1.5], - [-2.5, -1.5, -0.5, 0.5, 1.5], - [-2.5, -1.5, -0.5, 0.5, 1.5], - [-2.5, -1.5, -0.5, 0.5, 1.5], - [-2.5, -1.5, -0.5, 0.5, 1.5]]]])) + assert np.allclose( + centroid_estimator.yy, + np.array( + [ + [ + [ + [-2.5, -1.5, -0.5, 0.5, 1.5], + [-2.5, -1.5, -0.5, 0.5, 1.5], + [-2.5, -1.5, -0.5, 0.5, 1.5], + [-2.5, -1.5, -0.5, 0.5, 1.5], + [-2.5, -1.5, -0.5, 0.5, 1.5], + ] + ] + ] + ), + ) + def test_elliptical_gaussian(simple_image): """Test that the elliptical Gaussian is calculated correctly.""" centroid_estimator = CentroidEstimator(im=simple_image, n_iter=1) # Check if the output is a valid 2D array with the correct shape assert centroid_estimator.window.shape == (1, 5, 5) - + # Check if the Gaussian window values are reasonable (non-negative and decrease with distance) assert np.all(centroid_estimator.window >= 0) - assert np.isclose(np.sum(centroid_estimator.window), 25, atol=1.0) + assert np.isclose(np.sum(centroid_estimator.window), 25, atol=1.0) def test_intra_pixel_shifts(simple_image_with_centroid): """Test the return_intra_pixel_shifts method.""" - centroid_estimator = simple_image_with_centroid - + # Calculate intra-pixel shifts shifts = centroid_estimator.get_intra_pixel_shifts() - + # Expected intra-pixel shifts expected_x_shift = 2.3 - 2.0 # xc - xc0 expected_y_shift = 2.7 - 2.0 # yc - yc0 - + # Check that the shifts are correct - assert np.isclose(shifts[0], expected_x_shift), f"Expected {expected_x_shift}, got {shifts[0]}" - assert np.isclose(shifts[1], expected_y_shift), f"Expected {expected_y_shift}, got {shifts[1]}" + assert np.isclose( + shifts[0], expected_x_shift + ), f"Expected {expected_x_shift}, got {shifts[0]}" + assert np.isclose( + shifts[1], expected_y_shift + ), f"Expected {expected_y_shift}, got {shifts[1]}" diff --git a/src/wf_psf/tests/test_data/conftest.py b/src/wf_psf/tests/test_data/conftest.py index 6159d53a..131922e5 100644 --- a/src/wf_psf/tests/test_data/conftest.py +++ b/src/wf_psf/tests/test_data/conftest.py @@ -9,8 +9,13 @@ """ import pytest +import numpy as np +import tensorflow as tf +from types import SimpleNamespace + from wf_psf.utils.read_config import RecursiveNamespace from wf_psf.psf_models import psf_models +from wf_psf.tests.test_data.test_data_utils import MockData training_config = RecursiveNamespace( id_name="-coherent_euclid_200stars", @@ -93,6 +98,76 @@ ) +@pytest.fixture +def mock_data(scope="module"): + """Fixture to provide mock data for testing.""" + # Mock positions and Zernike priors + training_positions = tf.constant([[1, 2], [3, 4]]) + test_positions = tf.constant([[5, 6], [7, 8]]) + training_zernike_priors = tf.constant([[0.1, 0.2], [0.3, 0.4]]) + test_zernike_priors = tf.constant([[0.5, 0.6], [0.7, 0.8]]) + + # Define dummy 5x5 image patches for stars (mock star images) + # Define varied values for 5x5 star images + noisy_stars = tf.constant( + [np.arange(25).reshape(5, 5), np.arange(25, 50).reshape(5, 5)], dtype=tf.float32 + ) + + noisy_masks = tf.constant([np.eye(5), np.ones((5, 5))], dtype=tf.float32) + + stars = tf.constant([np.full((5, 5), 100), np.full((5, 5), 200)], dtype=tf.float32) + + masks = tf.constant([np.zeros((5, 5)), np.tri(5)], dtype=tf.float32) + + return MockData( + training_positions, + test_positions, + training_zernike_priors, + test_zernike_priors, + noisy_stars, + noisy_masks, + stars, + masks, + ) + + +@pytest.fixture +def mock_data_inference(): + """Flat dataset for inference path only.""" + return SimpleNamespace( + dataset={ + "positions": np.array([[9, 9], [10, 10]]), + "zernike_prior": np.array([[0.9, 0.9]]), + # no "missing_key" → used to trigger allow_missing behavior + } + ) + + +@pytest.fixture +def simple_image(scope="module"): + """Fixture for a simple star image.""" + num_images = 1 # Change this to test with multiple images + image = np.zeros((num_images, 5, 5)) # Create a 3D array + image[:, 2, 2] = 1 # Place the star at the center for each image + return image + + +@pytest.fixture +def identity_mask(scope="module"): + """Creates a mask where all pixels are fully considered.""" + return np.ones((5, 5)) + + +@pytest.fixture +def multiple_images(scope="module"): + """Fixture for a batch of images with stars at different positions.""" + images = np.zeros((3, 5, 5)) # 3 images, each of size 5x5 + images[0, 2, 2] = 1 # Star at center of image 0 + images[1, 1, 3] = 1 # Star at (1, 3) in image 1 + images[2, 3, 1] = 1 # Star at (3, 1) in image 2 + return images + + @pytest.fixture(scope="module", params=[data]) def data_params(): return data diff --git a/src/wf_psf/tests/test_data/data_handler_test.py b/src/wf_psf/tests/test_data/data_handler_test.py new file mode 100644 index 00000000..85200b8a --- /dev/null +++ b/src/wf_psf/tests/test_data/data_handler_test.py @@ -0,0 +1,365 @@ +import pytest +import numpy as np +import tensorflow as tf +from wf_psf.data.data_handler import ( + DataHandler, + get_data_array, + extract_star_data, +) +from wf_psf.utils.read_config import RecursiveNamespace + + +def mock_sed(): + # Create a fake SED with shape (n_wavelengths,) — match what your real SEDs look like + return np.linspace(0.1, 1.0, 50) + + +def mock_sed(): + # Create a fake SED with shape (n_wavelengths,) — match what your real SEDs look like + return np.linspace(0.1, 1.0, 50) + + +def test_process_sed_data_auto_load(data_params, simPSF): + # load_data=True → dataset is used and SEDs processed automatically + data_handler = DataHandler( + "training", data_params.training, simPSF, n_bins_lambda=10, load_data=True + ) + assert data_handler.sed_data is not None + assert data_handler.sed_data.shape[1] == 10 # n_bins_lambda + + +def test_load_train_dataset(tmp_path, simPSF): + # Create a temporary directory and a temporary data file + data_dir = tmp_path / "data" + data_dir.mkdir() + temp_data_dir = data_dir / "train_data.npy" + + # Mock dataset + mock_dataset = { + "positions": np.array([[1, 2], [3, 4]]), + "noisy_stars": np.array([[5, 6], [7, 8]]), + "SEDs": np.array([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]]), + } + + # Save the mock dataset to the temporary data file + np.save(temp_data_dir, mock_dataset) + + # Initialize DataHandler instance + data_params = RecursiveNamespace(data_dir=str(data_dir), file="train_data.npy") + + n_bins_lambda = 10 + data_handler = DataHandler( + "training", data_params, simPSF, n_bins_lambda, load_data=False + ) + + # Call the load_dataset method + data_handler.load_dataset() + + # Assertions + assert np.array_equal(data_handler.dataset["positions"], mock_dataset["positions"]) + assert np.array_equal( + data_handler.dataset["noisy_stars"], mock_dataset["noisy_stars"] + ) + assert np.array_equal(data_handler.dataset["SEDs"], mock_dataset["SEDs"]) + + +def test_load_test_dataset(tmp_path, simPSF): + # Create a temporary directory and a temporary data file + data_dir = tmp_path / "data" + data_dir.mkdir() + temp_data_dir = data_dir / "test_data.npy" + + # Mock dataset + mock_dataset = { + "positions": np.array([[1, 2], [3, 4]]), + "stars": np.array([[5, 6], [7, 8]]), + "SEDs": np.array([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]]), + } + + # Save the mock dataset to the temporary data file + np.save(temp_data_dir, mock_dataset) + + # Initialize DataHandler instance + data_params = RecursiveNamespace(data_dir=str(data_dir), file="test_data.npy") + + n_bins_lambda = 10 + data_handler = DataHandler( + dataset_type="test", + data_params=data_params, + simPSF=simPSF, + n_bins_lambda=n_bins_lambda, + load_data=False, + ) + + # Call the load_dataset method + data_handler.load_dataset() + + # Assertions + assert np.array_equal(data_handler.dataset["positions"], mock_dataset["positions"]) + assert np.array_equal(data_handler.dataset["stars"], mock_dataset["stars"]) + assert np.array_equal(data_handler.dataset["SEDs"], mock_dataset["SEDs"]) + + +def test_validate_train_dataset_missing_noisy_stars_raises(tmp_path, simPSF): + """Test that validation raises an error if 'noisy_stars' is missing in training data.""" + data_dir = tmp_path / "data" + data_dir.mkdir() + temp_data_file = data_dir / "train_data.npy" + + mock_dataset = { + "positions": np.array([[1, 2], [3, 4]]), # No 'noisy_stars' key + "SEDs": np.array([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]]), + } + + np.save(temp_data_file, mock_dataset) + + data_params = RecursiveNamespace(data_dir=str(data_dir), file="train_data.npy") + + n_bins_lambda = 10 + data_handler = DataHandler( + "training", data_params, simPSF, n_bins_lambda, load_data=False + ) + + with pytest.raises( + ValueError, match="Missing required field 'noisy_stars' in training dataset." + ): + data_handler.load_dataset() + data_handler.validate_and_process_dataset() + + +def test_load_test_dataset_missing_stars(tmp_path, simPSF): + """Test that a warning is raised if 'stars' is missing in test data.""" + data_dir = tmp_path / "data" + data_dir.mkdir() + temp_data_file = data_dir / "test_data.npy" + + mock_dataset = { + "positions": np.array([[1, 2], [3, 4]]), # No 'stars' key + "SEDs": np.array([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]]), + } + + np.save(temp_data_file, mock_dataset) + + data_params = RecursiveNamespace(data_dir=str(data_dir), file="test_data.npy") + + n_bins_lambda = 10 + data_handler = DataHandler( + "test", data_params, simPSF, n_bins_lambda, load_data=False + ) + + with pytest.raises( + ValueError, match="Missing required field 'stars' in test dataset." + ): + data_handler.load_dataset() + data_handler.validate_and_process_dataset() + + +def test_extract_star_data_valid_keys(mock_data): + """Test extracting valid data from the dataset.""" + result = extract_star_data(mock_data, train_key="noisy_stars", test_key="stars") + + expected = tf.concat( + [ + tf.constant( + [np.arange(25).reshape(5, 5), np.arange(25, 50).reshape(5, 5)], + dtype=tf.float32, + ), + tf.constant([np.full((5, 5), 100), np.full((5, 5), 200)], dtype=tf.float32), + ], + axis=0, + ) + + np.testing.assert_array_equal(result, expected) + + +def test_extract_star_data_masks(mock_data): + """Test extracting star masks from the dataset.""" + result = extract_star_data(mock_data, train_key="masks", test_key="masks") + + mask0 = np.eye(5, dtype=np.float32) + mask1 = np.ones((5, 5), dtype=np.float32) + mask2 = np.zeros((5, 5), dtype=np.float32) + mask3 = np.tri(5, dtype=np.float32) + + expected = np.array([mask0, mask1, mask2, mask3], dtype=np.float32) + + np.testing.assert_array_equal(result, expected) + + +def test_extract_star_data_missing_key(mock_data): + """Test that the function raises a KeyError when a key is missing.""" + with pytest.raises(KeyError, match="Missing keys in dataset: \\['invalid_key'\\]"): + extract_star_data(mock_data, train_key="invalid_key", test_key="stars") + + +def test_extract_star_data_partially_missing_key(mock_data): + """Test that the function raises a KeyError if only one key is missing.""" + with pytest.raises( + KeyError, match="Missing keys in dataset: \\['missing_stars'\\]" + ): + extract_star_data(mock_data, train_key="noisy_stars", test_key="missing_stars") + + +def test_extract_star_data_tensor_conversion(mock_data): + """Test that the function properly converts TensorFlow tensors to NumPy arrays.""" + result = extract_star_data(mock_data, train_key="noisy_stars", test_key="stars") + assert isinstance(result, np.ndarray), "The result should be a NumPy array" + assert result.dtype == np.float32, "The NumPy array should have dtype float32" + + +def test_reference_shifts_broadcasting(): + reference_shifts = [-1 / 3, -1 / 3] # Example reference_shifts + shifts = np.random.rand(2, 2400) # Example shifts array + + # Ensure reference_shifts is a NumPy array (if it's not already) + reference_shifts = np.array(reference_shifts) + + # Broadcast reference_shifts to match the shape of shifts + reference_shifts = np.broadcast_to( + reference_shifts[:, None], shifts.shape + ) # Shape will be (2, 2400) + + # Ensure shapes are compatible for subtraction + displacements = reference_shifts - shifts + + # Test the result + assert displacements.shape == shifts.shape, "Shapes do not match" + assert np.all(displacements.shape == (2, 2400)), "Broadcasting failed" + + +@pytest.mark.parametrize( + "run_type,data_fixture,key,train_key,test_key,allow_missing,expect", + [ + # =================================================== + # training/simulation/metrics → extract_star_data path + # =================================================== + ( + "training", + "mock_data", + None, + "positions", + None, + False, + np.array([[1, 2], [3, 4], [5, 6], [7, 8]]), + ), + ( + "simulation", + "mock_data", + "none", + "noisy_stars", + "stars", + True, + # will concatenate noisy_stars from train and stars from test + # expected shape: (4, 5, 5) + # validate shape only, not full content (too large) + "shape:(4, 5, 5)", + ), + ( + "metrics", + "mock_data", + "zernike_prior", + None, + None, + True, + np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]]), + ), + # ================= + # inference (success) + # ================= + ( + "inference", + "mock_data_inference", + "positions", + None, + None, + False, + np.array([[9, 9], [10, 10]]), + ), + ( + "inference", + "mock_data_inference", + "zernike_prior", + None, + None, + False, + np.array([[0.9, 0.9]]), + ), + # ============================== + # inference → allow_missing=True + # ============================== + ( + "inference", + "mock_data_inference", + None, + None, + None, + True, + None, + ), + ( + "inference", + "mock_data_inference", + "missing_key", + None, + None, + True, + None, + ), + # ================================= + # inference → allow_missing=False → errors + # ================================= + ( + "inference", + "mock_data_inference", + None, + None, + None, + False, + pytest.raises(ValueError), + ), + ( + "inference", + "mock_data_inference", + "missing_key", + None, + None, + False, + pytest.raises(KeyError), + ), + ], +) +def test_get_data_array_v2( + request, run_type, data_fixture, key, train_key, test_key, allow_missing, expect +): + data = request.getfixturevalue(data_fixture) + + if hasattr(expect, "__enter__") and hasattr(expect, "__exit__"): + with expect: + get_data_array( + data, + run_type, + key=key, + train_key=train_key, + test_key=test_key, + allow_missing=allow_missing, + ) + return + + result = get_data_array( + data, + run_type, + key=key, + train_key=train_key, + test_key=test_key, + allow_missing=allow_missing, + ) + + if expect is None: + assert result is None + elif isinstance(expect, str) and expect.startswith("shape:"): + expected_shape = tuple(eval(expect.replace("shape:", ""))) + assert isinstance(result, np.ndarray) + assert result.shape == expected_shape + else: + assert isinstance(result, np.ndarray) + assert np.allclose(result, expect, rtol=1e-6, atol=1e-8) diff --git a/src/wf_psf/tests/test_data/data_zernike_utils_test.py b/src/wf_psf/tests/test_data/data_zernike_utils_test.py new file mode 100644 index 00000000..66d23309 --- /dev/null +++ b/src/wf_psf/tests/test_data/data_zernike_utils_test.py @@ -0,0 +1,543 @@ +import pytest +import numpy as np +from unittest.mock import MagicMock, patch +import tensorflow as tf +from wf_psf.data.data_zernike_utils import ( + ZernikeInputsFactory, + get_np_zernike_prior, + pad_contribution_to_order, + combine_zernike_contributions, + assemble_zernike_contributions, + compute_zernike_tip_tilt, + pad_tf_zernikes, +) +from types import SimpleNamespace as RecursiveNamespace + + +@pytest.fixture +def mock_model_params(): + return RecursiveNamespace( + use_prior=True, + correct_centroids=True, + add_ccd_misalignments=True, + param_hparams=RecursiveNamespace(n_zernikes=6), + ) + + +@pytest.fixture +def dummy_prior(): + return np.ones((4, 6), dtype=np.float32) + + +@pytest.fixture +def dummy_centroid_dataset(): + return {"training": "dummy_train", "test": "dummy_test"} + + +def test_training_without_prior(mock_model_params, mock_data): + mock_model_params.use_prior = False + + # Clear priors to simulate no prior being used + mock_data.training_data.dataset.pop("zernike_prior", None) + mock_data.test_data.dataset.pop("zernike_prior", None) + + zinputs = ZernikeInputsFactory.build( + data=mock_data, run_type="training", model_params=mock_model_params + ) + + mock_data_stamps = np.concatenate( + [ + mock_data.training_data.dataset["noisy_stars"], + mock_data.test_data.dataset["stars"], + ] + ) + mock_data_masks = np.concatenate( + [ + mock_data.training_data.dataset["masks"], + mock_data.test_data.dataset["masks"], + ] + ) + + assert np.allclose( + zinputs.centroid_dataset["stamps"], mock_data_stamps, rtol=1e-6, atol=1e-8 + ) + + assert np.allclose( + zinputs.centroid_dataset["masks"], mock_data_masks, rtol=1e-6, atol=1e-8 + ) + + assert zinputs.zernike_prior is None + + expected_positions = np.concatenate( + [ + mock_data.training_data.dataset["positions"], + mock_data.test_data.dataset["positions"], + ] + ) + np.testing.assert_array_equal(zinputs.misalignment_positions, expected_positions) + + +def test_training_with_dataset_prior(mock_model_params, mock_data): + mock_model_params.use_prior = True + + zinputs = ZernikeInputsFactory.build( + data=mock_data, run_type="training", model_params=mock_model_params + ) + + expected_priors = np.concatenate( + ( + mock_data.training_data.dataset["zernike_prior"], + mock_data.test_data.dataset["zernike_prior"], + ), + axis=0, + ) + np.testing.assert_array_equal(zinputs.zernike_prior, expected_priors) + + +def test_training_with_explicit_prior(mock_model_params, caplog): + mock_model_params.use_prior = True + data = MagicMock() + data.training_dataset = {"positions": np.ones((1, 2))} + data.test_dataset = {"positions": np.zeros((1, 2))} + + explicit_prior = np.array([9.0, 9.0, 9.0]) + + with caplog.at_level("WARNING"): + zinputs = ZernikeInputsFactory.build( + data, "training", mock_model_params, prior=explicit_prior + ) + + assert "Explicit prior provided; ignoring dataset-based prior." in caplog.text + assert (zinputs.zernike_prior == explicit_prior).all() + + +def test_inference_with_dict_and_prior(mock_model_params): + mock_model_params.use_prior = True + data = RecursiveNamespace( + dataset={ + "positions": tf.ones((5, 2)), + "zernike_prior": tf.constant([42.0, 0.0]), + } + ) + + zinputs = ZernikeInputsFactory.build(data, "inference", mock_model_params) + + for key in ["stamps", "masks"]: + assert zinputs.centroid_dataset[key] is None + + # NumPy array comparison + np.testing.assert_array_equal( + zinputs.misalignment_positions, data.dataset["positions"].numpy() + ) + + # TensorFlow tensor comparison + tf.debugging.assert_equal(zinputs.zernike_prior, data.dataset["zernike_prior"]) + + +def test_invalid_run_type(mock_model_params): + data = {"positions": np.ones((2, 2))} + with pytest.raises(ValueError, match="Unsupported run_type"): + ZernikeInputsFactory.build(data, "invalid_mode", mock_model_params) + + +def test_get_np_zernike_prior(): + # Mock training and test data + training_prior = np.array([[1, 2, 3], [4, 5, 6]]) + test_prior = np.array([[7, 8, 9]]) + + # Construct fake DataConfigHandler structure using RecursiveNamespace + data = RecursiveNamespace( + training_data=RecursiveNamespace(dataset={"zernike_prior": training_prior}), + test_data=RecursiveNamespace(dataset={"zernike_prior": test_prior}), + ) + + expected_prior = np.concatenate((training_prior, test_prior), axis=0) + + result = get_np_zernike_prior(data) + + # Assert shape and values match expected + np.testing.assert_array_equal(result, expected_prior) + + +def test_pad_contribution_to_order(): + # Input: batch of 2 samples, each with 3 Zernike coefficients + input_contribution = np.array( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ) + + max_order = 5 # Target size: pad to 5 coefficients + + expected_output = np.array( + [ + [1.0, 2.0, 3.0, 0.0, 0.0], + [4.0, 5.0, 6.0, 0.0, 0.0], + ] + ) + + padded = pad_contribution_to_order(input_contribution, max_order) + + assert padded.shape == (2, 5), "Output shape should match padded shape" + np.testing.assert_array_equal(padded, expected_output) + + +def test_no_padding_needed(): + """If current order equals max_order, return should be unchanged.""" + input_contribution = np.array([[1, 2, 3], [4, 5, 6]]) + max_order = 3 + output = pad_contribution_to_order(input_contribution, max_order) + assert output.shape == input_contribution.shape + np.testing.assert_array_equal(output, input_contribution) + + +def test_padding_to_much_higher_order(): + """Pad from order 2 to order 10.""" + input_contribution = np.array([[1, 2], [3, 4]]) + max_order = 10 + expected_output = np.hstack([input_contribution, np.zeros((2, 8))]) + output = pad_contribution_to_order(input_contribution, max_order) + assert output.shape == (2, 10) + np.testing.assert_array_equal(output, expected_output) + + +def test_empty_contribution(): + """Test behavior with empty input array (0 features).""" + input_contribution = np.empty((3, 0)) # 3 samples, 0 coefficients + max_order = 4 + expected_output = np.zeros((3, 4)) + output = pad_contribution_to_order(input_contribution, max_order) + assert output.shape == (3, 4) + np.testing.assert_array_equal(output, expected_output) + + +def test_zero_samples(): + """Test with zero samples (empty batch).""" + input_contribution = np.empty((0, 3)) # 0 samples, 3 coefficients + max_order = 5 + expected_output = np.empty((0, 5)) + output = pad_contribution_to_order(input_contribution, max_order) + assert output.shape == (0, 5) + np.testing.assert_array_equal(output, expected_output) + + +def test_combine_zernike_contributions_basic_case(): + """Combine two contributions with matching sample count and varying order.""" + contrib1 = np.array([[1, 2], [3, 4]]) # shape (2, 2) + contrib2 = np.array([[5], [6]]) # shape (2, 1) + expected = np.array([[1 + 5, 2 + 0], [3 + 6, 4 + 0]]) # padded contrib2 to (2, 2) + result = combine_zernike_contributions([contrib1, contrib2]) + np.testing.assert_array_equal(result, expected) + + +def test_combine_multiple_contributions(): + """Combine three contributions.""" + c1 = np.array([[1, 2, 3]]) # shape (1, 3) + c2 = np.array([[4, 5]]) # shape (1, 2) + c3 = np.array([[6]]) # shape (1, 1) + expected = np.array([[1 + 4 + 6, 2 + 5 + 0, 3 + 0 + 0]]) # shape (1, 3) + result = combine_zernike_contributions([c1, c2, c3]) + np.testing.assert_array_equal(result, expected) + + +def test_empty_input_list(): + """Raise ValueError when input list is empty.""" + with pytest.raises(ValueError, match="No contributions provided."): + combine_zernike_contributions([]) + + +def test_inconsistent_sample_count(): + """Raise error or produce incorrect shape if contributions have inconsistent sample counts.""" + c1 = np.array([[1, 2], [3, 4]]) # shape (2, 2) + c2 = np.array([[5, 6]]) # shape (1, 2) + with pytest.raises(ValueError): + combine_zernike_contributions([c1, c2]) + + +def test_single_contribution(): + """Combining a single contribution should return the same array (no-op).""" + contrib = np.array([[7, 8, 9], [10, 11, 12]]) + result = combine_zernike_contributions([contrib]) + np.testing.assert_array_equal(result, contrib) + + +def test_zero_order_contributions(): + """Contributions with 0 Zernike coefficients.""" + contrib1 = np.empty((2, 0)) # 2 samples, 0 coefficients + contrib2 = np.empty((2, 0)) + expected = np.empty((2, 0)) + result = combine_zernike_contributions([contrib1, contrib2]) + assert result.shape == (2, 0) + np.testing.assert_array_equal(result, expected) + + +@patch("wf_psf.data.data_zernike_utils.compute_centroid_correction") +@patch("wf_psf.data.data_zernike_utils.compute_ccd_misalignment") +def test_full_contribution_combination( + mock_ccd, mock_centroid, mock_model_params, dummy_prior, dummy_centroid_dataset +): + mock_centroid.return_value = np.full((4, 6), 2.0) + mock_ccd.return_value = np.full((4, 6), 3.0) + dummy_positions = np.full((4, 6), 1.0) + + result = assemble_zernike_contributions( + model_params=mock_model_params, + zernike_prior=dummy_prior, + centroid_dataset=dummy_centroid_dataset, + positions=dummy_positions, + ) + + expected = dummy_prior + 2.0 + 3.0 + np.testing.assert_allclose(result.numpy(), expected) + + +def test_prior_only(mock_model_params, dummy_prior): + mock_model_params.correct_centroids = False + mock_model_params.add_ccd_misalignments = False + + result = assemble_zernike_contributions( + model_params=mock_model_params, + zernike_prior=dummy_prior, + centroid_dataset=None, + positions=None, + ) + + np.testing.assert_array_equal(result.numpy(), dummy_prior) + + +def test_no_contributions_returns_zeros(): + model_params = RecursiveNamespace( + use_prior=False, + correct_centroids=False, + add_ccd_misalignments=False, + param_hparams=RecursiveNamespace(n_zernikes=8), + ) + + result = assemble_zernike_contributions(model_params) + + assert isinstance(result, tf.Tensor) + assert result.shape == (1, 8) + np.testing.assert_array_equal(result.numpy(), np.zeros((1, 8))) + + +def test_prior_as_tensor(mock_model_params): + tensor_prior = tf.ones((4, 6), dtype=tf.float32) + + mock_model_params.correct_centroids = False + mock_model_params.add_ccd_misalignments = False + + result = assemble_zernike_contributions( + model_params=mock_model_params, zernike_prior=tensor_prior + ) + assert tf.executing_eagerly(), "TensorFlow must be in eager mode for this test" + assert isinstance(result, tf.Tensor) + np.testing.assert_array_equal(result.numpy(), np.ones((4, 6))) + + +@patch("wf_psf.data.data_zernike_utils.compute_centroid_correction") +def test_inconsistent_shapes_raises_error( + mock_centroid, mock_model_params, dummy_prior, dummy_centroid_dataset +): + mock_model_params.add_ccd_misalignments = False + mock_centroid.return_value = np.ones((5, 6)) # 5 samples instead of 4 + + with pytest.raises( + ValueError, match="All contributions must have the same number of samples" + ): + assemble_zernike_contributions( + model_params=mock_model_params, + zernike_prior=dummy_prior, + centroid_dataset=dummy_centroid_dataset, + positions=None, + ) + + +def test_pad_zernikes_num_of_zernikes_equal(): + # Prepare your test tensors + zk_param = tf.constant([[[[1.0]]], [[[2.0]]]]) # Shape (2, 1, 1, 1) + zk_prior = tf.constant([[[[1.0]]], [[[2.0]]]]) # Same shape + + # Reshape to (1, 2, 1, 1) + zk_param = tf.reshape(zk_param, (1, 2, 1, 1)) + zk_prior = tf.reshape(zk_prior, (1, 2, 1, 1)) + + # Reset _n_zks_total to max number of zernikes (2 here) + n_zks_total = max(tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy()) + + # Call pad_zernikes method + padded_zk_param, padded_zk_prior = pad_tf_zernikes(zk_param, zk_prior, n_zks_total) + + # Assert shapes are equal and correct + assert padded_zk_param.shape[1] == n_zks_total + assert padded_zk_prior.shape[1] == n_zks_total + + # If num zernikes already equal, output should be unchanged + np.testing.assert_array_equal(padded_zk_param.numpy(), zk_param.numpy()) + np.testing.assert_array_equal(padded_zk_prior.numpy(), zk_prior.numpy()) + + +def test_pad_zernikes_prior_greater_than_param(): + zk_param = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) + zk_prior = tf.constant([[[[1]], [[2]], [[3]], [[4]], [[5]]]]) # Shape: (5, 1, 1, 1) + + # Reshape the tensors to have the desired shapes + zk_param = tf.reshape(zk_param, (1, 2, 1, 1)) # Reshaping tensor1 to (1, 2, 1, 1) + zk_prior = tf.reshape(zk_prior, (1, 5, 1, 1)) # Reshaping tensor2 to (1, 5, 1, 1) + + # Reset n_zks_total attribute + n_zks_total = max(tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy()) + + # Call the method under test + padded_zk_param, padded_zk_prior = pad_tf_zernikes(zk_param, zk_prior, n_zks_total) + + # Assert that the padded tensors have the correct shapes + assert padded_zk_param.shape == (1, 5, 1, 1) + assert padded_zk_prior.shape == (1, 5, 1, 1) + + +def test_pad_zernikes_param_greater_than_prior(): + zk_param = tf.constant([[[[10]], [[20]], [[30]], [[40]]]]) # Shape: (4, 1, 1, 1) + zk_prior = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) + + # Reshape the tensors to have the desired shapes + zk_param = tf.reshape(zk_param, (1, 4, 1, 1)) # Reshaping tensor1 to (1, 2, 1, 1) + zk_prior = tf.reshape(zk_prior, (1, 2, 1, 1)) # Reshaping tensor2 to (1, 4, 1, 1) + + # Reset n_zks_total attribute + n_zks_total = max(tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy()) + + # Call the method under test + padded_zk_param, padded_zk_prior = pad_tf_zernikes(zk_param, zk_prior, n_zks_total) + + # Assert that the padded tensors have the correct shapes + assert padded_zk_param.shape == (1, 4, 1, 1) + assert padded_zk_prior.shape == (1, 4, 1, 1) + + +def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_mask): + """Test compute_zernike_tip_tilt handling with single batch input and mocks.""" + # Mock the CentroidEstimator class + mock_centroid_calc = mocker.patch( + "wf_psf.data.centroids.CentroidEstimator", autospec=True + ) + + # Create a mock instance and configure get_intra_pixel_shifts() + mock_instance = mock_centroid_calc.return_value + mock_instance.get_intra_pixel_shifts.return_value = np.array( + [[0.05, -0.02]] + ) # Shape (1, 2) + + # Mock shift_x_y_to_zk1_2_wavediff to return predictable values + mock_shift_fn = mocker.patch( + "wf_psf.data.data_zernike_utils.shift_x_y_to_zk1_2_wavediff", + side_effect=lambda shift: shift * 0.5, # Mocked conversion for test + ) + + # Define test inputs (batch of 1 image) + pixel_sampling = 12e-6 + reference_shifts = [-1 / 3, -1 / 3] # Default Euclid conditions + + # Run the function + zernike_corrections = compute_zernike_tip_tilt( + simple_image, identity_mask, pixel_sampling, reference_shifts + ) + zernike_corrections = compute_zernike_tip_tilt( + simple_image, identity_mask, pixel_sampling, reference_shifts + ) + + # Expected shifts based on centroid calculation + expected_dx = reference_shifts[1] - (-0.02) # Expected x-axis shift in meters + expected_dy = reference_shifts[0] - 0.05 # Expected y-axis shift in meters + + # Expected calls to the mocked function + # Extract the arguments passed to mock_shift_fn + args, _ = mock_shift_fn.call_args_list[0] # Get the first call args + + # Compare expected values with the actual arguments passed to the mock function + np.testing.assert_allclose( + args[0][0], expected_dx * pixel_sampling, rtol=1e-7, atol=0 + ) + + # Check dy values similarly + np.testing.assert_allclose( + args[0][1], expected_dy * pixel_sampling, rtol=1e-7, atol=0 + ) + + # Expected values based on mock side_effect (0.5 * shift) + np.testing.assert_allclose( + zernike_corrections[0, 0], expected_dx * pixel_sampling * 0.5 + ) # Zk1 + np.testing.assert_allclose( + zernike_corrections[0, 1], expected_dy * pixel_sampling * 0.5 + ) # Zk2 + + +def test_compute_zernike_tip_tilt_batch(mocker, multiple_images): + """Test compute_zernike_tip_tilt batch handling of multiple inputs.""" + # Mock the CentroidEstimator class + mock_centroid_calc = mocker.patch( + "wf_psf.data.centroids.CentroidEstimator", autospec=True + ) + + # Create a mock instance and configure get_intra_pixel_shifts() + mock_instance = mock_centroid_calc.return_value + mock_instance.get_intra_pixel_shifts.return_value = np.array( + [[0.05, -0.02], [0.04, -0.01], [0.06, -0.03]] + ) # Shape (3, 2) + + # Mock shift_x_y_to_zk1_2_wavediff to return predictable values + mock_shift_fn = mocker.patch( + "wf_psf.data.data_zernike_utils.shift_x_y_to_zk1_2_wavediff", + side_effect=lambda shift: shift * 0.5, # Mocked conversion for test + ) + + # Define test inputs (batch of 3 images) + pixel_sampling = 12e-6 + reference_shifts = [-1 / 3, -1 / 3] # Default Euclid conditions + + # Run the function + zernike_corrections = compute_zernike_tip_tilt( + star_images=multiple_images, + pixel_sampling=pixel_sampling, + reference_shifts=reference_shifts, + ) + + # Check if the mock function was called once with the full batch + assert ( + len(mock_shift_fn.call_args_list) == 1 + ), f"Expected 1 call, but got {len(mock_shift_fn.call_args_list)}" + + # Get the arguments passed to the mock function for the batch of images + args, _ = mock_shift_fn.call_args_list[0] + + print("Shape of args[0]:", args[0].shape) + print("Contents of args[0]:", args[0]) + print("Mock function call args list:", mock_shift_fn.call_args_list) + + # Reshape args[0] to (N, 2) for batch processing + args_array = np.array(args[0]).reshape(-1, 2) + + # Process the displacements and expected values for each image in the batch + expected_dx = ( + reference_shifts[1] - mock_instance.get_intra_pixel_shifts.return_value[:, 1] + ) # Expected x-axis shift in meters + expected_dy = ( + reference_shifts[0] - mock_instance.get_intra_pixel_shifts.return_value[:, 0] + ) # Expected y-axis shift in meters + + # Compare expected values with the actual arguments passed to the mock function + np.testing.assert_allclose( + args_array[:, 0], expected_dx * pixel_sampling, rtol=1e-7, atol=0 + ) + np.testing.assert_allclose( + args_array[:, 1], expected_dy * pixel_sampling, rtol=1e-7, atol=0 + ) + + # Expected values based on mock side_effect (0.5 * shift) + np.testing.assert_allclose( + zernike_corrections[:, 0], expected_dx * pixel_sampling * 0.5 + ) # Zk1 for each image + np.testing.assert_allclose( + zernike_corrections[:, 1], expected_dy * pixel_sampling * 0.5 + ) # Zk2 for each image diff --git a/src/wf_psf/tests/test_data/test_data_utils.py b/src/wf_psf/tests/test_data/test_data_utils.py new file mode 100644 index 00000000..de111427 --- /dev/null +++ b/src/wf_psf/tests/test_data/test_data_utils.py @@ -0,0 +1,36 @@ +class MockDataset: + def __init__(self, positions, zernike_priors, star_type, stars, masks): + self.dataset = { + "positions": positions, + "zernike_prior": zernike_priors, + star_type: stars, + "masks": masks, + } + + +class MockData: + def __init__( + self, + training_positions, + test_positions, + training_zernike_priors=None, + test_zernike_priors=None, + noisy_stars=None, + noisy_masks=None, + stars=None, + masks=None, + ): + self.training_data = MockDataset( + positions=training_positions, + zernike_priors=training_zernike_priors, + star_type="noisy_stars", + stars=noisy_stars, + masks=noisy_masks, + ) + self.test_data = MockDataset( + positions=test_positions, + zernike_priors=test_zernike_priors, + star_type="stars", + stars=stars, + masks=masks, + ) diff --git a/src/wf_psf/tests/test_data/training_preprocessing_test.py b/src/wf_psf/tests/test_data/training_preprocessing_test.py deleted file mode 100644 index 6769d0e5..00000000 --- a/src/wf_psf/tests/test_data/training_preprocessing_test.py +++ /dev/null @@ -1,368 +0,0 @@ -import pytest -import numpy as np -import tensorflow as tf -from wf_psf.utils.read_config import RecursiveNamespace -from wf_psf.data.training_preprocessing import ( - DataHandler, - get_obs_positions, - get_zernike_prior, - extract_star_data, - compute_centroid_correction, -) -import logging -from unittest.mock import patch - -class MockData: - def __init__( - self, - training_positions, - test_positions, - training_zernike_priors, - test_zernike_priors, - noisy_stars=None, - noisy_masks=None, - stars=None, - masks=None, - ): - self.training_data = MockDataset( - positions=training_positions, - zernike_priors=training_zernike_priors, - star_type="noisy_stars", - stars=noisy_stars, - masks=noisy_masks) - self.test_data = MockDataset( - positions=test_positions, - zernike_priors=test_zernike_priors, - star_type="stars", - stars=stars, - masks=masks) - - -class MockDataset: - def __init__(self, positions, zernike_priors, star_type, stars, masks): - self.dataset = {"positions": positions, "zernike_prior": zernike_priors, star_type: stars, "masks": masks} - - -@pytest.fixture -def mock_data(): - # Mock data for testing - # Mock training and test positions and Zernike priors - training_positions = np.array([[1, 2], [3, 4]]) - test_positions = np.array([[5, 6], [7, 8]]) - training_zernike_priors = np.array([[0.1, 0.2], [0.3, 0.4]]) - test_zernike_priors = np.array([[0.5, 0.6], [0.7, 0.8]]) - # Mock noisy stars, stars and masks - noisy_stars = tf.constant([[1, 2], [3, 4]], dtype=tf.float32) - noisy_masks = tf.constant([[1], [0]], dtype=tf.float32) - stars = tf.constant([[5, 6], [7, 8]], dtype=tf.float32) - masks = tf.constant([[0], [1]], dtype=tf.float32) - - return MockData( - training_positions, test_positions, training_zernike_priors, test_zernike_priors, noisy_stars, noisy_masks, stars, masks - ) - - -def test_process_sed_data(data_params, simPSF): - # Test processing SED data without initialization - data_handler = DataHandler( - "training", data_params, simPSF, n_bins_lambda=10, load_data=False - ) - assert data_handler.sed_data is None # SED data should not be processed - - # Test processing SED data with initialization - data_handler = DataHandler( - "training", data_params, simPSF, n_bins_lambda=10, load_data=True - ) - assert data_handler.sed_data is not None # SED data should be processed - - -def test_load_train_dataset(tmp_path, data_params, simPSF): - # Create a temporary directory and a temporary data file - data_dir = tmp_path / "data" - data_dir.mkdir() - temp_data_dir = data_dir / "train_data.npy" - - # Mock dataset - mock_dataset = { - "positions": np.array([[1, 2], [3, 4]]), - "noisy_stars": np.array([[5, 6], [7, 8]]), - "SEDs": np.array([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]]), - } - - # Save the mock dataset to the temporary data file - np.save(temp_data_dir, mock_dataset) - - # Initialize DataHandler instance - data_params = RecursiveNamespace( - training=RecursiveNamespace(data_dir=str(data_dir), file="train_data.npy") - ) - - n_bins_lambda = 10 - data_handler = DataHandler("training", data_params, simPSF, n_bins_lambda, load_data=False) - - # Call the load_dataset method - data_handler.load_dataset() - - # Assertions - assert np.array_equal(data_handler.dataset["positions"], mock_dataset["positions"]) - assert np.array_equal( - data_handler.dataset["noisy_stars"], mock_dataset["noisy_stars"] - ) - assert np.array_equal(data_handler.dataset["SEDs"], mock_dataset["SEDs"]) - - -def test_load_test_dataset(tmp_path, data_params, simPSF): - # Create a temporary directory and a temporary data file - data_dir = tmp_path / "data" - data_dir.mkdir() - temp_data_dir = data_dir / "test_data.npy" - - # Mock dataset - mock_dataset = { - "positions": np.array([[1, 2], [3, 4]]), - "stars": np.array([[5, 6], [7, 8]]), - "SEDs": np.array([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]]), - } - - # Save the mock dataset to the temporary data file - np.save(temp_data_dir, mock_dataset) - - # Initialize DataHandler instance - data_params = RecursiveNamespace( - test=RecursiveNamespace(data_dir=str(data_dir), file="test_data.npy") - ) - - n_bins_lambda = 10 - data_handler = DataHandler("test", data_params, simPSF, n_bins_lambda, load_data=False) - - # Call the load_dataset method - data_handler.load_dataset() - - # Assertions - assert np.array_equal(data_handler.dataset["positions"], mock_dataset["positions"]) - assert np.array_equal(data_handler.dataset["stars"], mock_dataset["stars"]) - assert np.array_equal(data_handler.dataset["SEDs"], mock_dataset["SEDs"]) - - -def test_load_train_dataset_missing_noisy_stars(tmp_path, data_params, simPSF): - """Test that a warning is raised if 'noisy_stars' is missing in training data.""" - data_dir = tmp_path / "data" - data_dir.mkdir() - temp_data_file = data_dir / "train_data.npy" - - mock_dataset = { - "positions": np.array([[1, 2], [3, 4]]), # No 'noisy_stars' key - "SEDs": np.array([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]]), - } - - np.save(temp_data_file, mock_dataset) - - data_params = RecursiveNamespace( - training=RecursiveNamespace(data_dir=str(data_dir), file="train_data.npy") - ) - - n_bins_lambda = 10 - data_handler = DataHandler("training", data_params, simPSF, n_bins_lambda, load_data=False) - - with patch("wf_psf.data.training_preprocessing.logger.warning") as mock_warning: - data_handler.load_dataset() - mock_warning.assert_called_with("Missing 'noisy_stars' in training dataset.") - -def test_load_test_dataset_missing_stars(tmp_path, data_params, simPSF): - """Test that a warning is raised if 'stars' is missing in test data.""" - data_dir = tmp_path / "data" - data_dir.mkdir() - temp_data_file = data_dir / "test_data.npy" - - mock_dataset = { - "positions": np.array([[1, 2], [3, 4]]), # No 'stars' key - "SEDs": np.array([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]]), - } - - np.save(temp_data_file, mock_dataset) - - data_params = RecursiveNamespace( - test=RecursiveNamespace(data_dir=str(data_dir), file="test_data.npy") - ) - - n_bins_lambda = 10 - data_handler = DataHandler("test", data_params, simPSF, n_bins_lambda, load_data=False) - - with patch("wf_psf.data.training_preprocessing.logger.warning") as mock_warning: - data_handler.load_dataset() - mock_warning.assert_called_with("Missing 'stars' in test dataset.") - - -def test_process_sed_data(data_params, simPSF): - mock_dataset = { - "positions": np.array([[1, 2], [3, 4]]), - "noisy_stars": np.array([[5, 6], [7, 8]]), - "SEDs": np.array([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]]), - } - # Initialize DataHandler instance - n_bins_lambda = 4 - data_handler = DataHandler("training", data_params, simPSF, n_bins_lambda, False) - - data_handler.dataset = mock_dataset - data_handler.process_sed_data() - # Assertions - assert isinstance(data_handler.sed_data, tf.Tensor) - assert data_handler.sed_data.dtype == tf.float32 - assert data_handler.sed_data.shape == ( - len(data_handler.dataset["positions"]), - n_bins_lambda, - len(["feasible_N", "feasible_wv", "SED_norm"]), - ) - - -def test_get_obs_positions(mock_data): - observed_positions = get_obs_positions(mock_data) - expected_positions = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) - assert tf.reduce_all(tf.equal(observed_positions, expected_positions)) - - -def test_get_zernike_prior(model_params, mock_data): - zernike_priors = get_zernike_prior(model_params, mock_data) - expected_shape = ( - 4, - 2, - ) # Assuming 2 Zernike priors for each dataset (training and test) - assert zernike_priors.shape == expected_shape - - -def test_get_zernike_prior_dtype(model_params, mock_data): - zernike_priors = get_zernike_prior(model_params, mock_data) - assert zernike_priors.dtype == np.float32 - - -def test_get_zernike_prior_concatenation(model_params, mock_data): - zernike_priors = get_zernike_prior(model_params, mock_data) - expected_zernike_priors = tf.convert_to_tensor( - np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]]), dtype=tf.float32 - ) - - assert np.array_equal(zernike_priors, expected_zernike_priors) - - -def test_get_zernike_prior_empty_data(model_params): - empty_data = MockData(np.array([]), np.array([]), np.array([]), np.array([])) - zernike_priors = get_zernike_prior(model_params, empty_data) - assert zernike_priors.shape == tf.TensorShape([0]) # Check for empty array shape - -def test_extract_star_data_valid_keys(mock_data): - """Test extracting valid data from the dataset.""" - result = extract_star_data(mock_data, train_key="noisy_stars", test_key="stars") - - expected = np.array([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=np.float32) - np.testing.assert_array_equal(result, expected) - -def test_extract_star_data_masks(mock_data): - """Test extracting star masks from the dataset.""" - result = extract_star_data(mock_data, train_key="masks", test_key="masks") - - expected = np.array([[1], [0], [0], [1]], dtype=np.float32) - np.testing.assert_array_equal(result, expected) - -def test_extract_star_data_missing_key(mock_data): - """Test that the function raises a KeyError when a key is missing.""" - with pytest.raises(KeyError, match="Missing keys in dataset: \\['invalid_key'\\]"): - extract_star_data(mock_data, train_key="invalid_key", test_key="stars") - -def test_extract_star_data_partially_missing_key(mock_data): - """Test that the function raises a KeyError if only one key is missing.""" - with pytest.raises(KeyError, match="Missing keys in dataset: \\['missing_stars'\\]"): - extract_star_data(mock_data, train_key="noisy_stars", test_key="missing_stars") - - -def test_extract_star_data_tensor_conversion(mock_data): - """Test that the function properly converts TensorFlow tensors to NumPy arrays.""" - result = extract_star_data(mock_data, train_key="noisy_stars", test_key="stars") - - assert isinstance(result, np.ndarray), "The result should be a NumPy array" - assert result.dtype == np.float32, "The NumPy array should have dtype float32" - -def test_compute_centroid_correction_with_masks(mock_data): - """Test compute_centroid_correction function with masks present.""" - # Given that compute_centroid_correction expects a model_params and data object - model_params = RecursiveNamespace( - pix_sampling=12e-6, # Example pixel sampling in meters - correct_centroids=True, - reference_shifts=["-1/3", "-1/3"] - ) - - # Mock the internal function calls: - with patch('wf_psf.data.training_preprocessing.extract_star_data') as mock_extract_star_data, \ - patch('wf_psf.data.training_preprocessing.compute_zernike_tip_tilt') as mock_compute_zernike_tip_tilt: - - # Mock the return values of extract_star_data and compute_zernike_tip_tilt - mock_extract_star_data.side_effect = lambda data, train_key, test_key: ( - np.array([[1, 2], [3, 4]]) if train_key == 'noisy_stars' else np.array([[5, 6], [7, 8]]) - ) - mock_compute_zernike_tip_tilt.return_value = np.array([[0.1, 0.2], [0.3, 0.4]]) - - # Call the function under test - result = compute_centroid_correction(model_params, mock_data) - - # Ensure the result has the correct shape - assert result.shape == (4, 3) # Should be (n_stars, 3 Zernike components) - - assert np.allclose(result[0, :], np.array([0, -0.1, -0.2])) # First star Zernike coefficients - assert np.allclose(result[1, :], np.array([0, -0.3, -0.4])) # Second star Zernike coefficients - - -def test_compute_centroid_correction_without_masks(mock_data): - """Test compute_centroid_correction function when no masks are provided.""" - # Remove masks from mock_data - mock_data.test_data.dataset["masks"] = None - mock_data.training_data.dataset["masks"] = None - - # Define model parameters - model_params = RecursiveNamespace( - pix_sampling=12e-6, # Example pixel sampling in meters - correct_centroids=True, - reference_shifts=["-1/3", "-1/3"] - ) - - # Mock internal function calls - with patch('wf_psf.data.training_preprocessing.extract_star_data') as mock_extract_star_data, \ - patch('wf_psf.data.training_preprocessing.compute_zernike_tip_tilt') as mock_compute_zernike_tip_tilt: - - # Mock extract_star_data to return synthetic star postage stamps - mock_extract_star_data.side_effect = lambda data, train_key, test_key: ( - np.array([[1, 2], [3, 4]]) if train_key == 'noisy_stars' else np.array([[5, 6], [7, 8]]) - ) - - # Mock compute_zernike_tip_tilt assuming no masks - mock_compute_zernike_tip_tilt.return_value = np.array([[0.1, 0.2], [0.3, 0.4]]) - - # Call function under test - result = compute_centroid_correction(model_params, mock_data) - - # Validate result shape - assert result.shape == (4, 3) # (n_stars, 3 Zernike components) - - # Validate expected values (adjust based on behavior) - expected_result = np.array([ - [0, -0.1, -0.2], # From training data - [0, -0.3, -0.4], - [0, -0.1, -0.2], # From test data (reused mocked return) - [0, -0.3, -0.4] - ]) - assert np.allclose(result, expected_result) - -def test_reference_shifts_broadcasting(): - reference_shifts = [-1/3, -1/3] # Example reference_shifts - shifts = np.random.rand(2, 2400) # Example shifts array - - # Ensure reference_shifts is a NumPy array (if it's not already) - reference_shifts = np.array(reference_shifts) - - # Broadcast reference_shifts to match the shape of shifts - reference_shifts = np.broadcast_to(reference_shifts[:, None], shifts.shape) # Shape will be (2, 2400) - - # Ensure shapes are compatible for subtraction - displacements = reference_shifts - shifts - - # Test the result - assert displacements.shape == shifts.shape, "Shapes do not match" - assert np.all(displacements.shape == (2, 2400)), "Broadcasting failed" diff --git a/src/wf_psf/tests/test_inference/psf_inference_test.py b/src/wf_psf/tests/test_inference/psf_inference_test.py new file mode 100644 index 00000000..28a2c1af --- /dev/null +++ b/src/wf_psf/tests/test_inference/psf_inference_test.py @@ -0,0 +1,485 @@ +"""UNIT TESTS FOR PACKAGE MODULE: PSF Inference. + +This module contains unit tests for the wf_psf.inference.psf_inference module. + +:Author: Jennifer Pollack + +""" + +import numpy as np +import os +from pathlib import Path +import pytest +import tensorflow as tf +from types import SimpleNamespace +from unittest.mock import MagicMock, patch, PropertyMock +from wf_psf.inference.psf_inference import ( + InferenceConfigHandler, + PSFInference, + PSFInferenceEngine, +) + +from wf_psf.utils.read_config import RecursiveNamespace + + +def _patch_data_handler(): + """Helper for patching data_handler to avoid full PSF logic.""" + patcher = patch.object(PSFInference, "data_handler", new_callable=PropertyMock) + mock_data_handler = patcher.start() + mock_instance = MagicMock() + mock_data_handler.return_value = mock_instance + + def fake_process(x): + mock_instance.sed_data = tf.convert_to_tensor(x) + + mock_instance.process_sed_data.side_effect = fake_process + return patcher, mock_instance + + +@pytest.fixture +def mock_training_config(): + training_config = RecursiveNamespace( + training=RecursiveNamespace( + id_name="mock_id", + model_params=RecursiveNamespace( + model_name="mock_model", + output_Q=2, + output_dim=32, + pupil_diameter=256, + oversampling_rate=3, + interpolation_type=None, + interpolation_args=None, + sed_interp_pts_per_bin=0, + sed_extrapolate=True, + sed_interp_kind="linear", + sed_sigma=0, + x_lims=[0.0, 1000.0], + y_lims=[0.0, 1000.0], + pix_sampling=12, + tel_diameter=1.2, + tel_focal_length=24.5, + euclid_obsc=True, + LP_filter_length=3, + param_hparams=RecursiveNamespace( + n_zernikes=10, + ), + ), + ) + ) + return training_config + + +@pytest.fixture +def mock_inference_config(): + inference_config = RecursiveNamespace( + inference=RecursiveNamespace( + batch_size=16, + cycle=2, + configs=RecursiveNamespace( + trained_model_path="/path/to/trained/model", + model_subdir="psf_model", + trained_model_config_path="config/training_config.yaml", + data_config_path=None, + ), + model_params=RecursiveNamespace(n_bins_lda=8, output_Q=1, output_dim=64), + ) + ) + return inference_config + + +@pytest.fixture +def psf_test_setup(mock_inference_config): + num_sources = 2 + num_bins = 10 + output_dim = 32 + + mock_positions = tf.convert_to_tensor([[0.1, 0.1], [0.2, 0.2]], dtype=tf.float32) + mock_seds = tf.convert_to_tensor( + np.random.rand(num_sources, num_bins, 2), dtype=tf.float32 + ) + expected_psfs = np.random.rand(num_sources, output_dim, output_dim).astype( + np.float32 + ) + + inference = PSFInference( + "dummy_path.yaml", + x_field=[0.1, 0.2], + y_field=[0.1, 0.2], + seds=np.random.rand(num_sources, num_bins, 2), + ) + inference._config_handler = MagicMock() + inference._config_handler.inference_config = mock_inference_config + inference._trained_psf_model = MagicMock() + + return { + "inference": inference, + "mock_positions": mock_positions, + "mock_seds": mock_seds, + "expected_psfs": expected_psfs, + "num_sources": num_sources, + "num_bins": num_bins, + "output_dim": output_dim, + } + + +@pytest.fixture +def psf_single_star_setup(mock_inference_config): + num_sources = 1 + num_bins = 10 + output_dim = 32 + + # Single position + mock_positions = tf.convert_to_tensor([[0.1, 0.1]], dtype=tf.float32) + # Shape (1, 2, num_bins) + mock_seds = tf.convert_to_tensor( + np.random.rand(num_sources, 2, num_bins), dtype=tf.float32 + ) + expected_psfs = np.random.rand(num_sources, output_dim, output_dim).astype( + np.float32 + ) + + inference = PSFInference( + "dummy_path.yaml", + x_field=0.1, # scalar for single star + y_field=0.1, + seds=np.random.rand(num_bins, 2), # shape (num_bins, 2) before batching + ) + inference._config_handler = MagicMock() + inference._config_handler.inference_config = mock_inference_config + inference._trained_psf_model = MagicMock() + + return { + "inference": inference, + "mock_positions": mock_positions, + "mock_seds": mock_seds, + "expected_psfs": expected_psfs, + "num_sources": num_sources, + "num_bins": num_bins, + "output_dim": output_dim, + } + + +def test_set_config_paths(mock_inference_config): + """Test setting configuration paths.""" + # Initialize handler and inject mock config + config_handler = InferenceConfigHandler("fake/path") + config_handler.inference_config = mock_inference_config + + # Call the method under test + config_handler.set_config_paths() + + # Assertions + assert config_handler.trained_model_path == Path("/path/to/trained/model") + assert config_handler.model_subdir == "psf_model" + assert config_handler.trained_model_config_path == Path( + "/path/to/trained/model/config/training_config.yaml" + ) + assert config_handler.data_config_path == None + + +def test_overwrite_model_params(mock_training_config, mock_inference_config): + """Test that model_params can be overwritten.""" + # Mock the model_params object with some initial values + training_config = mock_training_config + inference_config = mock_inference_config + + InferenceConfigHandler.overwrite_model_params(training_config, inference_config) + + # Assert that the model_params were overwritten correctly + assert ( + training_config.training.model_params.output_Q == 1 + ), "output_Q should be overwritten" + assert ( + training_config.training.model_params.output_dim == 64 + ), "output_dim should be overwritten" + + assert ( + training_config.training.id_name == "mock_id" + ), "id_name should not be overwritten" + + +def test_prepare_configs(mock_training_config, mock_inference_config): + """Test preparing configurations for inference.""" + # Mock the model_params object with some initial values + training_config = mock_training_config + inference_config = mock_inference_config + + # Make copy of the original training config model_params + original_model_params = mock_training_config.training.model_params + + # Instantiate PSFInference + psf_inf = PSFInference("/dummy/path.yaml") + + # Mock the config handler attribute with a mock InferenceConfigHandler + mock_config_handler = MagicMock(spec=InferenceConfigHandler) + mock_config_handler.training_config = training_config + mock_config_handler.inference_config = inference_config + + # Patch the overwrite_model_params to use the real static method + mock_config_handler.overwrite_model_params.side_effect = ( + InferenceConfigHandler.overwrite_model_params + ) + + psf_inf._config_handler = mock_config_handler + + # Run prepare_configs + psf_inf.prepare_configs() + + # Assert that the training model_params were updated + assert original_model_params.output_Q == 1 + assert original_model_params.output_dim == 64 + + +def test_config_handler_lazy_load(monkeypatch): + inference = PSFInference("dummy_path.yaml") + + called = {} + + class DummyHandler: + def load_configs(self): + called["load"] = True + self.inference_config = {} + self.training_config = {} + self.data_config = {} + + def overwrite_model_params(self, *args): + pass + + monkeypatch.setattr( + "wf_psf.inference.psf_inference.InferenceConfigHandler", + lambda path: DummyHandler(), + ) + + inference.prepare_configs() + + assert "load" in called # Confirm lazy load happened + + +def test_batch_size_positive(): + inference = PSFInference("dummy_path.yaml") + inference._config_handler = MagicMock() + inference._config_handler.inference_config = SimpleNamespace( + inference=SimpleNamespace( + batch_size=4, model_params=SimpleNamespace(output_dim=32) + ) + ) + assert inference.batch_size == 4 + + +@patch("wf_psf.inference.psf_inference.DataHandler") +@patch("wf_psf.inference.psf_inference.load_trained_psf_model") +def test_load_inference_model( + mock_load_trained_psf_model, + mock_data_handler, + mock_training_config, + mock_inference_config, +): + mock_data_config = MagicMock() + mock_data_handler.return_value = mock_data_config + mock_config_handler = MagicMock(spec=InferenceConfigHandler) + mock_config_handler.trained_model_path = "mock/path/to/model" + mock_config_handler.training_config = mock_training_config + mock_config_handler.inference_config = mock_inference_config + mock_config_handler.model_subdir = "psf_model" + mock_config_handler.data_config = MagicMock() + + psf_inf = PSFInference("dummy_path.yaml") + psf_inf._config_handler = mock_config_handler + + psf_inf.load_inference_model() + + weights_path_pattern = os.path.join( + mock_config_handler.trained_model_path, + mock_config_handler.model_subdir, + f"{mock_config_handler.model_subdir}*_{mock_config_handler.training_config.training.model_params.model_name}*{mock_config_handler.training_config.training.id_name}_cycle{mock_config_handler.inference_config.inference.cycle}*", + ) + + # Assert calls to the mocked methods + mock_load_trained_psf_model.assert_called_once_with( + mock_training_config, mock_data_config, weights_path_pattern + ) + + +@patch.object(PSFInference, "prepare_configs") +@patch.object(PSFInference, "_prepare_positions_and_seds") +@patch.object(PSFInferenceEngine, "compute_psfs") +def test_run_inference( + mock_compute_psfs, + mock_prepare_positions_and_seds, + mock_prepare_configs, + psf_test_setup, +): + inference = psf_test_setup["inference"] + mock_positions = psf_test_setup["mock_positions"] + mock_seds = psf_test_setup["mock_seds"] + expected_psfs = psf_test_setup["expected_psfs"] + + mock_prepare_positions_and_seds.return_value = (mock_positions, mock_seds) + mock_compute_psfs.return_value = expected_psfs + + psfs = inference.run_inference() + + assert isinstance(psfs, np.ndarray) + assert psfs.shape == expected_psfs.shape + mock_prepare_positions_and_seds.assert_called_once() + mock_compute_psfs.assert_called_once_with(mock_positions, mock_seds) + mock_prepare_configs.assert_called_once() + + +@patch("wf_psf.inference.psf_inference.psf_models.simPSF") +def test_simpsf_uses_updated_model_params( + mock_simpsf, mock_training_config, mock_inference_config +): + """Test that simPSF uses the updated model parameters.""" + training_config = mock_training_config + inference_config = mock_inference_config + + # Set the expected output_Q + expected_output_Q = inference_config.inference.model_params.output_Q + training_config.training.model_params.output_Q = expected_output_Q + + # Create fake psf instance + fake_psf_instance = MagicMock() + fake_psf_instance.output_Q = expected_output_Q + mock_simpsf.return_value = fake_psf_instance + + mock_config_handler = MagicMock(spec=InferenceConfigHandler) + mock_config_handler.trained_model_path = "mock/path/to/model" + mock_config_handler.training_config = training_config + mock_config_handler.inference_config = inference_config + mock_config_handler.model_subdir = "psf_model" + mock_config_handler.data_config = MagicMock() + + modeller = PSFInference("dummy_path.yaml") + modeller._config_handler = mock_config_handler + + modeller.prepare_configs() + result = modeller.simPSF + + # Confirm simPSF was called once with the updated model_params + mock_simpsf.assert_called_once() + called_args, _ = mock_simpsf.call_args + model_params_passed = called_args[0] + assert model_params_passed.output_Q == expected_output_Q + assert result.output_Q == expected_output_Q + + +@patch.object(PSFInference, "_prepare_positions_and_seds") +@patch.object(PSFInferenceEngine, "compute_psfs") +def test_get_psfs_runs_inference( + mock_compute_psfs, mock_prepare_positions_and_seds, psf_test_setup +): + inference = psf_test_setup["inference"] + mock_positions = psf_test_setup["mock_positions"] + mock_seds = psf_test_setup["mock_seds"] + expected_psfs = psf_test_setup["expected_psfs"] + + mock_prepare_positions_and_seds.return_value = (mock_positions, mock_seds) + + def fake_compute_psfs(positions, seds): + inference.engine._inferred_psfs = expected_psfs + return expected_psfs + + mock_compute_psfs.side_effect = fake_compute_psfs + + psfs_1 = inference.get_psfs() + assert np.all(psfs_1 == expected_psfs) + + psfs_2 = inference.get_psfs() + assert np.all(psfs_2 == expected_psfs) + + assert mock_compute_psfs.call_count == 1 + + +def test_single_star_inference_shape(psf_single_star_setup): + setup = psf_single_star_setup + + _, mock_instance = _patch_data_handler() + + # Run the method under test + positions, sed_data = setup["inference"]._prepare_positions_and_seds() + + # Check shapes + assert sed_data.shape == (1, setup["num_bins"], 2) + assert positions.shape == (1, 2) + + # Verify the call happened + mock_instance.process_sed_data.assert_called_once() + args, _ = mock_instance.process_sed_data.call_args + input_array = args[0] + + # Check input SED had the right shape before being tensorized + assert input_array.shape == ( + 1, + setup["num_bins"], + 2, + ), "process_sed_data should have been called with shape (1, num_bins, 2)" + + +def test_multiple_star_inference_shape(psf_test_setup): + """Test that _prepare_positions_and_seds returns correct shapes for multiple stars.""" + setup = psf_test_setup + + _, mock_instance = _patch_data_handler() + + # Run the method under test + positions, sed_data = setup["inference"]._prepare_positions_and_seds() + + # Check shapes + assert sed_data.shape == (2, setup["num_bins"], 2) + assert positions.shape == (2, 2) + + # Verify the call happened + mock_instance.process_sed_data.assert_called_once() + args, _ = mock_instance.process_sed_data.call_args + input_array = args[0] + + # Check input SED had the right shape before being tensorized + assert input_array.shape == ( + 2, + setup["num_bins"], + 2, + ), "process_sed_data should have been called with shape (2, num_bins, 2)" + + +def test_valueerror_on_mismatched_batches(psf_single_star_setup): + """Raise if sed_data batch size != positions batch size and sed_data != 1.""" + setup = psf_single_star_setup + inference = setup["inference"] + + patcher, _ = _patch_data_handler() + try: + # Force sed_data to have 2 sources while positions has 1 + bad_sed = np.ones((2, setup["num_bins"], 2), dtype=np.float32) + + # Replace fixture's sed_data with mismatched one + inference.seds = bad_sed + inference.positions = np.ones((1, 2), dtype=np.float32) + + with pytest.raises( + ValueError, match="SEDs batch size 2 does not match number of positions 1" + ): + inference._prepare_positions_and_seds() + finally: + patcher.stop() + + +def test_valueerror_on_mismatched_positions(psf_single_star_setup): + """Raise if positions batch size != sed_data batch size (opposite mismatch).""" + setup = psf_single_star_setup + inference = setup["inference"] + + patcher, _ = _patch_data_handler() + try: + # Force positions to have 3 entries while sed_data has 2 + bad_sed = np.ones((2, setup["num_bins"], 2), dtype=np.float32) + inference.seds = bad_sed + inference.x_field = np.ones((3, 1), dtype=np.float32) + inference.y_field = np.ones((3, 1), dtype=np.float32) + + with pytest.raises( + ValueError, match="SEDs batch size 2 does not match number of positions 3" + ): + inference._prepare_positions_and_seds() + finally: + patcher.stop() diff --git a/src/wf_psf/tests/test_metrics/conftest.py b/src/wf_psf/tests/test_metrics/conftest.py index d015024e..b32dfbf1 100644 --- a/src/wf_psf/tests/test_metrics/conftest.py +++ b/src/wf_psf/tests/test_metrics/conftest.py @@ -7,9 +7,9 @@ """ + import pytest from unittest.mock import patch, MagicMock -import numpy as np import tensorflow as tf @@ -27,6 +27,7 @@ def load_weights(self, *args, **kwargs): # Simulate the weight loading pass + class TFGroundTruthSemiParametricField(TFSemiParametricField): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -35,6 +36,7 @@ def __init__(self, *args, **kwargs): def call(self, inputs, **kwargs): return inputs + @pytest.fixture def mock_psf_model(): # Return a mock instance of TFSemiParametricField @@ -42,8 +44,10 @@ def mock_psf_model(): psf_model.load_weights = MagicMock() # Mock load_weights method return psf_model + @pytest.fixture def mock_get_psf_model(mock_psf_model): - with patch('wf_psf.psf_models.psf_models.get_psf_model', return_value=mock_psf_model) as mock_method: + with patch( + "wf_psf.psf_models.psf_models.get_psf_model", return_value=mock_psf_model + ) as mock_method: yield mock_method - diff --git a/src/wf_psf/tests/test_metrics/metrics_interface_test.py b/src/wf_psf/tests/test_metrics/metrics_interface_test.py index 8b498295..28653761 100644 --- a/src/wf_psf/tests/test_metrics/metrics_interface_test.py +++ b/src/wf_psf/tests/test_metrics/metrics_interface_test.py @@ -1,8 +1,8 @@ - from unittest.mock import patch, MagicMock import pytest -from wf_psf.metrics.metrics_interface import evaluate_model, MetricsParamsHandler -from wf_psf.data.training_preprocessing import DataHandler +from wf_psf.metrics.metrics_interface import evaluate_model +from wf_psf.data.data_handler import DataHandler + @pytest.fixture def mock_metrics_params(): @@ -10,22 +10,21 @@ def mock_metrics_params(): eval_mono_metric=True, eval_opd_metric=False, eval_test_shape_results_dict=True, - eval_train_shape_results_dict=False + eval_train_shape_results_dict=False, ) + @pytest.fixture def mock_trained_model_params(): - return MagicMock( - model_params=MagicMock(model_name="mock_model"), - id_name="mock_id" - ) + return MagicMock(model_params=MagicMock(model_name="mock_model"), id_name="mock_id") + @pytest.fixture def mock_data(): # Create mock instances of the required attributes mock_data_params = MagicMock() mock_simPSF = MagicMock() - + # Mock the `data_params` dictionary for "train" and "test" data mock_data_params.train = MagicMock() mock_data_params.test = MagicMock() @@ -41,34 +40,54 @@ def mock_data(): mock_data_handler.test_data = MagicMock() mock_data_handler.training_data.dataset = { - 'positions': 'train_positions', - 'noisy_stars': 'train_noisy_stars', + "positions": "train_positions", + "noisy_stars": "train_noisy_stars", "SEDs": "train_SEDs", - "C_poly": "train_C_poly" + "C_poly": "train_C_poly", } mock_data_handler.test_data.dataset = { - 'positions': 'test_positions', - 'noisy_stars': 'test_noisy_stars', + "positions": "test_positions", + "noisy_stars": "test_noisy_stars", "SEDs": "test_SEDs", - "C_poly": "test_C_poly" + "C_poly": "test_C_poly", } - mock_data_handler.sed_data = 'mock_sed_data' - + mock_data_handler.sed_data = "mock_sed_data" + # Return the mocked DataHandler instance return mock_data_handler -def test_evaluate_model(mock_metrics_params, mock_trained_model_params, mock_data, mock_psf_model, mocker): +def test_evaluate_model( + mock_metrics_params, mock_trained_model_params, mock_data, mock_psf_model, mocker +): # Mock the metric functions - with patch('wf_psf.metrics.metrics_interface.MetricsParamsHandler.evaluate_metrics_polychromatic_lowres', new_callable=MagicMock) as mock_evaluate_poly_metric, \ - patch('wf_psf.metrics.metrics_interface.MetricsParamsHandler.evaluate_metrics_mono_rmse', new_callable=MagicMock) as mock_evaluate_mono_metric, \ - patch('wf_psf.metrics.metrics_interface.MetricsParamsHandler.evaluate_metrics_opd', new_callable=MagicMock) as mock_evaluate_opd_metric, \ - patch('wf_psf.metrics.metrics_interface.MetricsParamsHandler.evaluate_metrics_shape', new_callable=MagicMock) as mock_evaluate_shape_results_dict, \ - patch('numpy.save', new_callable=MagicMock) as mock_np_save: + with ( + patch( + "wf_psf.metrics.metrics_interface.MetricsParamsHandler.evaluate_metrics_polychromatic_lowres", + new_callable=MagicMock, + ) as mock_evaluate_poly_metric, + patch( + "wf_psf.metrics.metrics_interface.MetricsParamsHandler.evaluate_metrics_chi2", + new_callable=MagicMock, + ) as mock_evaluate_chi2_metric, + patch( + "wf_psf.metrics.metrics_interface.MetricsParamsHandler.evaluate_metrics_mono_rmse", + new_callable=MagicMock, + ) as mock_evaluate_mono_metric, + patch( + "wf_psf.metrics.metrics_interface.MetricsParamsHandler.evaluate_metrics_opd", + new_callable=MagicMock, + ) as mock_evaluate_opd_metric, + patch( + "wf_psf.metrics.metrics_interface.MetricsParamsHandler.evaluate_metrics_shape", + new_callable=MagicMock, + ) as mock_evaluate_shape_results_dict, + patch("numpy.save", new_callable=MagicMock) as mock_np_save, + ): # Mock the logger - logger = mocker.patch('wf_psf.metrics.metrics_interface.logger') + logger = mocker.patch("wf_psf.metrics.metrics_interface.logger") # Call evaluate_model evaluate_model( @@ -76,13 +95,19 @@ def test_evaluate_model(mock_metrics_params, mock_trained_model_params, mock_dat trained_model_params=mock_trained_model_params, data=mock_data, psf_model=mock_psf_model, - weights_path="/mock/weights/path", - metrics_output="/mock/metrics/output" + metrics_output="/mock/metrics/output", ) - # Assertions for metric functions - assert mock_evaluate_poly_metric.call_count == 2 # Called twice, once for each dataset - assert mock_evaluate_mono_metric.call_count == 2 # Called twice, once for each dataset + # Assertions for metric functions + assert ( + mock_evaluate_poly_metric.call_count == 2 + ) # Called twice, once for each dataset + assert ( + mock_evaluate_chi2_metric.call_count == 2 + ) # Called twice, once for each dataset + assert ( + mock_evaluate_mono_metric.call_count == 2 + ) # Called twice, once for each dataset mock_evaluate_opd_metric.assert_not_called() # Should not be called because the flag is False mock_evaluate_shape_results_dict.assert_called_once() # Should be called only for the test dataset @@ -91,5 +116,7 @@ def test_evaluate_model(mock_metrics_params, mock_trained_model_params, mock_dat # Validate the np.save call arguments output_path, saved_data = mock_np_save.call_args[0] # Extract arguments - assert "/mock/metrics/output/metrics-mock_modelmock_id" in output_path # Ensure correct path format - assert isinstance(saved_data, dict) # Ensure data being saved is a dictionary \ No newline at end of file + assert ( + "/mock/metrics/output/metrics-mock_modelmock_id" in output_path + ) # Ensure correct path format + assert isinstance(saved_data, dict) # Ensure data being saved is a dictionary diff --git a/src/wf_psf/tests/test_psf_models/conftest.py b/src/wf_psf/tests/test_psf_models/conftest.py index cbaae8d9..4693b343 100644 --- a/src/wf_psf/tests/test_psf_models/conftest.py +++ b/src/wf_psf/tests/test_psf_models/conftest.py @@ -12,7 +12,7 @@ from wf_psf.utils.read_config import RecursiveNamespace from wf_psf.training.train import TrainingParamsHandler from wf_psf.psf_models import psf_models -from wf_psf.data.training_preprocessing import DataHandler +from wf_psf.data.data_handler import DataHandler training_config = RecursiveNamespace( id_name="_sample_w_bis1_2k", diff --git a/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py b/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py index 81042c9a..e900a6d3 100644 --- a/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py +++ b/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py @@ -9,7 +9,8 @@ import pytest import numpy as np import tensorflow as tf -from wf_psf.psf_models.psf_model_physical_polychromatic import ( +from unittest.mock import patch +from wf_psf.psf_models.models.psf_model_physical_polychromatic import ( TFPhysicalPolychromaticField, ) from wf_psf.utils.configs_handler import DataConfigHandler @@ -28,14 +29,27 @@ def zks_prior(): @pytest.fixture -def mock_data(mocker): +def mock_data(mocker, zks_prior): mock_instance = mocker.Mock(spec=DataConfigHandler) - # Configure the mock data object to have the necessary attributes + mock_instance.run_type = "training" + + training_dataset = { + "positions": np.array([[1, 2], [3, 4]]), + "zernike_prior": zks_prior, + "noisy_stars": np.zeros((2, 1, 1, 1)), + } + test_dataset = { + "positions": np.array([[5, 6], [7, 8]]), + "zernike_prior": zks_prior, + "stars": np.zeros((2, 1, 1, 1)), + } + mock_instance.training_data = mocker.Mock() - mock_instance.training_data.dataset = {"positions": np.array([[1, 2], [3, 4]])} + mock_instance.training_data.dataset = training_dataset mock_instance.test_data = mocker.Mock() - mock_instance.test_data.dataset = {"positions": np.array([[5, 6], [7, 8]])} - mock_instance.batch_size = 32 + mock_instance.test_data.dataset = test_dataset + mock_instance.batch_size = 16 + return mock_instance @@ -47,258 +61,57 @@ def mock_model_params(mocker): return model_params_mock -def test_initialize_parameters(mocker, mock_data, mock_model_params, zks_prior): - # Create mock objects for model_params, training_params - # model_params_mock = mocker.MagicMock() - mock_training_params = mocker.Mock() - - # Mock internal methods called during initialization - mocker.patch( - "wf_psf.psf_models.psf_model_physical_polychromatic.get_zernike_prior", - return_value=zks_prior, - ) - - mocker.patch( - "wf_psf.data.training_preprocessing.get_obs_positions", return_value=True - ) - - # Create TFPhysicalPolychromaticField instance - field_instance = TFPhysicalPolychromaticField( - mock_model_params, mock_training_params, mock_data - ) - - mocker.patch.object(field_instance, "_initialize_zernike_parameters") - mocker.patch.object(field_instance, "_initialize_layers") - mocker.patch.object(field_instance, "assign_coeff_matrix") - - # Call the method being tested - field_instance._initialize_parameters_and_layers( - mock_model_params, mock_training_params, mock_data - ) - - # Check if internal methods were called with the correct arguments - field_instance._initialize_zernike_parameters.assert_called_once_with( - mock_model_params, mock_data - ) - field_instance._initialize_layers.assert_called_once_with( - mock_model_params, mock_training_params - ) - field_instance.assign_coeff_matrix.assert_not_called() # Because coeff_mat is None in this test - - -def test_initialize_zernike_parameters(mocker, mock_model_params, mock_data, zks_prior): - # Create training params mock object - mock_training_params = mocker.Mock() - - # Mock internal methods called during initialization - mocker.patch( - "wf_psf.psf_models.psf_model_physical_polychromatic.get_zernike_prior", - return_value=zks_prior, - ) - - # Create TFPhysicalPolychromaticField instance - field_instance = TFPhysicalPolychromaticField( - mock_model_params, mock_training_params, mock_data - ) - - # Assert that the attributes are set correctly - # assert field_instance.n_zernikes == mock_model_params.param_hparams.n_zernikes - assert np.array_equal(field_instance.zks_prior.numpy(), zks_prior.numpy()) - assert field_instance.n_zks_total == mock_model_params.param_hparams.n_zernikes - assert isinstance( - field_instance.zernike_maps, tf.Tensor - ) # Check if the returned value is a TensorFlow tensor - assert ( - field_instance.zernike_maps.dtype == tf.float32 - ) # Check if the data type of the tensor is float32 - - # Expected shape of the tensor based on the input parameters - expected_shape = ( - field_instance.n_zks_total, - mock_model_params.pupil_diameter, - mock_model_params.pupil_diameter, - ) - assert field_instance.zernike_maps.shape == expected_shape - - # Modify model_params to simulate zks_prior > n_zernikes - mock_model_params.param_hparams.n_zernikes = 2 - - # Call the method again to initialize the parameters - field_instance._initialize_zernike_parameters(mock_model_params, mock_data) - - assert field_instance.n_zks_total == tf.cast( - tf.shape(field_instance.zks_prior)[1], tf.int32 - ) - # Expected shape of the tensor based on the input parameters - expected_shape = ( - field_instance.n_zks_total, - mock_model_params.pupil_diameter, - mock_model_params.pupil_diameter, - ) - assert field_instance.zernike_maps.shape == expected_shape - - -def test_initialize_physical_layer_mocking( - mocker, mock_model_params, mock_data, zks_prior -): - # Create training params mock object - mock_training_params = mocker.Mock() - - # Mock internal methods called during initialization - mocker.patch( - "wf_psf.psf_models.psf_model_physical_polychromatic.get_zernike_prior", - return_value=zks_prior, - ) - - # Create a mock for the TFPhysicalLayer class - mock_physical_layer_class = mocker.patch( - "wf_psf.psf_models.psf_model_physical_polychromatic.TFPhysicalLayer" - ) - - # Create TFPhysicalPolychromaticField instance - field_instance = TFPhysicalPolychromaticField( - mock_model_params, mock_training_params, mock_data - ) - - # Assert that the TFPhysicalLayer class was called with the expected arguments - mock_physical_layer_class.assert_called_once_with( - field_instance.obs_pos, - field_instance.zks_prior, - interpolation_type=mock_model_params.interpolation_type, - interpolation_args=mock_model_params.interpolation_args, - ) - - @pytest.fixture -def physical_layer_instance(mocker, mock_model_params, mock_data, zks_prior): - # Create training params mock object - mock_training_params = mocker.Mock() - - # Mock internal methods called during initialization - mocker.patch( - "wf_psf.psf_models.psf_model_physical_polychromatic.get_zernike_prior", - return_value=zks_prior, - ) - - # Create a mock for the TFPhysicalLayer class - mocker.patch( - "wf_psf.psf_models.psf_model_physical_polychromatic.TFPhysicalLayer" - ) - - # Create TFPhysicalPolychromaticField instance - psf_field_instance = TFPhysicalPolychromaticField( - mock_model_params, mock_training_params, mock_data - ) - return psf_field_instance - - -def test_pad_zernikes_num_of_zernikes_equal(physical_layer_instance): - # Define input tensors with same length and num of Zernikes - zk_param = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) - zk_prior = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) +def physical_layer_instance(mocker, mock_model_params, mock_data): + # Patch expensive methods during construction to avoid errors + with patch( + "wf_psf.psf_models.models.psf_model_physical_polychromatic.TFPhysicalPolychromaticField._assemble_zernike_contributions", + return_value=tf.constant([[[[1.0]]], [[[2.0]]]]), + ): + from wf_psf.psf_models.models.psf_model_physical_polychromatic import ( + TFPhysicalPolychromaticField, + ) + + instance = TFPhysicalPolychromaticField( + mock_model_params, mocker.Mock(), mock_data + ) + return instance - # Reshape the tensors to have the desired shapes - zk_param = tf.reshape(zk_param, (1, 2, 1, 1)) # Reshaping tensor1 to (1, 2, 1, 1) - zk_prior = tf.reshape(zk_prior, (1, 2, 1, 1)) # Reshaping tensor2 to (1, 2, 1, 1) - # Reset n_zks_total attribute - physical_layer_instance.n_zks_total = max( - tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() - ) - # Call the method under test - padded_zk_param, padded_zk_prior = physical_layer_instance.pad_zernikes( - zk_param, zk_prior - ) - - # Assert that the padded tensors have the correct shapes - assert padded_zk_param.shape == (1, 2, 1, 1) - assert padded_zk_prior.shape == (1, 2, 1, 1) - - -def test_pad_zernikes_prior_greater_than_param(physical_layer_instance): - zk_param = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) - zk_prior = tf.constant([[[[1]], [[2]], [[3]], [[4]], [[5]]]]) # Shape: (5, 1, 1, 1) - - # Reshape the tensors to have the desired shapes - zk_param = tf.reshape(zk_param, (1, 2, 1, 1)) # Reshaping tensor1 to (1, 2, 1, 1) - zk_prior = tf.reshape(zk_prior, (1, 5, 1, 1)) # Reshaping tensor2 to (1, 5, 1, 1) - - # Reset n_zks_total attribute - physical_layer_instance.n_zks_total = max( - tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() - ) - - # Call the method under test - padded_zk_param, padded_zk_prior = physical_layer_instance.pad_zernikes( - zk_param, zk_prior +def test_compute_zernikes(mocker, physical_layer_instance): + # Expected output of mock components + padded_zernike_param = tf.constant( + [[[[10]], [[20]], [[30]], [[40]]]], dtype=tf.float32 ) - - # Assert that the padded tensors have the correct shapes - assert padded_zk_param.shape == (1, 5, 1, 1) - assert padded_zk_prior.shape == (1, 5, 1, 1) - - -def test_pad_zernikes_param_greater_than_prior(physical_layer_instance): - zk_param = tf.constant([[[[10]], [[20]], [[30]], [[40]]]]) # Shape: (4, 1, 1, 1) - zk_prior = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) - - # Reshape the tensors to have the desired shapes - zk_param = tf.reshape(zk_param, (1, 4, 1, 1)) # Reshaping tensor1 to (1, 2, 1, 1) - zk_prior = tf.reshape(zk_prior, (1, 2, 1, 1)) # Reshaping tensor2 to (1, 4, 1, 1) - - # Reset n_zks_total attribute - physical_layer_instance.n_zks_total = max( - tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() + padded_zernike_prior = tf.constant([[[[1]], [[2]], [[0]], [[0]]]], dtype=tf.float32) + n_zks_total = physical_layer_instance.n_zks_total + expected_values_list = [11, 22, 30, 40] + [0] * (n_zks_total - 4) + expected_values = tf.constant( + [[[[v]] for v in expected_values_list]], dtype=tf.float32 ) - - # Call the method under test - padded_zk_param, padded_zk_prior = physical_layer_instance.pad_zernikes( - zk_param, zk_prior + # Patch tf_poly_Z_field method + mocker.patch.object( + TFPhysicalPolychromaticField, + "tf_poly_Z_field", + return_value=padded_zernike_param, ) - # Assert that the padded tensors have the correct shapes - assert padded_zk_param.shape == (1, 4, 1, 1) - assert padded_zk_prior.shape == (1, 4, 1, 1) - - -def test_compute_zernikes(mocker, physical_layer_instance): - # Mock padded tensors - padded_zk_param = tf.constant( - [[[[10]], [[20]], [[30]], [[40]]]] - ) # Shape: (1, 4, 1, 1) - padded_zk_prior = tf.constant([[[[1]], [[2]], [[0]], [[0]]]]) # Shape: (1, 4, 1, 1) - - # Reset n_zks_total attribute - physical_layer_instance.n_zks_total = 4 # Assuming a specific value for simplicity - - # Define the mock return values for tf_poly_Z_field and tf_physical_layer.call - padded_zernike_param = tf.constant( - [[[[10]], [[20]], [[30]], [[40]]]] - ) # Shape: (1, 4, 1, 1) - padded_zernike_prior = tf.constant( - [[[[1]], [[2]], [[0]], [[0]]]] - ) # Shape: (1, 4, 1, 1) - + # Patch tf_physical_layer.call method + mock_tf_physical_layer = mocker.Mock() + mock_tf_physical_layer.call.return_value = padded_zernike_prior mocker.patch.object( - physical_layer_instance, "tf_poly_Z_field", return_value=padded_zk_param + TFPhysicalPolychromaticField, "tf_physical_layer", mock_tf_physical_layer ) - mocker.patch.object(physical_layer_instance, "call", return_value=padded_zk_prior) - mocker.patch.object( - physical_layer_instance, - "pad_zernikes", + + # Patch pad_tf_zernikes function + mocker.patch( + "wf_psf.data.data_zernike_utils.pad_tf_zernikes", return_value=(padded_zernike_param, padded_zernike_prior), ) - # Call the method under test + # Run the test zernike_coeffs = physical_layer_instance.compute_zernikes(tf.constant([[0.0, 0.0]])) - # Define the expected values - expected_values = tf.constant( - [[[[11]], [[22]], [[30]], [[40]]]] - ) # Shape: (1, 4, 1, 1) - - # Assert that the shapes are equal + # Assertions + tf.debugging.assert_equal(zernike_coeffs, expected_values) assert zernike_coeffs.shape == expected_values.shape - - # Assert that the tensor values are equal - assert tf.reduce_all(tf.equal(zernike_coeffs, expected_values)) diff --git a/src/wf_psf/tests/test_psf_models/psf_models_test.py b/src/wf_psf/tests/test_psf_models/psf_models_test.py index 066e1328..2b907eff 100644 --- a/src/wf_psf/tests/test_psf_models/psf_models_test.py +++ b/src/wf_psf/tests/test_psf_models/psf_models_test.py @@ -7,8 +7,8 @@ """ -from wf_psf.psf_models import ( - psf_models, +from wf_psf.psf_models import psf_models +from wf_psf.psf_models.models import ( psf_model_semiparametric, psf_model_physical_polychromatic, ) diff --git a/src/wf_psf/tests/test_training/train_utils_test.py b/src/wf_psf/tests/test_training/train_utils_test.py index 8cefa5f6..ecb4d822 100644 --- a/src/wf_psf/tests/test_training/train_utils_test.py +++ b/src/wf_psf/tests/test_training/train_utils_test.py @@ -150,7 +150,6 @@ def test_calculate_sample_weights_integration( use_sample_weights, loss, expected_output_type ): """Test different cases for sample weight computation.""" - # Generate dummy image data batch_size, height, width = 5, 32, 32 @@ -179,9 +178,21 @@ def test_calculate_sample_weights_integration( @pytest.mark.parametrize( "loss", [None, "mean_squared_error", "masked_mean_squared_error"] ) -def test_calculate_sample_weights_unit(mock_noise_estimator, loss): +def test_calculate_sample_weights_unit(loss): """Test sample weighting strategy with random images.""" - outputs = np.random.rand(10, 32, 32) # 10 images of size 32x32 + # Generate dummy image data + batch_size, height, width = 10, 32, 32 + + if loss == "masked_mean_squared_error": + # Create image-mask pairs: last dimension has [image, mask] + outputs = np.random.rand(batch_size, height, width, 2) + outputs[..., 1] = np.random.randint( + 0, 2, size=(batch_size, height, width) + ) # Binary mask + else: + outputs = np.random.rand(batch_size, height, width) + + # Calculate sample weights result = train_utils.calculate_sample_weights( outputs, use_sample_weights=True, loss=loss ) @@ -200,10 +211,28 @@ def test_calculate_sample_weights_unit(mock_noise_estimator, loss): @pytest.mark.parametrize( "loss", [None, "mean_squared_error", "masked_mean_squared_error"] ) -def test_calculate_sample_weights_high_variance(mock_noise_estimator, loss): +def test_calculate_sample_weights_high_variance(loss): """Test case for high variance (noisy images).""" # Create high variance images with more noise - outputs = np.random.normal(loc=0.0, scale=10.0, size=(5, 32, 32)) # Larger noise + # Generate dummy image data + batch_size, height, width = 10, 32, 32 + + if loss == "masked_mean_squared_error": + # Create image-mask pairs: last dimension has [image, mask] + outputs = np.zeros((batch_size, height, width, 2), dtype=np.float32) + mask_prob: float = 0.5 # Probability of a pixel being unmasked + # High variance images + outputs[..., 0] = np.random.normal( + loc=0.0, scale=10.0, size=(batch_size, height, width) + ) + # Random masks with adjustable sparsity + outputs[..., 1] = ( + np.random.rand(batch_size, height, width) < mask_prob + ).astype(np.float32) + else: + outputs = np.random.normal( + loc=0.0, scale=10.0, size=(batch_size, height, width) + ) # Larger noise # Calculate sample weights weights = train_utils.calculate_sample_weights( @@ -416,7 +445,6 @@ def test_general_train_cycle_with_callbacks( mock_test_setup, cycle_def, param_callback, non_param_callback, general_callback ): """Test general_train_cycle with different cycle_def and callback configurations.""" - # Unpack test setup mock_model = mock_test_setup["mock_model"] diff --git a/src/wf_psf/tests/test_utils/configs_handler_test.py b/src/wf_psf/tests/test_utils/configs_handler_test.py index 5fb7d7cd..071181c3 100644 --- a/src/wf_psf/tests/test_utils/configs_handler_test.py +++ b/src/wf_psf/tests/test_utils/configs_handler_test.py @@ -4,15 +4,17 @@ :Author: Jennifer Pollack - """ import pytest +from wf_psf.data.data_handler import DataHandler from wf_psf.utils import configs_handler from wf_psf.utils.read_config import RecursiveNamespace from wf_psf.utils.io import FileIOHandler -from wf_psf.utils.configs_handler import TrainingConfigHandler, DataConfigHandler -from wf_psf.data.training_preprocessing import DataHandler +from wf_psf.utils.configs_handler import ( + TrainingConfigHandler, + DataConfigHandler, +) import os @@ -108,9 +110,7 @@ def test_get_run_config(path_to_repo_dir, path_to_tmp_output_dir, path_to_config assert type(config_class) is RegisterConfigClass -def test_data_config_handler_init( - mock_training_conf, mock_data_read_conf, mocker -): +def test_data_config_handler_init(mock_training_conf, mock_data_read_conf, mocker): # Mock read_conf function mock_data_read_conf() @@ -120,13 +120,24 @@ def test_data_config_handler_init( "wf_psf.psf_models.psf_models.simPSF", return_value=mock_simPSF_instance ) - # Patch the load_dataset and process_sed_data methods inside DataHandler - mocker.patch.object(DataHandler, "load_dataset") + # Patch process_sed_data method mocker.patch.object(DataHandler, "process_sed_data") + # Patch validate_and_process_datasetmethod + mocker.patch.object(DataHandler, "validate_and_process_dataset") + + # Patch load_dataset to assign dataset + def mock_load_dataset(self): + self.dataset = { + "SEDs": ["dummy_sed_data"], + "positions": ["dummy_positions_data"], + } + + mocker.patch.object(DataHandler, "load_dataset", new=mock_load_dataset) + # Create DataConfigHandler instance data_config_handler = DataConfigHandler( - "/path/to/data_config.yaml", + "/path/to/data_config.yaml", mock_training_conf.training.model_params, mock_training_conf.training.training_hparams.batch_size, ) @@ -142,7 +153,10 @@ def test_data_config_handler_init( data_config_handler.test_data.n_bins_lambda == mock_training_conf.training.model_params.n_bins_lda ) - assert (data_config_handler.batch_size == mock_training_conf.training.training_hparams.batch_size) # Default value + assert ( + data_config_handler.batch_size + == mock_training_conf.training.training_hparams.batch_size + ) def test_training_config_handler_init(mocker, mock_training_conf, mock_file_handler): @@ -234,23 +248,3 @@ def test_run_method_calls_train_with_correct_arguments( mock_th.optimizer_dir, mock_th.psf_model_dir, ) - - -def test_MetricsConfigHandler_weights_basename_filepath( - path_to_repo_dir, path_to_tmp_output_dir, path_to_config_dir -): - test_file_handler = FileIOHandler( - path_to_repo_dir, path_to_tmp_output_dir, path_to_config_dir - ) - - metrics_config_file = "validation/main_random_seed/config/metrics_config.yaml" - - metrics_object = configs_handler.MetricsConfigHandler( - os.path.join(path_to_config_dir, metrics_config_file), test_file_handler - ) - weights_filepath = metrics_object.weights_basename_filepath - - assert ( - weights_filepath - == "src/wf_psf/tests/data/validation/main_random_seed/checkpoint/checkpoint*_poly*_sample_w_bis1_2k_cycle2*" - ) diff --git a/src/wf_psf/tests/test_utils/conftest.py b/src/wf_psf/tests/test_utils/conftest.py index abcfa93f..caee7878 100644 --- a/src/wf_psf/tests/test_utils/conftest.py +++ b/src/wf_psf/tests/test_utils/conftest.py @@ -10,7 +10,6 @@ import pytest import os -from wf_psf.utils.read_config import RecursiveNamespace from wf_psf.utils.io import FileIOHandler cwd = os.getcwd() diff --git a/src/wf_psf/tests/test_utils/utils_test.py b/src/wf_psf/tests/test_utils/utils_test.py index 7d125014..adabeaa0 100644 --- a/src/wf_psf/tests/test_utils/utils_test.py +++ b/src/wf_psf/tests/test_utils/utils_test.py @@ -16,6 +16,7 @@ ) from wf_psf.sims.psf_simulator import PSFSimulator + def test_initialization(): """Test if NoiseEstimator initializes correctly.""" img_dim = (50, 50) @@ -27,6 +28,7 @@ def test_initialization(): assert isinstance(estimator.window, np.ndarray) assert estimator.window.shape == img_dim + def test_init_window(): """Test that the exclusion window is correctly initialized.""" img_dim = (50, 50) @@ -41,13 +43,17 @@ def test_init_window(): inside_radius = np.sqrt((x - mid_x) ** 2 + (y - mid_y) ** 2) <= win_rad assert estimator.window[x, y] == (not inside_radius) + def test_sigma_mad(): """Test the MAD-based standard deviation estimation.""" - data = np.array([1, 1, 2, 2, 3, 3, 4, 4, 100]) # Outlier should not heavily influence MAD + data = np.array( + [1, 1, 2, 2, 3, 3, 4, 4, 100] + ) # Outlier should not heavily influence MAD expected_sigma = 1.4826 * np.median(np.abs(data - np.median(data))) assert np.isclose(NoiseEstimator.sigma_mad(data), expected_sigma, atol=1e-4) + def test_estimate_noise_without_default_window(): """Test noise estimation with the default exclusion window (no custom mask).""" img_dim = (50, 50) @@ -59,10 +65,11 @@ def test_estimate_noise_without_default_window(): image = np.random.normal(0, 10, img_dim) noise_estimation = estimator.estimate_noise(image) - + # The estimated noise should be close to 10 (the true std) assert np.isclose(noise_estimation, 10, atol=2) + def test_estimate_noise_with_custom_mask(): """Test noise estimation with a custom mask applied outside the exclusion radius.""" img_dim = (50, 50) @@ -80,6 +87,7 @@ def test_estimate_noise_with_custom_mask(): assert np.isclose(noise_estimation, 5, atol=1) + def test_apply_mask_with_none_mask(): """Test apply_mask when mask is None.""" img_dim = (10, 10) @@ -88,13 +96,16 @@ def test_apply_mask_with_none_mask(): result = estimator.apply_mask(None) # Pass None as the mask # It should return the window itself when no mask is provided - assert np.array_equal(result, estimator.window), "apply_mask should return the window when mask is None." + assert np.array_equal( + result, estimator.window + ), "apply_mask should return the window when mask is None." + def test_apply_mask_with_valid_mask(): """Test apply_mask when a valid mask is provided.""" img_dim = (10, 10) estimator = NoiseEstimator(img_dim, win_rad=3) - + # Create a custom mask custom_mask = np.ones(img_dim, dtype=bool) custom_mask[5, 5] = False # Set a pixel to False to exclude it from the window @@ -103,7 +114,10 @@ def test_apply_mask_with_valid_mask(): # Check that the mask was applied correctly: pixel (5, 5) should be False, others True expected_result = estimator.window & custom_mask - assert np.array_equal(result, expected_result), "apply_mask did not apply the mask correctly." + assert np.array_equal( + result, expected_result + ), "apply_mask did not apply the mask correctly." + def test_apply_mask_with_zeroed_mask(): """Test apply_mask when a zeroed mask is provided.""" @@ -117,7 +131,9 @@ def test_apply_mask_with_zeroed_mask(): # The result should be an array of False values, as the mask excludes all pixels expected_result = np.zeros(img_dim, dtype=bool) - assert np.array_equal(result, expected_result), "apply_mask did not handle the zeroed mask correctly." + assert np.array_equal( + result, expected_result + ), "apply_mask did not handle the zeroed mask correctly." def test_unobscured_zernike_projection(): @@ -180,7 +196,9 @@ def test_tf_decompose_obscured_opd_basis(): tf_zernike_cube = tf.convert_to_tensor(np_zernike_cube, dtype=tf.float32) # Create obscurations - obscurations = PSFSimulator.generate_euclid_pupil_obscurations(N_pix=wfe_dim, N_filter=2) + obscurations = PSFSimulator.generate_euclid_pupil_obscurations( + N_pix=wfe_dim, N_filter=2 + ) tf_obscurations = tf.convert_to_tensor(obscurations, dtype=tf.float32) # Create random zernike coefficient array diff --git a/src/wf_psf/training/train.py b/src/wf_psf/training/train.py index 3ee47fa8..f3c11d6e 100644 --- a/src/wf_psf/training/train.py +++ b/src/wf_psf/training/train.py @@ -7,6 +7,7 @@ """ +import gc import numpy as np import time import tensorflow as tf @@ -33,6 +34,7 @@ def get_gpu_info(): device_name = tf.test.gpu_device_name() return device_name + def setup_training(): """Set up Training. @@ -273,7 +275,7 @@ def _prepare_callbacks( def get_loss_metrics_monitor_and_outputs(training_handler, data_conf): """Factory to return fresh loss, metrics (param & non-param), monitor, and outputs for the current cycle. - + Parameters ---------- training_handler: TrainingParamsHandler @@ -295,19 +297,26 @@ def get_loss_metrics_monitor_and_outputs(training_handler, data_conf): Tensor containing the outputs for training output_val: tf.Tensor Tensor containing the outputs for validation - - """ + """ if training_handler.training_hparams.loss == "mask_mse": loss = train_utils.MaskedMeanSquaredError() monitor = "loss" param_metrics = [train_utils.MaskedMeanSquaredErrorMetric()] non_param_metrics = [train_utils.MaskedMeanSquaredErrorMetric()] outputs = tf.stack( - [data_conf.training_data.dataset["noisy_stars"], data_conf.training_data.dataset["masks"]], axis=-1 + [ + data_conf.training_data.dataset["noisy_stars"], + data_conf.training_data.dataset["masks"], + ], + axis=-1, ) output_val = tf.stack( - [data_conf.test_data.dataset["stars"], data_conf.test_data.dataset["masks"]], axis=-1 + [ + data_conf.test_data.dataset["stars"], + data_conf.test_data.dataset["masks"], + ], + axis=-1, ) else: loss = tf.keras.losses.MeanSquaredError() @@ -332,7 +341,7 @@ def train( This function manages multi-cycle training of a parametric + non-parametric PSF model, including initialization, loss/metric configuration, optimizer setup, model checkpointing, - and optional projection or resetting of non-parametric features. Each cycle can include + and optional projection or resetting of non-parametric features. Each cycle can include both parametric and non-parametric training stages, and training history is saved for each. Parameters @@ -363,7 +372,7 @@ def train( None Side Effects - ------------ + ------------ - Saves model weights to `psf_model_dir` per training cycle (or final one if not all saved) - Saves optimizer histories to `optimizer_dir` - Logs cycle information and time durations @@ -391,10 +400,10 @@ def train( current_cycle += 1 # Instantiate fresh loss, monitor, and independent metric objects per training phase (param / non-param) - loss, param_metrics, non_param_metrics, monitor, outputs, output_val = get_loss_metrics_monitor_and_outputs( - training_handler, data_conf + loss, param_metrics, non_param_metrics, monitor, outputs, output_val = ( + get_loss_metrics_monitor_and_outputs(training_handler, data_conf) ) - + # If projected learning is enabled project DD_features. if hasattr(psf_model, "project_dd_features") and psf_model.project_dd_features: if current_cycle > 1: @@ -525,3 +534,8 @@ def train( final_time = time.time() logger.info("\nTotal elapsed time: %f" % (final_time - starting_time)) logger.info("\n Training complete..") + + # Clean up memory + del psf_model + gc.collect() + tf.keras.backend.clear_session() diff --git a/src/wf_psf/training/train_utils.py b/src/wf_psf/training/train_utils.py index a3d6f95a..b26c112a 100644 --- a/src/wf_psf/training/train_utils.py +++ b/src/wf_psf/training/train_utils.py @@ -382,6 +382,48 @@ def configure_optimizer_and_loss( return optimizer, loss, metrics +def compute_noise_std_from_stars( + images: np.ndarray, + masks: Optional[np.ndarray] = None, +) -> Optional[np.ndarray]: + """ + Compute the noise standard deviation from star images. + + Parameters + ---------- + images: np.ndarray + A 3D array of shape (batch_size, height, width) representing star images. + The first dimension is the batch size, and the next two dimensions are the image height and width. + masks: np.ndarray, optional + A 3D array of shape (batch_size, height, width) representing masks for the images. + + Returns + ------- + np.ndarray + An array of standard deviations for each image in the batch, or None if no images are provided. + + """ + if images is not None and len(images.shape) >= 3: + img_dim = (images.shape[1], images.shape[2]) + win_rad = np.ceil(images.shape[1] / 3.33) + std_est = NoiseEstimator(img_dim=img_dim, win_rad=win_rad) + + if masks is not None: + imgs_std = np.array( + [std_est.estimate_noise(_im, _win) for _im, _win in zip(images, masks)] + ) + else: + imgs_std = np.array([std_est.estimate_noise(_im) for _im in images]) + + else: + logger.warning( + "No images provided for noise standard deviation estimation or there was a problem with the input images." + ) + imgs_std = None + + return imgs_std + + def calculate_sample_weights( outputs: np.ndarray, use_sample_weights: bool, @@ -400,7 +442,7 @@ def calculate_sample_weights( ---------- outputs: np.ndarray A 3D array of shape (batch_size, height, width) representing images, where the first dimension is the batch size - and the next two dimensions are the image height and width. + and the next two dimensions are the image height and width. It can contain the masks in an extra dimension, e.g., (batch_size, height, width, 2), use_sample_weights: bool Flag indicating whether to compute sample weights. If True, sample weights will be computed based on the image noise. loss: str, callable, optional @@ -420,10 +462,7 @@ def calculate_sample_weights( An array of sample weights, or None if `use_sample_weights` is False. """ if use_sample_weights: - img_dim = (outputs.shape[1], outputs.shape[2]) - win_rad = np.ceil(outputs.shape[1] / 3.33) - std_est = NoiseEstimator(img_dim=img_dim, win_rad=win_rad) - + # Compute noise standard deviation from images if loss is not None and ( (isinstance(loss, str) and loss == "masked_mean_squared_error") or (hasattr(loss, "name") and loss.name == "masked_mean_squared_error") @@ -431,13 +470,12 @@ def calculate_sample_weights( logger.info("Estimating noise standard deviation for masked images..") images = outputs[..., 0] masks = np.array(1 - outputs[..., 1], dtype=bool) - imgs_std = np.array( - [std_est.estimate_noise(_im, _win) for _im, _win in zip(images, masks)] - ) + imgs_std = compute_noise_std_from_stars(images, masks) + else: logger.info("Estimating noise standard deviation for images..") # Estimate noise standard deviation - imgs_std = np.array([std_est.estimate_noise(_im) for _im in outputs]) + imgs_std = compute_noise_std_from_stars(outputs) # Calculate variances variances = imgs_std**2 diff --git a/src/wf_psf/utils/configs_handler.py b/src/wf_psf/utils/configs_handler.py index a7f76187..6971a6e9 100644 --- a/src/wf_psf/utils/configs_handler.py +++ b/src/wf_psf/utils/configs_handler.py @@ -12,12 +12,14 @@ import os import re import glob -from wf_psf.utils.read_config import read_conf -from wf_psf.data.training_preprocessing import DataHandler -from wf_psf.training import train -from wf_psf.psf_models import psf_models +from wf_psf.data.data_handler import DataHandler from wf_psf.metrics.metrics_interface import evaluate_model from wf_psf.plotting.plots_interface import plot_metrics +from wf_psf.psf_models import psf_models +from wf_psf.psf_models.psf_model_loader import load_trained_psf_model +from wf_psf.training import train +from wf_psf.utils.read_config import read_conf + logger = logging.getLogger(__name__) @@ -120,35 +122,38 @@ class DataConfigHandler: training_model_params : Recursive Namespace object Recursive Namespace object containing the training model parameters batch_size : int - Training hyperparameter used for batched pre-processing of data. + Training hyperparameter used for batched pre-processing of data. """ def __init__(self, data_conf, training_model_params, batch_size=16, load_data=True): try: self.data_conf = read_conf(data_conf) - except FileNotFoundError as e: - logger.exception(e) - exit() - except TypeError as e: + except (FileNotFoundError, TypeError) as e: logger.exception(e) exit() self.simPSF = psf_models.simPSF(training_model_params) + + # Extract sub-configs early + train_params = self.data_conf.data.training + test_params = self.data_conf.data.test + self.training_data = DataHandler( dataset_type="training", - data_params=self.data_conf.data, + data_params=train_params, simPSF=self.simPSF, n_bins_lambda=training_model_params.n_bins_lda, load_data=load_data, ) self.test_data = DataHandler( dataset_type="test", - data_params=self.data_conf.data, + data_params=test_params, simPSF=self.simPSF, n_bins_lambda=training_model_params.n_bins_lda, load_data=load_data, ) + self.batch_size = batch_size @@ -183,6 +188,7 @@ def __init__(self, training_conf, file_handler): self.training_conf.training.training_hparams.batch_size, self.training_conf.training.load_data_on_init, ) + self.data_conf.run_type = "training" self.file_handler.copy_conffile_to_output_dir( self.training_conf.training.data_config ) @@ -203,7 +209,6 @@ def run(self): input configuration. """ - train.train( self.training_conf.training, self.data_conf, @@ -255,102 +260,146 @@ class MetricsConfigHandler: def __init__(self, metrics_conf, file_handler, training_conf=None): self._metrics_conf = read_conf(metrics_conf) self._file_handler = file_handler - self.trained_model_path = self._get_trained_model_path(training_conf) - self._training_conf = self._load_training_conf(training_conf) + self.training_conf = training_conf + self.data_conf = self._load_data_conf() + self.data_conf.run_type = "metrics" + self.metrics_dir = self._file_handler.get_metrics_dir( + self._file_handler._run_output_dir + ) + self.trained_psf_model = self._load_trained_psf_model() @property def metrics_conf(self): return self._metrics_conf - @property - def metrics_dir(self): - return self._file_handler.get_metrics_dir(self._file_handler._run_output_dir) - @property def training_conf(self): + """Returns the loaded training configuration.""" return self._training_conf + @training_conf.setter + def training_conf(self, training_conf): + """ + Sets the training configuration. If None is provided, attempts to load it + from the trained_model_path in the metrics configuration. + """ + if training_conf is None: + try: + training_conf_path = self._get_training_conf_path_from_metrics() + logger.info( + f"Loading training config from inferred path: {training_conf_path}" + ) + self._training_conf = read_conf(training_conf_path) + except Exception as e: + logger.error(f"Failed to load training config: {e}") + raise + else: + self._training_conf = training_conf + @property def plotting_conf(self): return self.metrics_conf.metrics.plotting_config - @property - def data_conf(self): - return self._load_data_conf() - - @property - def psf_model(self): - return psf_models.get_psf_model( - self.training_conf.training.model_params, - self.training_conf.training.training_hparams, + def _load_trained_psf_model(self): + trained_model_path = self._get_trained_model_path() + try: + model_subdir = self.metrics_conf.metrics.model_save_path + cycle = self.metrics_conf.metrics.saved_training_cycle + except AttributeError as e: + raise KeyError("Missing required model config fields.") from e + + model_name = self.training_conf.training.model_params.model_name + id_name = self.training_conf.training.id_name + + weights_path_pattern = os.path.join( + trained_model_path, + model_subdir, + (f"{model_subdir}*_{model_name}" f"*{id_name}_cycle{cycle}*"), + ) + return load_trained_psf_model( + self.training_conf, self.data_conf, + weights_path_pattern, ) - @property - def weights_path(self): - return psf_models.get_psf_model_weights_filepath(self.weights_basename_filepath) - - def _get_trained_model_path(self, training_conf): - """Get Trained Model Path. - - Helper method to get the trained model path. - - Parameters - ---------- - training_conf: None or RecursiveNamespace - None type or RecursiveNamespace + def _get_training_conf_path_from_metrics(self): + """ + Retrieves the full path to the training config based on the metrics configuration. Returns ------- str - A string representing the path to the trained model output run directory. - + Full path to the training configuration file. + + Raises + ------ + KeyError + If 'trained_model_config' key is missing. + FileNotFoundError + If the file does not exist at the constructed path. """ - if training_conf is None: - try: - return self._metrics_conf.metrics.trained_model_path + trained_model_path = self._get_trained_model_path() - except TypeError as e: - logger.exception(e) - raise ConfigParameterError( - "Metrics config file trained model path or config values are empty." - ) - else: - return os.path.join( - self._file_handler.output_path, - self._file_handler.parent_output_dir, - self._file_handler.workdir, + try: + training_conf_filename = self._metrics_conf.metrics.trained_model_config + except AttributeError as e: + raise KeyError( + "Missing 'trained_model_config' key in metrics configuration." + ) from e + + training_conf_path = os.path.join( + self._file_handler.get_config_dir(trained_model_path), + training_conf_filename, + ) + + if not os.path.exists(training_conf_path): + raise FileNotFoundError( + f"Training config file not found: {training_conf_path}" ) - def _load_training_conf(self, training_conf): - """Load Training Conf. - Load the training configuration if training_conf is not provided. + return training_conf_path - Parameters - ---------- - training_conf: None or RecursiveNamespace - None type or a RecursiveNamespace storing the training configuration parameter setttings. + def _get_trained_model_path(self): + """ + Determine the trained model path from either: + + 1. The metrics configuration file (i.e., for metrics-only runs after training), or + 2. The runtime-generated file handler paths (i.e., for single runs that perform both training and evaluation). Returns ------- - RecursiveNamespace storing the training configuration parameter settings. + str + Path to the trained model directory. + Raises + ------ + ConfigParameterError + If the path specified in the metrics config is invalid or missing. """ - if training_conf is None: - try: - return read_conf( - os.path.join( - self._file_handler.get_config_dir(self.trained_model_path), - self._metrics_conf.metrics.trained_model_config, - ) - ) - except TypeError as e: - logger.exception(e) + trained_model_path = getattr( + self._metrics_conf.metrics, "trained_model_path", None + ) + + if trained_model_path: + if not os.path.isdir(trained_model_path): raise ConfigParameterError( - "Metrics config file trained model path or config values are empty." + f"The trained model path provided in the metrics config is not a valid directory: {trained_model_path}" ) - else: - return training_conf + logger.info( + f"Using trained model path from metrics config: {trained_model_path}" + ) + return trained_model_path + + # Fallback for single-run training + metrics evaluation mode + fallback_path = os.path.join( + self._file_handler.output_path, + self._file_handler.parent_output_dir, + self._file_handler.workdir, + ) + logger.info( + f"Using fallback trained model path from runtime file handler: {fallback_path}" + ) + return fallback_path def _load_data_conf(self): """Load Data Conf. @@ -374,27 +423,6 @@ def _load_data_conf(self): logger.exception(e) raise ConfigParameterError("Data configuration loading error.") - @property - def weights_basename_filepath(self): - """Get PSF model weights filepath. - - A function to return the basename of the user-specified psf model weights path. - - Returns - ------- - weights_basename: str - The basename of the psf model weights to be loaded. - - """ - return os.path.join( - self.trained_model_path, - self.metrics_conf.metrics.model_save_path, - ( - f"{self.metrics_conf.metrics.model_save_path}*_{self.training_conf.training.model_params.model_name}" - f"*{self.training_conf.training.id_name}_cycle{self.metrics_conf.metrics.saved_training_cycle}*" - ), - ) - def call_plot_config_handler_run(self, model_metrics): """Make Metrics Plots. @@ -437,20 +465,17 @@ def call_plot_config_handler_run(self, model_metrics): def run(self): """Run. - A function to run wave-diff according to the + A function to run WaveDiff according to the input configuration. """ - logger.info( - "Running metrics evaluation on psf model: {}".format(self.weights_path) - ) + logger.info("Running metrics evaluation on trained PSF model...") model_metrics = evaluate_model( self.metrics_conf.metrics, self.training_conf.training, self.data_conf, - self.psf_model, - self.weights_path, + self.trained_psf_model, self.metrics_dir, ) @@ -555,7 +580,6 @@ def _metrics_run_id_name(self, wf_outdir, metrics_params): metrics_run_id_name: list List containing the model name and id for each training run """ - try: training_conf = read_conf( os.path.join( @@ -570,9 +594,7 @@ def _metrics_run_id_name(self, wf_outdir, metrics_params): except (TypeError, FileNotFoundError): logger.info("Trained model path not provided...") logger.info( - "Trying to retrieve training config file from workdir: {}".format( - wf_outdir - ) + f"Trying to retrieve training config file from workdir: {wf_outdir}" ) training_confs = [ @@ -623,9 +645,7 @@ def load_metrics_into_dict(self): "metrics-" + run_id_name + ".npy", ) logger.info( - "Attempting to read in trained model config file...{}".format( - output_path - ) + f"Attempting to read in trained model config file...{output_path}" ) try: metrics_dict[k].append( diff --git a/src/wf_psf/utils/graph_utils.py b/src/wf_psf/utils/graph_utils.py index 1f6c5463..96f9b270 100644 --- a/src/wf_psf/utils/graph_utils.py +++ b/src/wf_psf/utils/graph_utils.py @@ -1,7 +1,7 @@ import numpy as np -class GraphBuilder(object): +class GraphBuilder: r"""GraphBuilder class. This class computes the necessary quantities for RCA's graph constraint. @@ -112,10 +112,10 @@ def _build_graphs(self): R -= vect.T.dot(vect.dot(R)) if self.verbose: print( - " > selected e: {}\tselected a:".format(e) - + "{}\t chosen index: {}/{}".format(a, j, self.n_eigenvects) + f" > selected e: {e}\tselected a:" + + f"{a}\t chosen index: {j}/{self.n_eigenvects}" ) - self.VT = np.vstack((eigenvect for eigenvect in list_eigenvects)) + self.VT = np.vstack(eigenvect for eigenvect in list_eigenvects) self.alpha = np.zeros((self.n_comp, self.VT.shape[0])) for i in range(self.n_comp): self.alpha[i, i * self.n_eigenvects + idx[i]] = 1 diff --git a/src/wf_psf/utils/io.py b/src/wf_psf/utils/io.py index b49c9a47..b7444c7b 100644 --- a/src/wf_psf/utils/io.py +++ b/src/wf_psf/utils/io.py @@ -118,7 +118,6 @@ def get_timestamp(self): timestamp: str A string representation of the date and time. """ - timestamp = datetime.now().strftime("%Y%m%d%H%M") return timestamp @@ -190,7 +189,6 @@ def copy_conffile_to_output_dir(self, source_file): source_file: str Name of source file """ - source = os.path.join(self.config_path, source_file) destination = os.path.join( self.get_config_dir(self._run_output_dir), source_file diff --git a/src/wf_psf/utils/preprocessing.py b/src/wf_psf/utils/preprocessing.py deleted file mode 100644 index 210c03e5..00000000 --- a/src/wf_psf/utils/preprocessing.py +++ /dev/null @@ -1,102 +0,0 @@ -"""Preprocessing. - -A module with utils to preprocess data. - -:Author: Tobias Liaudat - -""" - -import numpy as np - - -def shift_x_y_to_zk1_2_wavediff(dxy, tel_focal_length=24.5, tel_diameter=1.2): - """Compute Zernike 1(2) for a given shifts in x(y) in WaveDifff conventions. - - All inputs should be in [m]. - A displacement of, for example, 0.5 pixels should be scaled with the corresponding pixel scale, - e.g. 12[um], to get a displacement in [m], which would be `dxy=0.5*12e-6`. - - The output zernike coefficient is in [um] units as expected by wavediff. - - To apply match the centroid with a `dx` that has a corresponding `zk1`, - the new PSF should be generated with `-zk1`. - - The same applies to `dy` and `zk2`. - - Parameters - ---------- - dxy : float - Centroid shift in [m]. It can be on the x-axis or the y-axis. - tel_focal_length : float - Telescope focal length in [m]. - tel_diameter : float - Telescope aperture diameter in [m]. - """ - reference_pix_sampling = 12e-6 - zernike_norm_factor = 2.0 - - # return zernike_norm_factor * (dx/reference_pix_sampling) / (tel_focal_length * tel_diameter / 2) - return ( - zernike_norm_factor - * (tel_diameter / 2) - * np.sin(np.arctan((dxy / reference_pix_sampling) / tel_focal_length)) - * 3.0 - ) - - -def defocus_to_zk4_zemax(dz, tel_focal_length=24.5, tel_diameter=1.2): - """Compute Zernike 4 value for a given defocus in zemax conventions. - - All inputs should be in [m]. - - Parameters - ---------- - dz : float - Shift in the z-axis, perpendicular to the focal plane. Units in [m]. - tel_focal_length : float - Telescope focal length in [m]. - tel_diameter : float - Telescope aperture diameter in [m]. - """ - # Base calculation - zk4 = dz / (8.0 * (tel_focal_length / tel_diameter) ** 2) - # Apply Z4 normalisation - # This step depends on the normalisation of the Zernike basis used - zk4 /= np.sqrt(3) - # Convert to waves with a reference of 800nm - zk4 /= 800e-9 - # Remove the peak to valley value - zk4 /= 2.0 - - return zk4 - - -def defocus_to_zk4_wavediff(dz, tel_focal_length=24.5, tel_diameter=1.2): - """Compute Zernike 4 value for a given defocus in WaveDifff conventions. - - All inputs should be in [m]. - - The output zernike coefficient is in [um] units as expected by wavediff. - - Parameters - ---------- - dz : float - Shift in the z-axis, perpendicular to the focal plane. Units in [m]. - tel_focal_length : float - Telescope focal length in [m]. - tel_diameter : float - Telescope aperture diameter in [m]. - """ - # Base calculation - zk4 = dz / (8.0 * (tel_focal_length / tel_diameter) ** 2) - # Apply Z4 normalisation - # This step depends on the normalisation of the Zernike basis used - zk4 /= np.sqrt(3) - - # Remove the peak to valley value - zk4 /= 2.0 - - # Change units to [um] as Wavediff uses - zk4 *= 1e6 - - return zk4 diff --git a/src/wf_psf/utils/read_config.py b/src/wf_psf/utils/read_config.py index f6ca3bf8..7282d4c4 100644 --- a/src/wf_psf/utils/read_config.py +++ b/src/wf_psf/utils/read_config.py @@ -33,9 +33,9 @@ class RecursiveNamespace(SimpleNamespace): def __init__(self, **kwargs): super().__init__(**kwargs) for key, val in kwargs.items(): - if isinstance(val,dict): + if isinstance(val, dict): setattr(self, key, RecursiveNamespace(**val)) - elif isinstance(val,list): + elif isinstance(val, list): setattr(self, key, list(map(self.map_entry, val))) @staticmethod @@ -56,7 +56,6 @@ def map_entry(entry): entry: type Original type of entry if type is not a dictionary """ - if isinstance(entry, dict): return RecursiveNamespace(**entry) @@ -100,21 +99,19 @@ def read_conf(conf_file): Recursive Namespace object """ - logger.info("Loading...{}".format(conf_file)) - with open(conf_file, "r") as f: + logger.info(f"Loading...{conf_file}") + with open(conf_file) as f: try: my_conf = yaml.safe_load(f) except (ParserError, ScannerError, TypeError): logger.exception( - "There is a syntax problem with your config file. Please check {}.".format( - conf_file - ) + f"There is a syntax problem with your config file. Please check {conf_file}." ) exit() if my_conf is None: raise TypeError( - "Config file {} is empty...Stopping Program.".format(conf_file) + f"Config file {conf_file} is empty...Stopping Program." ) exit() @@ -122,7 +119,7 @@ def read_conf(conf_file): return RecursiveNamespace(**my_conf) except TypeError as e: logger.exception( - "Check your config file for errors. Error Msg: {}.".format(e) + f"Check your config file for errors. Error Msg: {e}." ) exit() @@ -143,7 +140,7 @@ def read_stream(conf_file): A dictionary containing all config files. """ - stream = open(conf_file, "r") + stream = open(conf_file) docs = yaml.load_all(stream, yaml.FullLoader) for doc in docs: diff --git a/src/wf_psf/utils/utils.py b/src/wf_psf/utils/utils.py index e27ab33d..eeb67357 100644 --- a/src/wf_psf/utils/utils.py +++ b/src/wf_psf/utils/utils.py @@ -1,5 +1,5 @@ import numpy as np -from typing import Optional, Tuple +from typing import Tuple import tensorflow as tf import tensorflow_addons as tfa import PIL @@ -20,6 +20,28 @@ def scale_to_range(input_array, old_range, new_range): return input_array +def ensure_batch(arr): + """ + Ensure array/tensor has a batch dimension. Converts shape (M, N) → (1, M, N). + + Parameters + ---------- + arr : np.ndarray or tf.Tensor + Input 2D or 3D array/tensor. + + Returns + ------- + np.ndarray or tf.Tensor + With batch dimension prepended if needed. + """ + if isinstance(arr, np.ndarray): + return arr if arr.ndim == 3 else np.expand_dims(arr, axis=0) + elif isinstance(arr, tf.Tensor): + return arr if arr.ndim == 3 else tf.expand_dims(arr, axis=0) + else: + raise TypeError(f"Expected np.ndarray or tf.Tensor, got {type(arr)}") + + def calc_wfe(zernike_basis, zks): wfe = np.einsum("ijk,ijk->jk", zernike_basis, zks.reshape(-1, 1, 1)) return wfe @@ -92,7 +114,6 @@ def generate_SED_elems(SED, sim_psf_toolkit, n_bins=20): n_bins: int Number of wavelength bins """ - feasible_wv, SED_norm = sim_psf_toolkit.calc_SED_wave_values(SED, n_bins) feasible_N = np.array([sim_psf_toolkit.feasible_N(_wv) for _wv in feasible_wv]) @@ -118,7 +139,6 @@ def generate_SED_elems_in_tensorflow( tf_dtype: tf. Tensor Flow data type """ - feasible_wv, SED_norm = sim_psf_toolkit.calc_SED_wave_values(SED, n_bins) feasible_N = np.array([sim_psf_toolkit.feasible_N(_wv) for _wv in feasible_wv]) @@ -383,7 +403,7 @@ def estimate_noise(self, image: np.ndarray, mask: np.ndarray = None) -> float: return self.sigma_mad(image[self.window]) -class ZernikeInterpolation(object): +class ZernikeInterpolation: """Interpolate zernikes This class helps to interpolate zernikes using only the closest K elements @@ -456,7 +476,7 @@ def interpolate_zks(self, interp_positions): return tf.squeeze(interp_zks, axis=1) -class IndependentZernikeInterpolation(object): +class IndependentZernikeInterpolation: """Interpolate each Zernike polynomial independently The interpolation is done independently for each Zernike polynomial.