diff --git a/changelog.d/20260123_160331_jennifer.pollack_159_psf_output_from_trained_model.md b/changelog.d/20260123_160331_jennifer.pollack_159_psf_output_from_trained_model.md new file mode 100644 index 00000000..453a504e --- /dev/null +++ b/changelog.d/20260123_160331_jennifer.pollack_159_psf_output_from_trained_model.md @@ -0,0 +1,45 @@ + + + + +### New features + +- Added PSF inference capabilities for generating broadband (polychromatic) PSFs from trained models given star positions and SEDs +- Introduced `PSFInferenceEngine` class to centralize training, simulation, metrics, and inference workflows +- Added `run_type` attribute to `DataHandler` supporting training, simulation, metrics, and inference modes +- Implemented `ZernikeInputsFactory` class for building `ZernikeInputs` instances based on run type +- Added `psf_model_loader.py` module for centralized model weights loading + + +### Bug fixes + +- Fix logger formatting for relative RMSE metrics in `metrics.py` (values were not being displayed) + + + +### Internal changes + +- Refactored `TFPhysicalPolychromatic` and related modules to separate training vs. inference behavior +- Enhanced `ZernikeInputs` data class with intelligent assembly based on run type and available data +- Implemented hybrid loading pattern with eager loading in constructors and lazy-loading via property decorators +- Centralized PSF data extraction in `data_handler` module +- Improved code organization with new `tf_utils.py` module in `psf_models` sub-package +- Updated configuration handling to support inference workflows via `inference_config.yaml` +- Fixed incorrect argument name in `DataHandler` that prevented proper TensorFlow data type conversion +- Removed deprecated `get_obs_positions` method +- Updated documentation to include inference package diff --git a/changelog.d/20260210_150506_jennifer.pollack_159_psf_output_from_trained_model.md b/changelog.d/20260210_150506_jennifer.pollack_159_psf_output_from_trained_model.md new file mode 100644 index 00000000..7fb0859a --- /dev/null +++ b/changelog.d/20260210_150506_jennifer.pollack_159_psf_output_from_trained_model.md @@ -0,0 +1,37 @@ + + + + + + + +### Internal changes + +- Remove deprecated/optional import tensorflow-addons statement from tf_layers.py + + diff --git a/config/inference_conf.yaml b/config/inference_conf.yaml new file mode 100644 index 00000000..927723c7 --- /dev/null +++ b/config/inference_conf.yaml @@ -0,0 +1,37 @@ + +inference: + # Inference batch size + batch_size: 16 + # Cycle to use for inference. Can be: 1, 2, ... + cycle: 2 + + # Paths to the configuration files and trained model directory + configs: + # Path to the directory containing the trained model + trained_model_path: /path/to/trained/model/ + + # Subdirectory name of the trained model, e.g. psf_model + model_subdir: model + + # Relative Path to the training configuration file used to train the model + trained_model_config_path: config/training_config.yaml + + # Path to the data config file (this could contain prior information) + data_config_path: + + # The following parameters will overwrite the `model_params` in the training config file. + model_params: + # Num of wavelength bins to reconstruct polychromatic objects. + n_bins_lda: 8 + + # Downsampling rate to match the oversampled model to the specified telescope's sampling. + output_Q: 1 + + # Dimension of the pixel PSF postage stamp + output_dim: 64 + + # Flag to perform centroid error correction + correct_centroids: False + + # Flag to perform CCD misalignment error correction + add_ccd_misalignments: True diff --git a/docs/source/api.rst b/docs/source/api.rst index 3bc4d395..cda87632 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -8,6 +8,7 @@ This section contains the API reference for the main packages in WaveDiff. :recursive: wf_psf.data + wf_psf.inference wf_psf.metrics wf_psf.plotting wf_psf.psf_models diff --git a/docs/source/conf.py b/docs/source/conf.py index 988df8fe..01edfeb6 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -21,7 +21,7 @@ else: copyright = f"{start_year}, CosmoStat" author = "CosmoStat" -release = "3.0.0" +release = "3.1.0" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/docs/source/configuration.md b/docs/source/configuration.md index ba87cef2..173e229d 100644 --- a/docs/source/configuration.md +++ b/docs/source/configuration.md @@ -137,13 +137,13 @@ model_params: reference_shifts: [-1/3, -1/3] # Euclid-like default shifts # Obscuration / geometry - obscuration_rotation_angle: 0 # Degrees (multiple of 90); counterclockwise rotation. + obscuration_rotation_angle: 0 # Degrees (multiple of 90); counterclockwise rotation. # CCD misalignments input file path ccd_misalignments_input_path: /path/to/ccd_misalignments_file.txt - + # Boolean to use sample weights based on the noise standard deviation estimation - use_sample_weights: True + use_sample_weights: True # Sample weight generalised sigmoid function sample_weights_sigmoid: @@ -220,7 +220,6 @@ training_hparams: n_epochs_non_params: [100, 120] ``` - (metrics_config)= ## `metrics_config.yaml` — Metrics Configuration @@ -402,7 +401,10 @@ plotting_params: ### 4. Example Directory Structure Below is an example of three WaveDiff runs stored under a single parent directory: -``` +**Example Directory Structure** +Below is an example of three WaveDiff runs stored under a single parent directory: + +```arduino wf-outputs/ ├── wf-outputs-202305271829 │ ├── config diff --git a/pyproject.toml b/pyproject.toml index 187ead5e..ec0c8d34 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ maintainers = [ description = 'A software framework to perform Differentiable wavefront-based PSF modelling.' dependencies = [ - "numpy>=1.26.4,<2.0", + "numpy>=1.18,<1.24", "scipy", "tensorflow==2.11.0", "tensorflow-estimator", @@ -24,7 +24,7 @@ dependencies = [ "seaborn", ] -version = "3.0.0" +version = "3.1.0" [project.optional-dependencies] docs = [ diff --git a/src/wf_psf/__init__.py b/src/wf_psf/__init__.py index d4394f09..988b02fe 100644 --- a/src/wf_psf/__init__.py +++ b/src/wf_psf/__init__.py @@ -2,6 +2,6 @@ # Dynamically import modules to trigger side effects when wf_psf is imported importlib.import_module("wf_psf.psf_models.psf_models") -importlib.import_module("wf_psf.psf_models.psf_model_semiparametric") -importlib.import_module("wf_psf.psf_models.psf_model_physical_polychromatic") -importlib.import_module("wf_psf.psf_models.tf_psf_field") +importlib.import_module("wf_psf.psf_models.models.psf_model_semiparametric") +importlib.import_module("wf_psf.psf_models.models.psf_model_physical_polychromatic") +importlib.import_module("wf_psf.psf_models.tf_modules.tf_psf_field") diff --git a/src/wf_psf/utils/centroids.py b/src/wf_psf/data/centroids.py similarity index 80% rename from src/wf_psf/utils/centroids.py rename to src/wf_psf/data/centroids.py index 8b4522bf..01135428 100644 --- a/src/wf_psf/utils/centroids.py +++ b/src/wf_psf/data/centroids.py @@ -8,10 +8,80 @@ import numpy as np import scipy.signal as scisig -from wf_psf.utils.preprocessing import shift_x_y_to_zk1_2_wavediff +from fractions import Fraction from typing import Optional +def compute_centroid_correction( + model_params, centroid_dataset, batch_size: int = 1 +) -> np.ndarray: + """Compute centroid corrections using Zernike polynomials. + + This function calculates the Zernike contributions required to match the centroid + of the WaveDiff PSF model to the observed star centroids, processing in batches. + + Parameters + ---------- + model_params : RecursiveNamespace + An object containing parameters for the PSF model, including pixel sampling and initial centroid window parameters. + + centroid_dataset : dict + Dictionary containing star data needed for centroiding: + - "stamps" : np.ndarray + Array of star postage stamps (required). + - "masks" : Optional[np.ndarray] + Array of star masks (optional, can be None). + + batch_size : int, optional + The batch size to use when processing the stars. Default is 16. + + Returns + ------- + zernike_centroid_array : np.ndarray + A 2D NumPy array of shape `(n_stars, 3)`, where `n_stars` is the number of + observed stars. The array contains the computed Zernike (Z1, Z2) contributions, + with zero padding applied to the first column to ensure a consistent shape. + """ + # Retrieve stamps and masks from centroid_dataset + star_postage_stamps = centroid_dataset.get("stamps") + star_masks = centroid_dataset.get("masks") # may be None + + if star_postage_stamps is None: + raise ValueError("centroid_dataset must contain 'stamps'") + + pix_sampling = model_params.pix_sampling * 1e-6 # Change units from [um] to [m] + + reference_shifts = [ + float(Fraction(value)) for value in model_params.reference_shifts + ] + + n_stars = len(star_postage_stamps) + zernike_centroid_array = [] + + # Batch process the stars + for i in range(0, n_stars, batch_size): + batch_postage_stamps = star_postage_stamps[i : i + batch_size] + batch_masks = star_masks[i : i + batch_size] if star_masks is not None else None + + # Compute Zernike 1 and Zernike 2 for the batch + zk1_2_batch = -1.0 * compute_zernike_tip_tilt( + batch_postage_stamps, batch_masks, pix_sampling, reference_shifts + ) + + # Zero pad array for each batch and append + zernike_centroid_array.append( + np.pad( + zk1_2_batch, + pad_width=[(0, 0), (1, 0)], + mode="constant", + constant_values=0, + ) + ) + + # Combine all batches into a single array + return np.concatenate(zernike_centroid_array, axis=0) + + def compute_zernike_tip_tilt( star_images: np.ndarray, star_masks: Optional[np.ndarray] = None, @@ -58,6 +128,8 @@ def compute_zernike_tip_tilt( - This function processes all images at once using vectorized operations. - The Zernike coefficients are computed in the WaveDiff convention. """ + from wf_psf.data.data_zernike_utils import shift_x_y_to_zk1_2_wavediff + # Vectorize the centroid computation centroid_estimator = CentroidEstimator( im=star_images, mask=star_masks, sigma_init=sigma_init, n_iter=n_iter @@ -178,6 +250,18 @@ def __init__( self, im, mask=None, sigma_init=7.5, n_iter=5, auto_run=True, xc=None, yc=None ): """Initialize class attributes.""" + # Convert to np.ndarray if not already + im = np.asarray(im) + if mask is not None: + mask = np.asarray(mask) + + # Check im dimensions convert to batch, if 2D + if im.ndim == 2: + # Single stamp → convert to batch of one + im = np.expand_dims(im, axis=0) + elif im.ndim != 3: + raise ValueError(f"Expected 2D or 3D input, got shape {im.shape}") + self.im = im self.mask = mask if self.mask is not None: diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py new file mode 100644 index 00000000..052fe730 --- /dev/null +++ b/src/wf_psf/data/data_handler.py @@ -0,0 +1,446 @@ +"""Data Handler Module. + +Provides tools for loading, preprocessing, and managing data used in both training and inference workflows. + +Includes: +- The `DataHandler` class for managing datasets and associated metadata +- Utility functions for loading structured data products +- Preprocessing routines for spectral energy distributions (SEDs), including format conversion (e.g., to TensorFlow) and transformations + +This module serves as a central interface between raw data and modeling components. + +Authors: Jennifer Pollack , Tobias Liaudat +""" + +import os +import numpy as np +import wf_psf.utils.utils as utils +from wf_psf.psf_models.tf_modules.tf_utils import ensure_tensor +import tensorflow as tf +from typing import Optional, Union +import logging + +logger = logging.getLogger(__name__) + + +class DataHandler: + """ + DataHandler for WaveDiff PSF modeling. + + This class manages loading, preprocessing, and TensorFlow conversion of datasets used + for PSF model training, testing, and inference in the WaveDiff framework. + + Parameters + ---------- + dataset_type : str + Indicates the dataset mode ("train", "test", or "inference"). + data_params : RecursiveNamespace + Configuration object containing dataset parameters (e.g., file paths, preprocessing flags). + simPSF : PSFSimulator + An instance of the PSFSimulator class used to encode SEDs into a TensorFlow-compatible format. + n_bins_lambda : int + Number of wavelength bins used to discretize SEDs. + load_data : bool, optional + If True (default), loads and processes data during initialization. If False, data loading + must be triggered explicitly. + dataset : dict or list, optional + If provided, uses this pre-loaded dataset instead of triggering automatic loading. + sed_data : dict or list, optional + If provided, uses this SED data directly instead of extracting it from the dataset. + + Attributes + ---------- + dataset_type : str + Indicates the dataset mode ("train", "test", or "inference"). + data_params : RecursiveNamespace + Configuration parameters for data access and structure. + simPSF : PSFSimulator + Simulator used to transform SEDs into TensorFlow-ready tensors. + n_bins_lambda : int + Number of wavelength bins in the SED representation. + load_data_on_init : bool + Whether data was loaded automatically during initialization. + dataset : dict + Loaded dataset including keys such as 'positions', 'stars', 'noisy_stars', or similar. + sed_data : tf.Tensor + TensorFlow-formatted SED data with shape [batch_size, n_bins_lambda, features]. + """ + + def __init__( + self, + dataset_type, + data_params, + simPSF, + n_bins_lambda, + load_data: bool = True, + dataset: Optional[Union[dict, list]] = None, + sed_data: Optional[Union[dict, list]] = None, + ): + """ + Initialize the DataHandler for PSF dataset preparation. + + This constructor sets up the dataset handler used for PSF simulation tasks, + such as training, testing, or inference. It supports three modes of use: + + 1. **Manual mode** (`load_data=False`, no `dataset`): data loading and SED processing + must be triggered manually via `load_dataset()` and `process_sed_data()`. + 2. **Pre-loaded dataset mode** (`dataset` is provided): the given dataset is used directly, + and `process_sed_data()` is called with either the given `sed_data` or `dataset["SEDs"]`. + 3. **Automatic loading mode** (`load_data=True` and no `dataset`): the dataset is loaded + from disk using `data_params`, and SEDs are extracted and processed automatically. + + Parameters + ---------- + dataset_type : str + One of {"train", "test", "inference"} indicating dataset usage. + data_params : RecursiveNamespace + Configuration object with paths, preprocessing options, and metadata. + simPSF : PSFSimulator + Used to convert SEDs to TensorFlow format. + n_bins_lambda : int + Number of wavelength bins for the SEDs. + load_data : bool, optional + Whether to automatically load and process the dataset (default: True). + dataset : dict or list, optional + A pre-loaded dataset to use directly (overrides `load_data`). + sed_data : array-like, optional + Pre-loaded SED data to use directly. If not provided but `dataset` is, + SEDs are taken from `dataset["SEDs"]`. + + Raises + ------ + ValueError + If SEDs cannot be found in either `dataset` or as `sed_data`. + + Notes + ----- + - `self.dataset` and `self.sed_data` are both `None` if neither `dataset` nor + `load_data=True` is used. + - TensorFlow conversion is performed at the end of initialization via `convert_dataset_to_tensorflow()`. + """ + self.dataset_type = dataset_type + self.data_params = data_params + self.simPSF = simPSF + self.n_bins_lambda = n_bins_lambda + self.load_data_on_init = load_data + + if dataset is not None: + self.dataset = dataset + self.process_sed_data(sed_data) + self.validate_and_process_dataset() + elif self.load_data_on_init: + self.load_dataset() + self.process_sed_data(self.dataset["SEDs"]) + self.validate_and_process_dataset() + else: + self.dataset = None + self.sed_data = None + + @property + def tf_positions(self): + return ensure_tensor(self.dataset["positions"]) + + def load_dataset(self): + """Load dataset. + + Load the dataset based on the specified dataset type. + + """ + self.dataset = np.load( + os.path.join(self.data_params.data_dir, self.data_params.file), + allow_pickle=True, + )[()] + + def validate_and_process_dataset(self): + """Validate the dataset structure and convert fields to TensorFlow tensors.""" + self._validate_dataset_structure() + self._convert_dataset_to_tensorflow() + + def _validate_dataset_structure(self): + """Validate dataset structure based on dataset_type.""" + if self.dataset is None: + raise ValueError("Dataset is None") + + if "positions" not in self.dataset: + raise ValueError("Dataset missing required field: 'positions'") + + if self.dataset_type == "training": + if "noisy_stars" not in self.dataset: + raise ValueError( + f"Missing required field 'noisy_stars' in {self.dataset_type} dataset." + ) + elif self.dataset_type == "test": + if "stars" not in self.dataset: + raise ValueError( + f"Missing required field 'stars' in {self.dataset_type} dataset." + ) + elif self.dataset_type == "inference": + pass + else: + raise ValueError(f"Unrecognized dataset_type: {self.dataset_type}") + + def _convert_dataset_to_tensorflow(self): + """Convert dataset to TensorFlow tensors.""" + self.dataset["positions"] = ensure_tensor( + self.dataset["positions"], dtype=tf.float32 + ) + + if self.dataset_type == "train": + self.dataset["noisy_stars"] = ensure_tensor( + self.dataset["noisy_stars"], dtype=tf.float32 + ) + elif self.dataset_type == "test": + self.dataset["stars"] = ensure_tensor( + self.dataset["stars"], dtype=tf.float32 + ) + + def process_sed_data(self, sed_data): + """ + Generate and process SED (Spectral Energy Distribution) data. + + This method transforms raw SED inputs into TensorFlow tensors suitable for model input. + It generates wavelength-binned SED elements using the PSF simulator, converts the result + into a tensor, and transposes it to match the expected shape for training or inference. + + Parameters + ---------- + sed_data : list or array-like + A list or array of raw SEDs, where each SED is typically a vector of flux values + or coefficients. These will be processed using the PSF simulator. + + Raises + ------ + ValueError + If `sed_data` is None. + + Notes + ----- + The resulting tensor is stored in `self.sed_data` and has shape + `(num_samples, n_bins_lambda, n_components)`, where: + - `num_samples` is the number of SEDs, + - `n_bins_lambda` is the number of wavelength bins, + - `n_components` is the number of components per SED (e.g., filters or basis terms). + + The intermediate tensor is created with `tf.float64` for precision during generation, + but is converted to `tf.float32` after processing for use in training. + """ + if sed_data is None: + raise ValueError("SED data must be provided explicitly or via dataset.") + + self.sed_data = [ + utils.generate_SED_elems_in_tensorflow( + _sed, self.simPSF, n_bins=self.n_bins_lambda, tf_dtype=tf.float64 + ) + for _sed in sed_data + ] + # Convert list of generated SED tensors to a single TensorFlow tensor of float32 dtype + self.sed_data = ensure_tensor(self.sed_data, dtype=tf.float32) + self.sed_data = tf.transpose(self.sed_data, perm=[0, 2, 1]) + + +def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: + """ + Extract and concatenate star-related data from training and test datasets. + + This function retrieves arrays (e.g., postage stamps, masks, positions) from + both the training and test datasets using the specified keys, converts them + to NumPy if necessary, and concatenates them along the first axis. + + Parameters + ---------- + data : DataConfigHandler + Object containing training and test datasets. + train_key : str + Key to retrieve data from the training dataset + (e.g., 'noisy_stars', 'masks'). + test_key : str + Key to retrieve data from the test dataset + (e.g., 'stars', 'masks'). + + Returns + ------- + np.ndarray + Concatenated NumPy array containing the selected data from both + training and test sets. + + Raises + ------ + KeyError + If either the training or test dataset does not contain the + requested key. + + Notes + ----- + - Designed for datasets with separate train/test splits, such as when + evaluating metrics on held-out data. + - TensorFlow tensors are automatically converted to NumPy arrays. + - Requires eager execution if TensorFlow tensors are present. + """ + # Ensure the requested keys exist in both training and test datasets + missing_keys = [ + key + for key, dataset in [ + (train_key, data.training_data.dataset), + (test_key, data.test_data.dataset), + ] + if key not in dataset + ] + + if missing_keys: + raise KeyError(f"Missing keys in dataset: {missing_keys}") + + # Retrieve data from training and test sets + train_data = data.training_data.dataset[train_key] + test_data = data.test_data.dataset[test_key] + + # Convert to NumPy if necessary + if tf.is_tensor(train_data): + train_data = train_data.numpy() + if tf.is_tensor(test_data): + test_data = test_data.numpy() + + # Concatenate and return + return np.concatenate((train_data, test_data), axis=0) + + +def get_data_array( + data, + run_type: str, + key: str = None, + train_key: str = None, + test_key: str = None, + allow_missing: bool = True, +) -> Optional[np.ndarray]: + """ + Retrieve data from dataset depending on run type. + + This function provides a unified interface for accessing data across different + execution contexts (training, simulation, metrics, inference). It handles + key resolution with sensible fallbacks and optional missing data tolerance. + + Parameters + ---------- + data : DataConfigHandler + Dataset object containing training, test, or inference data. + Expected to have methods compatible with the specified run_type. + run_type : {"training", "simulation", "metrics", "inference"} + Execution context that determines how data is retrieved: + - "training", "simulation", "metrics": Uses extract_star_data function + - "inference": Retrieves data directly from dataset using key lookup + key : str, optional + Primary key for data lookup. Used directly for inference run_type. + If None, falls back to train_key value. Default is None. + train_key : str, optional + Key for training dataset access. If None and key is provided, + defaults to key value. Default is None. + test_key : str, optional + Key for test dataset access. If None, defaults to the resolved + train_key value. Default is None. + allow_missing : bool, default True + Control behavior when data is missing or keys are not found: + - True: Return None instead of raising exceptions + - False: Raise appropriate exceptions (KeyError, ValueError) + + Returns + ------- + np.ndarray or None + Retrieved data as NumPy array. Returns None only when allow_missing=True + and the requested data is not available. + + Raises + ------ + ValueError + If run_type is not one of the supported values, or if no key can be + resolved for the operation and allow_missing=False. + KeyError + If the specified key is not found in the dataset and allow_missing=False. + + Notes + ----- + Key resolution follows this priority order: + 1. train_key = train_key or key + 2. test_key = test_key or resolved_train_key + 3. key = key or resolved_train_key (for inference fallback) + + For TensorFlow tensors, the .numpy() method is called to convert to NumPy. + Other data types are converted using np.asarray(). + + Examples + -------- + >>> # Training data retrieval + >>> train_data = get_data_array(data, "training", train_key="noisy_stars") + + >>> # Inference with fallback handling + >>> inference_data = get_data_array(data, "inference", key="positions", + ... allow_missing=True) + >>> if inference_data is None: + ... print("No inference data available") + + >>> # Using key parameter for both train and inference + >>> result = get_data_array(data, "inference", key="positions") + """ + # Validate run_type early + valid_run_types = {"training", "simulation", "metrics", "inference"} + if run_type not in valid_run_types: + raise ValueError(f"run_type must be one of {valid_run_types}, got '{run_type}'") + + # Simplify key resolution with clear precedence + effective_train_key = train_key or key + effective_test_key = test_key or effective_train_key + effective_key = key or effective_train_key + + try: + if run_type in {"simulation", "training", "metrics"}: + return extract_star_data(data, effective_train_key, effective_test_key) + else: # inference + return _get_direct_data(data, effective_key, allow_missing) + except Exception: + if allow_missing: + return None + raise + + +def _get_direct_data(data, key: str, allow_missing: bool) -> Optional[np.ndarray]: + """ + Extract data directly with proper error handling and type conversion. + + Parameters + ---------- + data : DataConfigHandler + Dataset object with a .dataset attribute that supports .get() method. + key : str or None + Key to lookup in the dataset. If None, behavior depends on allow_missing. + allow_missing : bool + If True, return None for missing keys/data instead of raising exceptions. + + Returns + ------- + np.ndarray or None + Data converted to NumPy array, or None if allow_missing=True and + data is unavailable. + + Raises + ------ + ValueError + If key is None and allow_missing=False. + KeyError + If key is not found in dataset and allow_missing=False. + + Notes + ----- + Conversion logic: + - TensorFlow tensors: Converted using .numpy() method + - Other types: Converted using np.asarray() + """ + if key is None: + if allow_missing: + return None + raise ValueError("No key provided for inference data") + + value = data.dataset.get(key, None) + if value is None: + if allow_missing: + return None + raise KeyError(f"Key '{key}' not found in inference dataset") + + return value.numpy() if tf.is_tensor(value) else np.asarray(value) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py new file mode 100644 index 00000000..0fad9c8e --- /dev/null +++ b/src/wf_psf/data/data_zernike_utils.py @@ -0,0 +1,458 @@ +"""Utilities for Zernike Data Handling. + +This module provides utility functions for working with Zernike coefficients, including: +- Prior generation +- Data loading +- Conversions between physical displacements (e.g., defocus, centroid shifts) and modal Zernike coefficients + +Useful in contexts where Zernike representations are used to model optical aberrations or link physical misalignments to wavefront modes. + +:Author: Tobias Liaudat + +""" + +from dataclasses import dataclass +from typing import Optional, Union +import numpy as np +import tensorflow as tf +from wf_psf.data.centroids import compute_centroid_correction +from wf_psf.data.data_handler import get_data_array +from wf_psf.instrument.ccd_misalignments import compute_ccd_misalignment +from wf_psf.utils.read_config import RecursiveNamespace +import logging + +logger = logging.getLogger(__name__) + + +@dataclass +class ZernikeInputs: + zernike_prior: Optional[np.ndarray] # true prior, if provided (e.g. from PDC) + centroid_dataset: Optional[ + Union[dict, "RecursiveNamespace"] + ] # only used in training/simulation + misalignment_positions: Optional[np.ndarray] # needed for CCD corrections + + +class ZernikeInputsFactory: + @staticmethod + def build( + data, run_type: str, model_params, prior: Optional[np.ndarray] = None + ) -> ZernikeInputs: + """Builds a ZernikeInputs dataclass instance based on run type and data. + + Parameters + ---------- + data : Union[dict, DataConfigHandler] + Dataset object containing star positions, priors, and optionally pixel data. + run_type : str + One of 'training', 'simulation', or 'inference'. + model_params : RecursiveNamespace + Model parameters, including flags for prior/corrections. + prior : Optional[np.ndarray] + An explicitly passed prior (overrides any inferred one if provided). + + Returns + ------- + ZernikeInputs + """ + centroid_dataset, positions = None, None + + if run_type in {"training", "simulation", "metrics"}: + stamps = get_data_array( + data, run_type, train_key="noisy_stars", test_key="stars" + ) + masks = get_data_array(data, run_type, key="masks", allow_missing=True) + centroid_dataset = {"stamps": stamps, "masks": masks} + + positions = get_data_array(data=data, run_type=run_type, key="positions") + + if model_params.use_prior: + if prior is not None: + logger.warning( + "Explicit prior provided; ignoring dataset-based prior." + ) + else: + prior = get_np_zernike_prior(data) + + elif run_type == "inference": + stamps = get_data_array(data=data, run_type=run_type, key="sources") + masks = get_data_array(data, run_type, key="masks", allow_missing=True) + centroid_dataset = {"stamps": stamps, "masks": masks} + + positions = get_data_array(data=data, run_type=run_type, key="positions") + + if model_params.use_prior: + # Try to extract prior from `data`, if present + prior = ( + getattr(data.dataset, "zernike_prior", None) + if not isinstance(data.dataset, dict) + else data.dataset.get("zernike_prior") + ) + + if prior is None: + logger.warning( + "model_params.use_prior=True but no prior found in inference data. Proceeding with None." + ) + + else: + raise ValueError(f"Unsupported run_type: {run_type}") + + return ZernikeInputs( + zernike_prior=prior, + centroid_dataset=centroid_dataset, + misalignment_positions=positions, + ) + + +def get_np_zernike_prior(data): + """Get the zernike prior from the provided dataset. + + This method concatenates the stars from both the training + and test datasets to obtain the full prior. + + Parameters + ---------- + data : DataConfigHandler + Object containing training and test datasets. + + Returns + ------- + zernike_prior : np.ndarray + Numpy array containing the full prior. + """ + zernike_prior = np.concatenate( + ( + data.training_data.dataset["zernike_prior"], + data.test_data.dataset["zernike_prior"], + ), + axis=0, + ) + + return zernike_prior + + +def pad_contribution_to_order(contribution: np.ndarray, max_order: int) -> np.ndarray: + """Pad a Zernike contribution array to the max Zernike order.""" + current_order = contribution.shape[1] + pad_width = ((0, 0), (0, max_order - current_order)) + return np.pad(contribution, pad_width=pad_width, mode="constant", constant_values=0) + + +def combine_zernike_contributions(contributions: list[np.ndarray]) -> np.ndarray: + """Combine multiple Zernike contributions, padding each to the max order before summing.""" + if not contributions: + raise ValueError("No contributions provided.") + + if len(contributions) == 1: + return contributions[0] + + max_order = max(contrib.shape[1] for contrib in contributions) + n_samples = contributions[0].shape[0] + + if any(c.shape[0] != n_samples for c in contributions): + raise ValueError("All contributions must have the same number of samples.") + + combined = np.zeros((n_samples, max_order)) + # Pad each contribution to the max order and sum them + for contrib in contributions: + padded = pad_contribution_to_order(contrib, max_order) + combined += padded + + return combined + + +def pad_tf_zernikes(zk_param: tf.Tensor, zk_prior: tf.Tensor, n_zks_total: int): + """ + Pad the Zernike coefficient tensors to match the specified total number of Zernikes. + + Parameters + ---------- + zk_param : tf.Tensor + Zernike coefficients for the parametric part. Shape [batch, n_zks_param, 1, 1]. + zk_prior : tf.Tensor + Zernike coefficients for the prior part. Shape [batch, n_zks_prior, 1, 1]. + n_zks_total : int + Total number of Zernikes to pad to. + + Returns + ------- + padded_zk_param : tf.Tensor + Padded Zernike coefficients for the parametric part. Shape [batch, n_zks_total, 1, 1]. + padded_zk_prior : tf.Tensor + Padded Zernike coefficients for the prior part. Shape [batch, n_zks_total, 1, 1]. + """ + pad_num_param = n_zks_total - tf.shape(zk_param)[1] + pad_num_prior = n_zks_total - tf.shape(zk_prior)[1] + + padded_zk_param = tf.cond( + tf.not_equal(pad_num_param, 0), + lambda: tf.pad(zk_param, [(0, 0), (0, pad_num_param), (0, 0), (0, 0)]), + lambda: zk_param, + ) + + padded_zk_prior = tf.cond( + tf.not_equal(pad_num_prior, 0), + lambda: tf.pad(zk_prior, [(0, 0), (0, pad_num_prior), (0, 0), (0, 0)]), + lambda: zk_prior, + ) + + return padded_zk_param, padded_zk_prior + + +def assemble_zernike_contributions( + model_params, + zernike_prior=None, + centroid_dataset=None, + positions=None, + batch_size=16, +): + """ + Assemble the total Zernike contribution map by combining the prior, + centroid correction, and CCD misalignment correction. + + Parameters + ---------- + model_params : RecursiveNamespace + Parameters controlling which contributions to apply. + zernike_prior : Optional[np.ndarray or tf.Tensor] + The precomputed Zernike prior. Can be either a NumPy array or a TensorFlow tensor. + If a Tensor, will be converted to NumPy in eager mode. + centroid_dataset : Optional[object] + Dataset used to compute centroid correction. Must have both training and test sets. + positions : Optional[np.ndarray or tf.Tensor] + Positions used for computing CCD misalignment. Must be available in inference mode. + batch_size : int + Batch size for centroid correction. + + Returns + ------- + tf.Tensor + A tensor representing the full Zernike contribution map. + """ + zernike_contribution_list = [] + + # Prior + if model_params.use_prior and zernike_prior is not None: + logger.info("Adding Zernike prior...") + if isinstance(zernike_prior, tf.Tensor): + if tf.executing_eagerly(): + zernike_prior = zernike_prior.numpy() + else: + raise RuntimeError( + "Zernike prior is a TensorFlow tensor but eager execution is disabled. " + "Cannot call `.numpy()` outside of eager mode." + ) + + elif not isinstance(zernike_prior, np.ndarray): + raise TypeError( + "Unsupported zernike_prior type. Must be np.ndarray or tf.Tensor." + ) + zernike_contribution_list.append(zernike_prior) + else: + logger.info("Skipping Zernike prior (not used or not provided).") + + # Centroid correction (tip/tilt) + if model_params.correct_centroids and centroid_dataset is not None: + logger.info("Computing centroid correction...") + centroid_correction = compute_centroid_correction( + model_params, centroid_dataset, batch_size=batch_size + ) + zernike_contribution_list.append(centroid_correction) + else: + logger.info("Skipping centroid correction (not enabled or no dataset).") + + # CCD misalignment (focus term) + if model_params.add_ccd_misalignments and positions is not None: + logger.info("Computing CCD misalignment correction...") + ccd_misalignment = compute_ccd_misalignment(model_params, positions) + zernike_contribution_list.append(ccd_misalignment) + else: + logger.info( + "Skipping CCD misalignment correction (not enabled or no positions)." + ) + + # If no contributions, return zeros tensor to avoid crashes + if not zernike_contribution_list: + logger.warning("No Zernike contributions found. Returning zero tensor.") + # Infer batch size and zernike order from model_params + n_samples = 1 + n_zks = getattr(model_params.param_hparams, "n_zernikes", 10) + return tf.zeros((n_samples, n_zks), dtype=tf.float32) + + combined_zernike_prior = combine_zernike_contributions(zernike_contribution_list) + + return tf.convert_to_tensor(combined_zernike_prior, dtype=tf.float32) + + +def shift_x_y_to_zk1_2_wavediff(dxy, tel_focal_length=24.5, tel_diameter=1.2): + """Compute Zernike 1(2) for a given shifts in x(y) in WaveDifff conventions. + + All inputs should be in [m]. + A displacement of, for example, 0.5 pixels should be scaled with the corresponding pixel scale, + e.g. 12[um], to get a displacement in [m], which would be `dxy=0.5*12e-6`. + + The output zernike coefficient is in [um] units as expected by wavediff. + + To apply match the centroid with a `dx` that has a corresponding `zk1`, + the new PSF should be generated with `-zk1`. + + The same applies to `dy` and `zk2`. + + Parameters + ---------- + dxy : float + Centroid shift in [m]. It can be on the x-axis or the y-axis. + tel_focal_length : float + Telescope focal length in [m]. + tel_diameter : float + Telescope aperture diameter in [m]. + """ + reference_pix_sampling = 12e-6 + zernike_norm_factor = 2.0 + + # return zernike_norm_factor * (dx/reference_pix_sampling) / (tel_focal_length * tel_diameter / 2) + return ( + zernike_norm_factor + * (tel_diameter / 2) + * np.sin(np.arctan((dxy / reference_pix_sampling) / tel_focal_length)) + * 3.0 + ) + + +def compute_zernike_tip_tilt( + star_images: np.ndarray, + star_masks: Optional[np.ndarray] = None, + pixel_sampling: float = 12e-6, + reference_shifts: list[float] = [-1 / 3, -1 / 3], + sigma_init: float = 2.5, + n_iter: int = 20, +) -> np.ndarray: + """ + Compute Zernike tip-tilt corrections for a batch of PSF images. + + This function estimates the centroid shifts of multiple PSFs and computes + the corresponding Zernike tip-tilt corrections to align them with a reference. + + Parameters + ---------- + star_images : np.ndarray + A batch of PSF images (3D array of shape `(num_images, height, width)`). + star_masks : np.ndarray, optional + A batch of masks (same shape as `star_postage_stamps`). Each mask can have: + - `0` to ignore the pixel. + - `1` to fully consider the pixel. + - Values in `(0,1]` as weights for partial consideration. + Defaults to None. + pixel_sampling : float, optional + The pixel size in meters. Defaults to `12e-6 m` (12 microns). + reference_shifts : list[float], optional + The target centroid shifts in pixels, specified as `[dy, dx]`. + Defaults to `[-1/3, -1/3]` (nominal Euclid conditions). + sigma_init : float, optional + Initial standard deviation for centroid estimation. Default is `2.5`. + n_iter : int, optional + Number of iterations for centroid refinement. Default is `20`. + + Returns + ------- + np.ndarray + An array of shape `(num_images, 2)`, where: + - Column 0 contains `Zk1` (tip) values. + - Column 1 contains `Zk2` (tilt) values. + + Notes + ----- + - This function processes all images at once using vectorized operations. + - The Zernike coefficients are computed in the WaveDiff convention. + """ + from wf_psf.data.centroids import CentroidEstimator + + # Vectorize the centroid computation + centroid_estimator = CentroidEstimator( + im=star_images, mask=star_masks, sigma_init=sigma_init, n_iter=n_iter + ) + + shifts = centroid_estimator.get_intra_pixel_shifts() + + # Ensure reference_shifts is a NumPy array (if it's not already) + reference_shifts = np.array(reference_shifts) + + # Reshape to ensure it's a column vector (1, 2) + reference_shifts = reference_shifts[None, :] + + # Broadcast reference_shifts to match the shape of shifts + reference_shifts = np.broadcast_to(reference_shifts, shifts.shape) + + # Compute displacements + displacements = reference_shifts - shifts # + + # Ensure the correct axis order for displacements (x-axis, then y-axis) + displacements_swapped = displacements[:, [1, 0]] # Adjust axis order if necessary + + # Call shift_x_y_to_zk1_2_wavediff directly on the vector of displacements + zk1_2_array = shift_x_y_to_zk1_2_wavediff( + displacements_swapped.flatten() * pixel_sampling + ) # vectorized call + + # Reshape the result back to the original shape of displacements + zk1_2_array = zk1_2_array.reshape(displacements.shape) + + return zk1_2_array + + +def defocus_to_zk4_zemax(dz, tel_focal_length=24.5, tel_diameter=1.2): + """Compute Zernike 4 value for a given defocus in zemax conventions. + + All inputs should be in [m]. + + Parameters + ---------- + dz : float + Shift in the z-axis, perpendicular to the focal plane. Units in [m]. + tel_focal_length : float + Telescope focal length in [m]. + tel_diameter : float + Telescope aperture diameter in [m]. + """ + # Base calculation + zk4 = dz / (8.0 * (tel_focal_length / tel_diameter) ** 2) + # Apply Z4 normalisation + # This step depends on the normalisation of the Zernike basis used + zk4 /= np.sqrt(3) + # Convert to waves with a reference of 800nm + zk4 /= 800e-9 + # Remove the peak to valley value + zk4 /= 2.0 + + return zk4 + + +def defocus_to_zk4_wavediff(dz, tel_focal_length=24.5, tel_diameter=1.2): + """Compute Zernike 4 value for a given defocus in WaveDifff conventions. + + All inputs should be in [m]. + + The output zernike coefficient is in [um] units as expected by wavediff. + + Parameters + ---------- + dz : float + Shift in the z-axis, perpendicular to the focal plane. Units in [m]. + tel_focal_length : float + Telescope focal length in [m]. + tel_diameter : float + Telescope aperture diameter in [m]. + """ + # Base calculation + zk4 = dz / (8.0 * (tel_focal_length / tel_diameter) ** 2) + # Apply Z4 normalisation + # This step depends on the normalisation of the Zernike basis used + zk4 /= np.sqrt(3) + + # Remove the peak to valley value + zk4 /= 2.0 + + # Change units to [um] as Wavediff uses + zk4 *= 1e6 + + return zk4 diff --git a/src/wf_psf/data/training_preprocessing.py b/src/wf_psf/data/training_preprocessing.py deleted file mode 100644 index c1402f06..00000000 --- a/src/wf_psf/data/training_preprocessing.py +++ /dev/null @@ -1,458 +0,0 @@ -"""Training Data Processing. - -A module to load and preprocess training and validation test data. - -:Authors: Jennifer Pollack and Tobias Liaudat - -""" - -import os -import numpy as np -import wf_psf.utils.utils as utils -import tensorflow as tf -from wf_psf.utils.ccd_misalignments import CCDMisalignmentCalculator -from wf_psf.utils.centroids import compute_zernike_tip_tilt -from fractions import Fraction -import logging - -logger = logging.getLogger(__name__) - - -class DataHandler: - """Data Handler. - - This class manages loading and processing of training and testing data for use during PSF model training and validation. - It provides methods to access and preprocess the data. - - Parameters - ---------- - dataset_type : str - Type of dataset ("train" or "test"). - data_params : RecursiveNamespace - Recursive Namespace object containing parameters for both 'train' and 'test' datasets. - simPSF : PSFSimulator - Instance of the PSFSimulator class for simulating PSF models. - n_bins_lambda : int - Number of wavelength bins for SED processing. - load_data : bool, optional - If True, data is loaded and processed during initialization. If False, data loading - is deferred until explicitly called. Default is True. - - Attributes - ---------- - dataset_type : str - Type of dataset ("train" or "test"). - data_params : RecursiveNamespace - Parameters for the current dataset type. - dataset : dict or None - Dictionary containing the loaded dataset, including positions and stars/noisy_stars. - simPSF : PSFSimulator - Instance of the PSFSimulator class for simulating PSF models. - n_bins_lambda : int - Number of wavelength bins. - sed_data : tf.Tensor or None - TensorFlow tensor containing processed SED data for training/testing. - load_data_on_init : bool - Flag controlling whether data is loaded during initialization. - """ - - def __init__( - self, dataset_type, data_params, simPSF, n_bins_lambda, load_data: bool = True - ): - """ - Initialize the dataset handler for PSF simulation. - - Parameters - ---------- - dataset_type : str - Type of dataset ("train" or "test"). - data_params : RecursiveNamespace - Recursive Namespace object containing parameters for both 'train' and 'test' datasets. - simPSF : PSFSimulator - Instance of the PSFSimulator class for simulating PSF models. - n_bins_lambda : int - Number of wavelength bins for SED processing. - load_data : bool, optional - If True, data is loaded and processed during initialization. If False, data loading - is deferred. Default is True. - """ - self.dataset_type = dataset_type - self.data_params = data_params.__dict__[dataset_type] - self.simPSF = simPSF - self.n_bins_lambda = n_bins_lambda - self.dataset = None - self.sed_data = None - self.load_data_on_init = load_data - if self.load_data_on_init: - self.load_dataset() - self.process_sed_data() - - def load_dataset(self): - """Load dataset. - - Load the dataset based on the specified dataset type. - - """ - self.dataset = np.load( - os.path.join(self.data_params.data_dir, self.data_params.file), - allow_pickle=True, - )[()] - self.dataset["positions"] = tf.convert_to_tensor( - self.dataset["positions"], dtype=tf.float32 - ) - if self.dataset_type == "training": - if "noisy_stars" in self.dataset: - self.dataset["noisy_stars"] = tf.convert_to_tensor( - self.dataset["noisy_stars"], dtype=tf.float32 - ) - else: - logger.warning(f"Missing 'noisy_stars' in {self.dataset_type} dataset.") - elif self.dataset_type == "test": - if "stars" in self.dataset: - self.dataset["stars"] = tf.convert_to_tensor( - self.dataset["stars"], dtype=tf.float32 - ) - else: - logger.warning(f"Missing 'stars' in {self.dataset_type} dataset.") - - def process_sed_data(self): - """Process SED Data. - - A method to generate and process SED data. - - """ - self.sed_data = [ - utils.generate_SED_elems_in_tensorflow( - _sed, self.simPSF, n_bins=self.n_bins_lambda, tf_dtype=tf.float64 - ) - for _sed in self.dataset["SEDs"] - ] - self.sed_data = tf.convert_to_tensor(self.sed_data, dtype=tf.float32) - self.sed_data = tf.transpose(self.sed_data, perm=[0, 2, 1]) - - -def get_np_obs_positions(data): - """Get observed positions in numpy from the provided dataset. - - This method concatenates the positions of the stars from both the training - and test datasets to obtain the observed positions. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - np.ndarray - Numpy array containing the observed positions of the stars. - - Notes - ----- - The observed positions are obtained by concatenating the positions of stars - from both the training and test datasets along the 0th axis. - """ - obs_positions = np.concatenate( - ( - data.training_data.dataset["positions"], - data.test_data.dataset["positions"], - ), - axis=0, - ) - - return obs_positions - - -def get_obs_positions(data): - """Get observed positions from the provided dataset. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - tf.Tensor - Tensor containing the observed positions of the stars. - """ - obs_positions = get_np_obs_positions(data) - - return tf.convert_to_tensor(obs_positions, dtype=tf.float32) - - -def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: - """Extract specific star-related data from training and test datasets. - - This function retrieves and concatenates specific star-related data (e.g., stars, masks) from the - star training and test datasets such as star stamps or masks, based on the provided keys. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - train_key : str - The key to retrieve data from the training dataset (e.g., 'noisy_stars', 'masks'). - test_key : str - The key to retrieve data from the test dataset (e.g., 'stars', 'masks'). - - Returns - ------- - np.ndarray - A NumPy array containing the concatenated data for the given keys. - - Raises - ------ - KeyError - If the specified keys do not exist in the training or test datasets. - - Notes - ----- - - If the dataset contains TensorFlow tensors, they will be converted to NumPy arrays. - - Ensure that eager execution is enabled when calling this function. - """ - # Ensure the requested keys exist in both training and test datasets - missing_keys = [ - key - for key, dataset in [ - (train_key, data.training_data.dataset), - (test_key, data.test_data.dataset), - ] - if key not in dataset - ] - - if missing_keys: - raise KeyError(f"Missing keys in dataset: {missing_keys}") - - # Retrieve data from training and test sets - train_data = data.training_data.dataset[train_key] - test_data = data.test_data.dataset[test_key] - - # Convert to NumPy if necessary - if tf.is_tensor(train_data): - train_data = train_data.numpy() - if tf.is_tensor(test_data): - test_data = test_data.numpy() - - # Concatenate and return - return np.concatenate((train_data, test_data), axis=0) - - -def get_np_zernike_prior(data): - """Get the zernike prior from the provided dataset. - - This method concatenates the stars from both the training - and test datasets to obtain the full prior. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - zernike_prior : np.ndarray - Numpy array containing the full prior. - """ - zernike_prior = np.concatenate( - ( - data.training_data.dataset["zernike_prior"], - data.test_data.dataset["zernike_prior"], - ), - axis=0, - ) - - return zernike_prior - - -def compute_centroid_correction(model_params, data, batch_size: int = 1) -> np.ndarray: - """Compute centroid corrections using Zernike polynomials. - - This function calculates the Zernike contributions required to match the centroid - of the WaveDiff PSF model to the observed star centroids, processing in batches. - - Parameters - ---------- - model_params : RecursiveNamespace - An object containing parameters for the PSF model, including pixel sampling and initial centroid window parameters. - - data : DataConfigHandler - An object containing training and test datasets, including observed PSFs - and optional star masks. - - batch_size : int, optional - The batch size to use when processing the stars. Default is 16. - - - Returns - ------- - zernike_centroid_array : np.ndarray - A 2D NumPy array of shape `(n_stars, 3)`, where `n_stars` is the number of - observed stars. The array contains the computed Zernike contributions, - with zero padding applied to the first column to ensure a consistent shape. - """ - star_postage_stamps = extract_star_data( - data=data, train_key="noisy_stars", test_key="stars" - ) - - # Get star mask catalogue only if "masks" exist in both training and test datasets - star_masks = ( - extract_star_data(data=data, train_key="masks", test_key="masks") - if ( - data.training_data.dataset.get("masks") is not None - and data.test_data.dataset.get("masks") is not None - and tf.size(data.training_data.dataset["masks"]) > 0 - and tf.size(data.test_data.dataset["masks"]) > 0 - ) - else None - ) - - pix_sampling = model_params.pix_sampling * 1e-6 # Change units from [um] to [m] - - # Ensure star_masks is properly handled - star_masks = star_masks if star_masks is not None else None - - reference_shifts = [ - float(Fraction(value)) for value in model_params.reference_shifts - ] - - n_stars = len(star_postage_stamps) - zernike_centroid_array = [] - - # Batch process the stars - for i in range(0, n_stars, batch_size): - batch_postage_stamps = star_postage_stamps[i : i + batch_size] - batch_masks = star_masks[i : i + batch_size] if star_masks is not None else None - - # Compute Zernike 1 and Zernike 2 for the batch - zk1_2_batch = -1.0 * compute_zernike_tip_tilt( - batch_postage_stamps, batch_masks, pix_sampling, reference_shifts - ) - - # Zero pad array for each batch and append - zernike_centroid_array.append( - np.pad( - zk1_2_batch, - pad_width=[(0, 0), (1, 0)], - mode="constant", - constant_values=0, - ) - ) - - # Combine all batches into a single array - return np.concatenate(zernike_centroid_array, axis=0) - - -def compute_ccd_misalignment(model_params, data): - """Compute CCD misalignment. - - Parameters - ---------- - model_params : RecursiveNamespace - Object containing parameters for this PSF model class. - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - zernike_ccd_misalignment_array : np.ndarray - Numpy array containing the Zernike contributions to model the CCD chip misalignments. - """ - obs_positions = get_np_obs_positions(data) - - ccd_misalignment_calculator = CCDMisalignmentCalculator( - tiles_path=model_params.ccd_misalignments_input_path, - x_lims=model_params.x_lims, - y_lims=model_params.y_lims, - tel_focal_length=model_params.tel_focal_length, - tel_diameter=model_params.tel_diameter, - ) - # Compute required zernike 4 for each position - zk4_values = np.array( - [ - ccd_misalignment_calculator.get_zk4_from_position(single_pos) - for single_pos in obs_positions - ] - ).reshape(-1, 1) - - # Zero pad array to get shape (n_stars, n_zernike=4) - zernike_ccd_misalignment_array = np.pad( - zk4_values, pad_width=[(0, 0), (3, 0)], mode="constant", constant_values=0 - ) - - return zernike_ccd_misalignment_array - - -def get_zernike_prior(model_params, data, batch_size: int = 16): - """Get Zernike priors from the provided dataset. - - This method concatenates the Zernike priors from both the training - and test datasets. - - Parameters - ---------- - model_params : RecursiveNamespace - Object containing parameters for this PSF model class. - data : DataConfigHandler - Object containing training and test datasets. - batch_size : int, optional - The batch size to use when processing the stars. Default is 16. - - Returns - ------- - tf.Tensor - Tensor containing the observed positions of the stars. - - Notes - ----- - The Zernike prior are obtained by concatenating the Zernike priors - from both the training and test datasets along the 0th axis. - - """ - # List of zernike contribution - zernike_contribution_list = [] - - if model_params.use_prior: - logger.info("Reading in Zernike prior into Zernike contribution list...") - zernike_contribution_list.append(get_np_zernike_prior(data)) - - if model_params.correct_centroids: - logger.info("Adding centroid correction to Zernike contribution list...") - zernike_contribution_list.append( - compute_centroid_correction(model_params, data, batch_size) - ) - - if model_params.add_ccd_misalignments: - logger.info("Adding CCD mis-alignments to Zernike contribution list...") - zernike_contribution_list.append(compute_ccd_misalignment(model_params, data)) - - if len(zernike_contribution_list) == 1: - zernike_contribution = zernike_contribution_list[0] - else: - # Get max zk order - max_zk_order = np.max( - np.array( - [ - zk_contribution.shape[1] - for zk_contribution in zernike_contribution_list - ] - ) - ) - - zernike_contribution = np.zeros( - (zernike_contribution_list[0].shape[0], max_zk_order) - ) - - # Pad arrays to get the same length and add the final contribution - for it in range(len(zernike_contribution_list)): - current_zk_order = zernike_contribution_list[it].shape[1] - current_zernike_contribution = np.pad( - zernike_contribution_list[it], - pad_width=[(0, 0), (0, int(max_zk_order - current_zk_order))], - mode="constant", - constant_values=0, - ) - - zernike_contribution += current_zernike_contribution - - return tf.convert_to_tensor(zernike_contribution, dtype=tf.float32) diff --git a/src/wf_psf/inference/__init__.py b/src/wf_psf/inference/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py new file mode 100644 index 00000000..70d7a9be --- /dev/null +++ b/src/wf_psf/inference/psf_inference.py @@ -0,0 +1,745 @@ +"""Inference. + +A module which provides a PSFInference class to perform inference +with trained PSF models. It is able to load a trained model, +perform inference on a dataset of SEDs and positions, and generate polychromatic PSFs. + +:Authors: Jennifer Pollack , Tobias Liaudat + +""" + +import os +from pathlib import Path +import numpy as np +from wf_psf.data.data_handler import DataHandler +from wf_psf.utils.read_config import read_conf +from wf_psf.utils.utils import ensure_batch +from wf_psf.psf_models import psf_models +from wf_psf.psf_models.psf_model_loader import load_trained_psf_model +import tensorflow as tf + + +class InferenceConfigHandler: + """ + Handle configuration loading and management for PSF inference. + + This class manages the loading of inference, training, and data configuration + files required for PSF inference operations. + + Parameters + ---------- + inference_config_path : str + Path to the inference configuration YAML file. + + Attributes + ---------- + inference_config_path : str + Path to the inference configuration file. + inference_config : RecursiveNamespace or None + Loaded inference configuration. + training_config : RecursiveNamespace or None + Loaded training configuration. + data_config : RecursiveNamespace or None + Loaded data configuration. + trained_model_path : Path + Path to the trained model directory. + model_subdir : str + Subdirectory name for model files. + trained_model_config_path : Path + Path to the training configuration file. + data_config_path : str or None + Path to the data configuration file. + """ + + ids = ("inference_conf",) + + def __init__(self, inference_config_path: str): + self.inference_config_path = inference_config_path + self.inference_config = None + self.training_config = None + self.data_config = None + + def load_configs(self): + """ + Load configuration files based on the inference config. + + Loads the inference configuration first, then uses it to determine and load + the training and data configurations. + + Notes + ----- + Updates the following attributes in-place: + - inference_config + - training_config + - data_config (if data_config_path is specified) + """ + + self.inference_config = read_conf(self.inference_config_path) + self.set_config_paths() + self.training_config = read_conf(self.trained_model_config_path) + + if self.data_config_path is not None: + # Load the data configuration + self.data_config = read_conf(self.data_config_path) + + def set_config_paths(self): + """ + Extract and set the configuration paths from the inference config. + + Sets the following attributes: + - trained_model_path + - model_subdir + - trained_model_config_path + - data_config_path + """ + # Set config paths + config_paths = self.inference_config.inference.configs + + self.trained_model_path = Path(config_paths.trained_model_path) + self.model_subdir = config_paths.model_subdir + self.trained_model_config_path = ( + self.trained_model_path / config_paths.trained_model_config_path + ) + self.data_config_path = config_paths.data_config_path + + @staticmethod + def overwrite_model_params(training_config=None, inference_config=None): + """ + Overwrite training model_params with values from inference_config if available. + + Parameters + ---------- + training_config : RecursiveNamespace + Configuration object from training phase. + inference_config : RecursiveNamespace + Configuration object from inference phase. + + Notes + ----- + Updates are applied in-place to training_config.training.model_params. + """ + model_params = training_config.training.model_params + inf_model_params = inference_config.inference.model_params + + if model_params and inf_model_params: + for key, value in inf_model_params.__dict__.items(): + if hasattr(model_params, key): + setattr(model_params, key, value) + + +class PSFInference: + """ + Perform PSF inference using a pre-trained WaveDiff model. + + This class handles the setup for PSF inference, including loading configuration + files, instantiating the PSF simulator and data handler, and preparing the + input data required for inference. + + Parameters + ---------- + inference_config_path : str + Path to the inference configuration YAML file. + x_field : array-like, optional + x coordinates in SHE convention. + y_field : array-like, optional + y coordinates in SHE convention. + seds : array-like, optional + Spectral energy distributions (SEDs). + sources : array-like, optional + Postage stamps of sources, e.g. star images (shape: [n_stars, h, w]). + masks : array-like, optional + Corresponding masks for the sources (same shape as sources). Defaults to None. + + + Attributes + ---------- + inference_config_path : str + Path to the inference configuration file. + x_field : array-like or None + x coordinates for PSF positions. + y_field : array-like or None + y coordinates for PSF positions. + seds : array-like or None + Spectral energy distributions. + sources : array-like or None + Source postage stamps. + masks : array-like or None + Source masks. + engine : PSFInferenceEngine or None + The inference engine instance. + + Examples + -------- + Basic usage with position coordinates and SEDs: + + .. code-block:: python + + psf_inf = PSFInference( + inference_config_path="config.yaml", + x_field=[100.5, 200.3], + y_field=[150.2, 250.8], + seds=sed_array + ) + psf_inf.run_inference() + psf = psf_inf.get_psf(0) + """ + + def __init__( + self, + inference_config_path: str, + x_field=None, + y_field=None, + seds=None, + sources=None, + masks=None, + ): + + self.inference_config_path = inference_config_path + + # Inputs for the model + self.x_field = x_field + self.y_field = y_field + self.seds = seds + self.sources = sources + self.masks = masks + + # Internal caches for lazy-loading + self._config_handler = None + self._simPSF = None + self._data_handler = None + self._trained_psf_model = None + self._n_bins_lambda = None + self._batch_size = None + self._cycle = None + self._output_dim = None + + # Initialise PSF Inference engine + self.engine = None + + @property + def config_handler(self): + """ + Get or create the configuration handler. + + Returns + ------- + InferenceConfigHandler + The configuration handler instance with loaded configs. + """ + if self._config_handler is None: + self._config_handler = InferenceConfigHandler(self.inference_config_path) + self._config_handler.load_configs() + return self._config_handler + + def prepare_configs(self): + """ + Prepare the configuration for inference. + + Overwrites training model parameters with inference configuration values. + """ + # Overwrite model parameters with inference config + self.config_handler.overwrite_model_params( + self.training_config, self.inference_config + ) + + @property + def inference_config(self): + """ + Get the inference configuration. + + Returns + ------- + RecursiveNamespace + The inference configuration object. + """ + return self.config_handler.inference_config + + @property + def training_config(self): + """ + Get the training configuration. + + Returns + ------- + RecursiveNamespace + The training configuration object. + """ + return self.config_handler.training_config + + @property + def data_config(self): + """ + Get the data configuration. + + Returns + ------- + RecursiveNamespace or None + The data configuration object, or None if not available. + """ + return self.config_handler.data_config + + @property + def simPSF(self): + """ + Get or create the PSF simulator. + + Returns + ------- + simPSF + The PSF simulator instance. + """ + if self._simPSF is None: + self._simPSF = psf_models.simPSF(self.training_config.training.model_params) + return self._simPSF + + def _prepare_dataset_for_inference(self): + """ + Prepare dataset dictionary for inference. + + Returns + ------- + dict or None + Dictionary containing positions, sources, and masks, or None if positions are invalid. + """ + positions = self.get_positions() + if positions is None: + return None + return {"positions": positions, "sources": self.sources, "masks": self.masks} + + @property + def data_handler(self): + """ + Get or create the data handler. + + Returns + ------- + DataHandler + The data handler instance configured for inference. + """ + if self._data_handler is None: + # Instantiate the data handler + self._data_handler = DataHandler( + dataset_type="inference", + data_params=self.data_config, + simPSF=self.simPSF, + n_bins_lambda=self.n_bins_lambda, + load_data=False, + dataset=self._prepare_dataset_for_inference(), + sed_data=self.seds, + ) + self._data_handler.run_type = "inference" + return self._data_handler + + @property + def trained_psf_model(self): + """ + Get or load the trained PSF model. + + Returns + ------- + Model + The loaded trained PSF model. + """ + if self._trained_psf_model is None: + self._trained_psf_model = self.load_inference_model() + return self._trained_psf_model + + def get_positions(self): + """ + Combine x_field and y_field into position pairs. + + Returns + ------- + numpy.ndarray + Array of shape (num_positions, 2) where each row contains [x, y] coordinates. + Returns None if either x_field or y_field is None. + + Raises + ------ + ValueError + If x_field and y_field have different lengths. + """ + if self.x_field is None or self.y_field is None: + return None + + x_arr = np.asarray(self.x_field) + y_arr = np.asarray(self.y_field) + + if x_arr.size == 0 or y_arr.size == 0: + return None + + if x_arr.size != y_arr.size: + raise ValueError( + f"x_field and y_field must have the same length. " + f"Got {x_arr.size} and {y_arr.size}" + ) + + # Flatten arrays to handle any input shape, then stack + x_flat = x_arr.flatten() + y_flat = y_arr.flatten() + + return np.column_stack((x_flat, y_flat)) + + def load_inference_model(self): + """Load the trained PSF model based on the inference configuration. + + Returns + ------- + Model + The loaded trained PSF model. + + Notes + ----- + Constructs the weights path pattern based on the trained model path, + model subdirectory, model name, id name, and cycle number specified in the + configuration files. + """ + model_path = self.config_handler.trained_model_path + model_dir = self.config_handler.model_subdir + model_name = self.training_config.training.model_params.model_name + id_name = self.training_config.training.id_name + + weights_path_pattern = os.path.join( + model_path, + model_dir, + f"{model_dir}*_{model_name}*{id_name}_cycle{self.cycle}*", + ) + + # Load the trained PSF model + return load_trained_psf_model( + self.training_config, + self.data_handler, + weights_path_pattern, + ) + + @property + def n_bins_lambda(self): + """Get the number of wavelength bins for inference. + + Returns + ------- + int + The number of wavelength bins used during inference.""" + if self._n_bins_lambda is None: + self._n_bins_lambda = ( + self.inference_config.inference.model_params.n_bins_lda + ) + return self._n_bins_lambda + + @property + def batch_size(self): + """ + Get the batch size for inference. + + Returns + ------- + int + The batch size for processing during inference. + """ + if self._batch_size is None: + self._batch_size = self.inference_config.inference.batch_size + assert self._batch_size > 0, "Batch size must be greater than 0." + return self._batch_size + + @property + def cycle(self): + """Get the cycle number for inference. + + Returns + ------- + int + The cycle number used for loading the trained model. + """ + if self._cycle is None: + self._cycle = self.inference_config.inference.cycle + return self._cycle + + @property + def output_dim(self): + """Get the output dimension for PSF inference. + + Returns + ------- + int + The output dimension (height and width) of the inferred PSFs. + """ + if self._output_dim is None: + self._output_dim = self.inference_config.inference.model_params.output_dim + return self._output_dim + + def _prepare_positions_and_seds(self): + """ + Preprocess and return tensors for positions and SEDs with consistent shapes. + + Handles single-star, multi-star, and even scalar inputs, ensuring: + - positions: shape (n_samples, 2) + - sed_data: shape (n_samples, n_bins, 2) + """ + # Ensure x_field and y_field are at least 1D + x_arr = np.atleast_1d(self.x_field) + y_arr = np.atleast_1d(self.y_field) + + if x_arr.size != y_arr.size: + raise ValueError( + f"x_field and y_field must have the same length. " + f"Got {x_arr.size} and {y_arr.size}" + ) + + # Combine into positions array (n_samples, 2) + positions = np.column_stack((x_arr, y_arr)) + positions = tf.convert_to_tensor(positions, dtype=tf.float32) + + # Ensure SEDs have shape (n_samples, n_bins, 2) + sed_data = ensure_batch(self.seds) + + if sed_data.shape[0] != positions.shape[0]: + raise ValueError( + f"SEDs batch size {sed_data.shape[0]} does not match number of positions {positions.shape[0]}" + ) + + if sed_data.shape[2] != 2: + raise ValueError( + f"SEDs last dimension must be 2 (flux, wavelength). Got {sed_data.shape}" + ) + + # Process SEDs through the data handler + self.data_handler.process_sed_data(sed_data) + sed_data_tensor = self.data_handler.sed_data + + return positions, sed_data_tensor + + def run_inference(self): + """Run PSF inference and return the full PSF array. + + Returns + ------- + numpy.ndarray + Array of inferred PSFs with shape (n_samples, output_dim, output_dim). + + Notes + ----- + Prepares configurations and input data, initializes the inference engine, + and computes the PSF for all input positions. + """ + # Prepare the configuration for inference + self.prepare_configs() + + # Prepare positions and SEDs for inference + positions, sed_data = self._prepare_positions_and_seds() + + self.engine = PSFInferenceEngine( + trained_model=self.trained_psf_model, + batch_size=self.batch_size, + output_dim=self.output_dim, + ) + return self.engine.compute_psfs(positions, sed_data) + + def _ensure_psf_inference_completed(self): + """Ensure that PSF inference has been completed. + + Runs inference if it has not been done yet. + """ + if self.engine is None or self.engine.inferred_psfs is None: + self.run_inference() + + def get_psfs(self): + """Get all inferred PSFs. + + Returns + ------- + numpy.ndarray + Array of inferred PSFs with shape (n_samples, output_dim, output_dim). + + Notes + ----- + Ensures automatically that inference has been completed before accessing the PSFs. + """ + self._ensure_psf_inference_completed() + return self.engine.get_psfs() + + def get_psf(self, index: int = 0) -> np.ndarray: + """ + Get the PSF at a specific index. + + Parameters + ---------- + index : int, optional + Index of the PSF to retrieve (default is 0). + + Returns + ------- + numpy.ndarray + The inferred PSF at the specified index with shape (output_dim, output_dim). + + Notes + ----- + Ensures automatically that inference has been completed before accessing the PSF. + If only a single star was passed during instantiation, the index defaults to 0 + and bounds checking is relaxed. + """ + self._ensure_psf_inference_completed() + + inferred_psfs = self.engine.get_psfs() + + # If a single-star batch, ignore index bounds + if inferred_psfs.shape[0] == 1: + return inferred_psfs[0] + + # Otherwise, return the PSF at the requested index + return inferred_psfs[index] + + def clear_cache(self): + """ + Clear all cached properties and reset the instance. + + This method resets all lazy-loaded properties, including the config handler, + PSF simulator, data handler, trained model, and inference engine. Useful for + freeing memory or forcing a fresh initialization. + + Notes + ----- + After calling this method, accessing any property will trigger re-initialization. + """ + self._config_handler = None + self._simPSF = None + self._data_handler = None + self._trained_psf_model = None + self._n_bins_lambda = None + self._batch_size = None + self._cycle = None + self._output_dim = None + self.engine = None + + +class PSFInferenceEngine: + """Engine to perform PSF inference using a trained model. + + This class handles the batch-wise computation of PSFs using a trained PSF model. + It manages the batching of input positions and SEDs, and caches the inferred PSFs for later access. + + Parameters + ---------- + trained_model : Model + The trained PSF model to use for inference. + batch_size : int + The batch size for processing during inference. + output_dim : int + The output dimension (height and width) of the inferred PSFs. + + Attributes + ---------- + trained_model : Model + The trained PSF model used for inference. + batch_size : int + The batch size for processing during inference. + output_dim : int + The output dimension (height and width) of the inferred PSFs. + + Examples + -------- + .. code-block:: python + + >>> engine = PSFInferenceEngine(model, batch_size=32, output_dim=64) + >>> psfs = engine.compute_psfs(positions, seds) + >>> single_psf = engine.get_psf(0) + """ + + def __init__(self, trained_model, batch_size: int, output_dim: int): + self.trained_model = trained_model + self.batch_size = batch_size + self.output_dim = output_dim + self._inferred_psfs = None + + @property + def inferred_psfs(self) -> np.ndarray: + """Access the cached inferred PSFs, if available. + + Returns + ------- + numpy.ndarray or None + The cached inferred PSFs, or None if not yet computed. + """ + return self._inferred_psfs + + def compute_psfs(self, positions: tf.Tensor, sed_data: tf.Tensor) -> np.ndarray: + """Compute and cache PSFs for the input source parameters. + + Parameters + ---------- + positions : tf.Tensor + Tensor of shape (n_samples, 2) containing the (x, y) positions + sed_data : tf.Tensor + Tensor of shape (n_samples, n_bins, 2) containing the SEDs + + Returns + ------- + numpy.ndarray + Array of inferred PSFs with shape (n_samples, output_dim, output_dim). + + Notes + ----- + PSFs are computed in batches according to the specified batch_size. + Results are cached internally for subsequent access via get_psfs() or get_psf(). + """ + n_samples = positions.shape[0] + self._inferred_psfs = np.zeros( + (n_samples, self.output_dim, self.output_dim), dtype=np.float32 + ) + + # Initialize counter + counter = 0 + while counter < n_samples: + # Calculate the batch end element + end_sample = min(counter + self.batch_size, n_samples) + + # Define the batch positions + batch_pos = positions[counter:end_sample, :] + batch_seds = sed_data[counter:end_sample, :, :] + batch_inputs = [batch_pos, batch_seds] + + # Generate PSFs for the current batch + batch_psfs = self.trained_model(batch_inputs, training=False) + self.inferred_psfs[counter:end_sample, :, :] = batch_psfs.numpy() + + # Update the counter + counter = end_sample + + return self._inferred_psfs + + def get_psfs(self) -> np.ndarray: + """Get all the generated PSFs. + + Returns + ------- + numpy.ndarray + Array of inferred PSFs with shape (n_samples, output_dim, output_dim). + """ + if self._inferred_psfs is None: + raise ValueError("PSFs not yet computed. Call compute_psfs() first.") + return self._inferred_psfs + + def get_psf(self, index: int) -> np.ndarray: + """Get the PSF at a specific index. + + Returns + ------- + numpy.ndarray + The inferred PSF at the specified index with shape (output_dim, output_dim). + + Raises + ------ + ValueError + If PSFs have not yet been computed. + """ + if self._inferred_psfs is None: + raise ValueError("PSFs not yet computed. Call compute_psfs() first.") + return self._inferred_psfs[index] + + def clear_cache(self): + """ + Clear cached inferred PSFs. + + Resets the internal PSF cache to free memory. After calling this method, + compute_psfs() must be called again before accessing PSFs. + """ + self._inferred_psfs = None diff --git a/src/wf_psf/instrument/__init__.py b/src/wf_psf/instrument/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/wf_psf/utils/ccd_misalignments.py b/src/wf_psf/instrument/ccd_misalignments.py similarity index 89% rename from src/wf_psf/utils/ccd_misalignments.py rename to src/wf_psf/instrument/ccd_misalignments.py index 1d2135ba..873509e5 100644 --- a/src/wf_psf/utils/ccd_misalignments.py +++ b/src/wf_psf/instrument/ccd_misalignments.py @@ -10,7 +10,47 @@ import numpy as np import matplotlib.path as mpltPath from scipy.spatial import KDTree -from wf_psf.utils.preprocessing import defocus_to_zk4_wavediff + + +def compute_ccd_misalignment(model_params, positions: np.ndarray) -> np.ndarray: + """Compute CCD misalignment. + + Parameters + ---------- + model_params : RecursiveNamespace + Object containing parameters for this PSF model class. + positions : np.ndarray + Numpy array containing the positions of the stars in the focal plane. + Shape: (n_stars, 2), where n_stars is the number of stars and 2 corresponds to x and y coordinates. + + Returns + ------- + zernike_ccd_misalignment_array : np.ndarray + Numpy array containing the Zernike contributions to model the CCD chip misalignments. + """ + obs_positions = positions + + ccd_misalignment_calculator = CCDMisalignmentCalculator( + tiles_path=model_params.ccd_misalignments_input_path, + x_lims=model_params.x_lims, + y_lims=model_params.y_lims, + tel_focal_length=model_params.tel_focal_length, + tel_diameter=model_params.tel_diameter, + ) + # Compute required zernike 4 for each position + zk4_values = np.array( + [ + ccd_misalignment_calculator.get_zk4_from_position(single_pos) + for single_pos in obs_positions + ] + ).reshape(-1, 1) + + # Zero pad array to get shape (n_stars, n_zernike=4) + zernike_ccd_misalignment_array = np.pad( + zk4_values, pad_width=[(0, 0), (3, 0)], mode="constant", constant_values=0 + ) + + return zernike_ccd_misalignment_array class CCDMisalignmentCalculator: @@ -121,11 +161,7 @@ def _preprocess_tile_data(self) -> None: self.tiles_z_average = np.mean(self.tiles_z_lims) def _initialize_polygons(self): - """Initialize polygons to look for CCD IDs. - - Each CCD is represented by a polygon defined by its corner points. - - """ + """Initialize polygons to look for CCD IDs""" # Build polygon list corresponding to each CCD self.ccd_polygons = [] @@ -346,6 +382,8 @@ def get_zk4_from_position(self, pos): Zernike 4 value in wavediff convention corresponding to the delta z of the given input position `pos`. """ + from wf_psf.data.data_zernike_utils import defocus_to_zk4_wavediff + dz = self.get_dz_from_position(pos) return defocus_to_zk4_wavediff(dz, self.tel_focal_length, self.tel_diameter) diff --git a/src/wf_psf/metrics/metrics.py b/src/wf_psf/metrics/metrics.py index 0447d596..942a0622 100644 --- a/src/wf_psf/metrics/metrics.py +++ b/src/wf_psf/metrics/metrics.py @@ -152,8 +152,7 @@ def compute_poly_metric( # Print RMSE values logger.info("Absolute RMSE:\t %.4e \t +/- %.4e", rmse, std_rmse) - logger.info("Relative RMSE:\t %.4e % \t +/- %.4e %", rel_rmse, std_rel_rmse) - + logger.info(f"Relative RMSE:\t {rel_rmse:.4e}% \t +/- {std_rel_rmse:.4e}%") return rmse, rel_rmse, std_rmse, std_rel_rmse @@ -364,9 +363,8 @@ def compute_opd_metrics(tf_semiparam_field, gt_tf_semiparam_field, pos, batch_si rel_rmse_std = np.std(rel_rmse_vals) # Print RMSE values - logger.info("Absolute RMSE:\t %.4e % \t +/- %.4e %", rmse, rmse_std) - logger.info("Relative RMSE:\t %.4e % \t +/- %.4e %", rel_rmse, rel_rmse_std) - + logger.info("Absolute RMSE:\t %.4e \t +/- %.4e" % (rmse, rmse_std)) + logger.info(f"Relative RMSE:\t {rel_rmse:.4e}% \t +/- {rel_rmse_std:.4e}%") return rmse, rel_rmse, rmse_std, rel_rmse_std @@ -596,10 +594,10 @@ def compute_shape_metrics( # Print relative shape/size errors logger.info( - f"\nRelative sigma(e1) RMSE =\t {rel_rmse_e1:.4e} % \t +/- {std_rel_rmse_e1:.4e} %" + f"\nRelative sigma(e1) RMSE =\t {rel_rmse_e1:.4e}% \t +/- {std_rel_rmse_e1:.4e}%" ) logger.info( - f"Relative sigma(e2) RMSE =\t {rel_rmse_e2:.4e} % \t +/- {std_rel_rmse_e2:.4e} %" + f"Relative sigma(e2) RMSE =\t {rel_rmse_e2:.4e}% \t +/- {std_rel_rmse_e2:.4e}%" ) # Print number of stars diff --git a/src/wf_psf/metrics/metrics_interface.py b/src/wf_psf/metrics/metrics_interface.py index 3dff2c6c..db410f4e 100644 --- a/src/wf_psf/metrics/metrics_interface.py +++ b/src/wf_psf/metrics/metrics_interface.py @@ -311,7 +311,6 @@ def evaluate_model( trained_model_params, data, psf_model, - weights_path, metrics_output, ): """Evaluate the trained model on both training and test datasets by computing various metrics. @@ -329,8 +328,6 @@ def evaluate_model( DataHandler object containing training and test data psf_model: object PSF model object - weights_path: str - Directory location of model weights metrics_output: str Directory location of metrics output @@ -341,8 +338,8 @@ def evaluate_model( try: ## Load datasets # ----------------------------------------------------- - # Get training data - logger.info("Fetching and preprocessing training and test data...") + # Get training and test data + logger.info("Fetching training and test data...") # Initialize metrics_handler metrics_handler = MetricsParamsHandler(metrics_params, trained_model_params) @@ -351,14 +348,6 @@ def evaluate_model( # Prepare np input simPSF_np = data.training_data.simPSF - ## Load the model's weights - try: - logger.info(f"Loading PSF model weights from {weights_path}") - psf_model.load_weights(weights_path) - except Exception as e: - logger.exception("An error occurred with the weights_path file: %s", e) - exit() - # Define datasets datasets = {"test": data.test_data.dataset, "train": data.training_data.dataset} diff --git a/src/wf_psf/psf_models/models/__init__.py b/src/wf_psf/psf_models/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/wf_psf/psf_models/psf_model_parametric.py b/src/wf_psf/psf_models/models/psf_model_parametric.py similarity index 98% rename from src/wf_psf/psf_models/psf_model_parametric.py rename to src/wf_psf/psf_models/models/psf_model_parametric.py index 0cd703d7..4a28417d 100644 --- a/src/wf_psf/psf_models/psf_model_parametric.py +++ b/src/wf_psf/psf_models/models/psf_model_parametric.py @@ -9,7 +9,7 @@ import tensorflow as tf from wf_psf.psf_models.psf_models import register_psfclass -from wf_psf.psf_models.tf_layers import ( +from wf_psf.psf_models.tf_modules.tf_layers import ( TFPolynomialZernikeField, TFZernikeOPD, TFBatchPolychromaticPSF, @@ -215,7 +215,7 @@ def predict_opd(self, input_positions): return opd_maps - def call(self, inputs): + def call(self, inputs, **kwargs): """Define the PSF field forward model. [1] From positions to Zernike coefficients diff --git a/src/wf_psf/psf_models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py similarity index 67% rename from src/wf_psf/psf_models/psf_model_physical_polychromatic.py rename to src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index ad61c1b5..fb8bc902 100644 --- a/src/wf_psf/psf_models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -10,11 +10,14 @@ from typing import Optional import tensorflow as tf from tensorflow.python.keras.engine import data_adapter +from wf_psf.data.data_handler import get_data_array +from wf_psf.data.data_zernike_utils import ( + ZernikeInputsFactory, + assemble_zernike_contributions, + pad_tf_zernikes, +) from wf_psf.psf_models import psf_models as psfm -from wf_psf.utils.read_config import RecursiveNamespace -from wf_psf.utils.configs_handler import DataConfigHandler -from wf_psf.data.training_preprocessing import get_obs_positions, get_zernike_prior -from wf_psf.psf_models.tf_layers import ( +from wf_psf.psf_models.tf_modules.tf_layers import ( TFPolynomialZernikeField, TFZernikeOPD, TFBatchPolychromaticPSF, @@ -22,6 +25,7 @@ TFNonParametricPolynomialVariationsOPD, TFPhysicalLayer, ) +from wf_psf.psf_models.tf_modules.tf_utils import ensure_tensor import logging @@ -97,8 +101,8 @@ def __init__(self, model_params, training_params, data, coeff_mat=None): A Recursive Namespace object containing parameters for this PSF model class. training_params: Recursive Namespace A Recursive Namespace object containing training hyperparameters for this PSF model class. - data: DataConfigHandler - A DataConfigHandler object that provides access to training and test datasets, as well as prior knowledge like Zernike coefficients. + data: DataConfigHandler or dict + A DataConfigHandler object or dict that provides access to single or multiple datasets (e.g. train and test), as well as prior knowledge like Zernike coefficients. coeff_mat: Tensor or None, optional Coefficient matrix defining the parametric PSF field model. @@ -108,204 +112,192 @@ def __init__(self, model_params, training_params, data, coeff_mat=None): Initialized instance of the TFPhysicalPolychromaticField class. """ super().__init__(model_params, training_params, coeff_mat) - self._initialize_parameters_and_layers( - model_params, training_params, data, coeff_mat - ) - - def _initialize_parameters_and_layers( - self, - model_params: RecursiveNamespace, - training_params: RecursiveNamespace, - data: DataConfigHandler, - coeff_mat: Optional[tf.Tensor] = None, - ): - """Initialize Parameters of the PSF model. - - This method sets up the PSF model parameters, observational positions, - Zernike coefficients, and components required for the automatically - differentiable optical forward model. + self.model_params = model_params + self.training_params = training_params + self.data = data + self.run_type = self._get_run_type(data) + self.obs_pos = self.get_obs_pos() - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - training_params: Recursive Namespace - A Recursive Namespace object containing training hyperparameters for this PSF model class. - data: DataConfigHandler object - A DataConfigHandler object providing access to training and tests datasets, as well as prior knowledge like Zernike coefficients. - coeff_mat: Tensor or None, optional - Initialization of the coefficient matrix defining the parametric psf field model. - - Notes - ----- - - Initializes Zernike parameters based on dataset priors. - - Configures the PSF model layers according to `model_params`. - - If `coeff_mat` is provided, the model coefficients are updated accordingly. - """ + # Initialize the model parameters self.output_Q = model_params.output_Q - self.obs_pos = get_obs_positions(data) self.l2_param = model_params.param_hparams.l2_param - # Inputs: Save optimiser history Parametric model features - self.save_optim_history_param = ( - model_params.param_hparams.save_optim_history_param - ) - # Inputs: Save optimiser history NonParameteric model features - self.save_optim_history_nonparam = ( - model_params.nonparam_hparams.save_optim_history_nonparam - ) - self._initialize_zernike_parameters(model_params, data) - self._initialize_layers(model_params, training_params) + self.output_dim = model_params.output_dim - # Initialize the model parameters with non-default value + # Initialise lazy loading of external Zernike prior + self._external_prior = None + + # Set Zernike Polynomial Coefficient Matrix if not None if coeff_mat is not None: self.assign_coeff_matrix(coeff_mat) - def _initialize_zernike_parameters(self, model_params, data): - """Initialize the Zernike parameters. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - data: DataConfigHandler object - A DataConfigHandler object providing access to training and tests datasets, as well as prior knowledge like Zernike coefficients. - """ - self.zks_prior = get_zernike_prior(model_params, data, data.batch_size) - self.n_zks_total = max( - model_params.param_hparams.n_zernikes, - tf.cast(tf.shape(self.zks_prior)[1], tf.int32), - ) - self.zernike_maps = psfm.generate_zernike_maps_3d( - self.n_zks_total, model_params.pupil_diameter + # Compute contributions once eagerly (outside graph) + zks_total_contribution_np = self._assemble_zernike_contributions().numpy() + self._zks_total_contribution = tf.convert_to_tensor( + zks_total_contribution_np, dtype=tf.float32 ) - def _initialize_layers(self, model_params, training_params): - """Initialize the layers of the PSF model. - - This method initializes the layers of the PSF model, including the physical layer, polynomial Zernike field, batch polychromatic layer, and non-parametric OPD layer. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - training_params: Recursive Namespace - A Recursive Namespace object containing training hyperparameters for this PSF model class. - coeff_mat: Tensor or None, optional - Initialization of the coefficient matrix defining the parametric psf field model. - - """ - self._initialize_physical_layer(model_params) - self._initialize_polynomial_Z_field(model_params) - self._initialize_Zernike_OPD(model_params) - self._initialize_batch_polychromatic_layer(model_params, training_params) - self._initialize_nonparametric_opd_layer(model_params, training_params) - - def _initialize_physical_layer(self, model_params): - """Initialize the physical layer of the PSF model. - - This method initializes the physical layer of the PSF model using parameters - specified in the `model_params` object. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - """ - self.tf_physical_layer = TFPhysicalLayer( - self.obs_pos, - self.zks_prior, - interpolation_type=model_params.interpolation_type, - interpolation_args=model_params.interpolation_args, + # Compute n_zks_total as int + self._n_zks_total = max( + self.model_params.param_hparams.n_zernikes, + zks_total_contribution_np.shape[1], ) - def _initialize_polynomial_Z_field(self, model_params): - """Initialize the polynomial Zernike field of the PSF model. - - This method initializes the polynomial Zernike field of the PSF model using - parameters specified in the `model_params` object. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - - """ - self.tf_poly_Z_field = TFPolynomialZernikeField( - x_lims=model_params.x_lims, - y_lims=model_params.y_lims, - random_seed=model_params.param_hparams.random_seed, - n_zernikes=model_params.param_hparams.n_zernikes, - d_max=model_params.param_hparams.d_max, + # Precompute zernike maps as tf.float32 + self._zernike_maps = psfm.generate_zernike_maps_3d( + n_zernikes=self._n_zks_total, pupil_diam=self.model_params.pupil_diameter ) - def _initialize_Zernike_OPD(self, model_params): - """Initialize the Zernike OPD field of the PSF model. - - This method initializes the Zernike Optical Path Difference - field of the PSF model using parameters specified in the `model_params` object. + # Precompute OPD dimension + self._opd_dim = self._zernike_maps.shape[1] - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. + # Precompute obscurations as tf.complex64 + self._obscurations = psfm.tf_obscurations( + pupil_diam=self.model_params.pupil_diameter, + N_filter=self.model_params.LP_filter_length, + rotation_angle=self.model_params.obscuration_rotation_angle, + ) - """ - # Initialize the zernike to OPD layer - self.tf_zernike_OPD = TFZernikeOPD(zernike_maps=self.zernike_maps) + # Eagerly initialise model layers + self.tf_batch_poly_PSF = self._build_tf_batch_poly_PSF() + _ = self.tf_poly_Z_field + _ = self.tf_np_poly_opd - def _initialize_batch_polychromatic_layer(self, model_params, training_params): - """Initialize the batch polychromatic PSF layer. + def _get_run_type(self, data): + if hasattr(data, "run_type"): + run_type = data.run_type + elif isinstance(data, dict) and "run_type" in data: + run_type = data["run_type"] + else: + raise ValueError("data must have a 'run_type' attribute or key") + + if run_type not in {"training", "simulation", "metrics", "inference"}: + raise ValueError(f"Unknown run_type: {run_type}") + return run_type + + def _assemble_zernike_contributions(self): + zks_inputs = ZernikeInputsFactory.build( + data=self.data, + run_type=self.run_type, + model_params=self.model_params, + prior=self._external_prior if hasattr(self, "_external_prior") else None, + ) + return assemble_zernike_contributions( + model_params=self.model_params, + zernike_prior=zks_inputs.zernike_prior, + centroid_dataset=zks_inputs.centroid_dataset, + positions=zks_inputs.misalignment_positions, + batch_size=self.training_params.batch_size, + ) - This method initializes the batch opd to batch polychromatic PSF layer - using the provided `model_params` and `training_params`. + @property + def save_param_history(self) -> bool: + """Check if the model should save the optimization history for parametric features.""" + return getattr( + self.model_params.param_hparams, "save_optim_history_param", False + ) - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - training_params: Recursive Namespace - A Recursive Namespace object containing training hyperparameters for this PSF model class. + @property + def save_nonparam_history(self) -> bool: + """Check if the model should save the optimization history for non-parametric features.""" + return getattr( + self.model_params.nonparam_hparams, "save_optim_history_nonparam", False + ) + def get_obs_pos(self): + assert self.run_type in { + "training", + "simulation", + "metrics", + "inference", + }, f"Unknown run_type: {self.run_type}" - """ - self.batch_size = training_params.batch_size - self.obscurations = psfm.tf_obscurations( - pupil_diam=model_params.pupil_diameter, - N_filter=model_params.LP_filter_length, - rotation_angle=model_params.obscuration_rotation_angle, + raw_pos = get_data_array( + data=self.data, run_type=self.run_type, key="positions" ) - self.output_dim = model_params.output_dim - self.tf_batch_poly_PSF = TFBatchPolychromaticPSF( + obs_pos = ensure_tensor(raw_pos, dtype=tf.float32) + + return obs_pos + + # === Lazy properties ===. + @property + def zks_total_contribution(self): + return self._zks_total_contribution + + @property + def n_zks_total(self): + """Get the total number of Zernike coefficients.""" + return self._n_zks_total + + @property + def zernike_maps(self): + """Get Zernike maps.""" + return self._zernike_maps + + @property + def opd_dim(self): + return self._opd_dim + + @property + def obscurations(self): + return self._obscurations + + @property + def tf_poly_Z_field(self): + """Lazy loading of the polynomial Zernike field layer.""" + if not hasattr(self, "_tf_poly_Z_field"): + self._tf_poly_Z_field = TFPolynomialZernikeField( + x_lims=self.model_params.x_lims, + y_lims=self.model_params.y_lims, + random_seed=self.model_params.param_hparams.random_seed, + n_zernikes=self.model_params.param_hparams.n_zernikes, + d_max=self.model_params.param_hparams.d_max, + ) + return self._tf_poly_Z_field + + @tf_poly_Z_field.deleter + def tf_poly_Z_field(self): + del self._tf_poly_Z_field + + @property + def tf_physical_layer(self): + """Lazy loading of the physical layer of the PSF model.""" + if not hasattr(self, "_tf_physical_layer"): + self._tf_physical_layer = TFPhysicalLayer( + self.obs_pos, + self.zks_total_contribution, + interpolation_type=self.model_params.interpolation_type, + interpolation_args=self.model_params.interpolation_args, + ) + return self._tf_physical_layer + + @property + def tf_zernike_OPD(self): + """Lazy loading of the Zernike Optical Path Difference (OPD) layer.""" + if not hasattr(self, "_tf_zernike_OPD"): + self._tf_zernike_OPD = TFZernikeOPD(zernike_maps=self.zernike_maps) + return self._tf_zernike_OPD + + def _build_tf_batch_poly_PSF(self): + """Eagerly build the TFBatchPolychromaticPSF layer with numpy-based obscurations.""" + return TFBatchPolychromaticPSF( obscurations=self.obscurations, output_Q=self.output_Q, output_dim=self.output_dim, ) - def _initialize_nonparametric_opd_layer(self, model_params, training_params): - """Initialize the non-parametric OPD layer. - - This method initializes the non-parametric OPD layer using the provided - `model_params` and `training_params`. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - training_params: Recursive Namespace - A Recursive Namespace object containing training hyperparameters for this PSF model class. - - """ - # self.d_max_nonparam = model_params.nonparam_hparams.d_max_nonparam - # self.opd_dim = tf.shape(self.zernike_maps)[1].numpy() - - self.tf_np_poly_opd = TFNonParametricPolynomialVariationsOPD( - x_lims=model_params.x_lims, - y_lims=model_params.y_lims, - random_seed=model_params.param_hparams.random_seed, - d_max=model_params.nonparam_hparams.d_max_nonparam, - opd_dim=tf.shape(self.zernike_maps)[1].numpy(), - ) + @property + def tf_np_poly_opd(self): + """Lazy loading of the non-parametric polynomial variations OPD layer.""" + if not hasattr(self, "_tf_np_poly_opd"): + self._tf_np_poly_opd = TFNonParametricPolynomialVariationsOPD( + x_lims=self.model_params.x_lims, + y_lims=self.model_params.y_lims, + random_seed=self.model_params.param_hparams.random_seed, + d_max=self.model_params.nonparam_hparams.d_max_nonparam, + opd_dim=self.opd_dim, + ) + return self._tf_np_poly_opd def get_coeff_matrix(self): """Get coefficient matrix.""" @@ -335,18 +327,15 @@ def set_output_Q(self, output_Q: float, output_dim: Optional[int] = None) -> Non This method updates the `output_Q` parameter, which defines the resampling factor for generating PSFs at different resolutions - relative to the telescope's native sampling. It also allows optionally - updating `output_dim`, which sets the output resolution of the PSF model. + relative to the telescope's native sampling. It also allows optionally updating `output_dim`, which sets the output resolution of the PSF model. If `output_dim` is provided, the PSF model's output resolution is updated. - The method then reinitializes the batch polychromatic PSF generator - to reflect the updated parameters. + The method then reinitializes the batch polychromatic PSF generator to reflect the updated parameters. Parameters ---------- output_Q : float - The resampling factor that determines the output PSF resolution - relative to the telescope's native sampling. + The resampling factor that determines the output PSF resolution relative to the telescope's native sampling. output_dim : Optional[int], default=None The new output dimension for the PSF model. If `None`, the output dimension remains unchanged. @@ -358,6 +347,7 @@ def set_output_Q(self, output_Q: float, output_dim: Optional[int] = None) -> Non self.output_Q = output_Q if output_dim is not None: self.output_dim = output_dim + # Reinitialize the PSF batch poly generator self.tf_batch_poly_PSF = TFBatchPolychromaticPSF( obscurations=self.obscurations, @@ -471,12 +461,16 @@ def predict_step(self, data, evaluate_step=False): # Compute zernikes from parametric model and physical layer zks_coeffs = self.predict_zernikes(input_positions) + # Propagate to obtain the OPD param_opd_maps = self.tf_zernike_OPD(zks_coeffs) + # Calculate the non parametric part nonparam_opd_maps = self.tf_np_poly_opd(input_positions) + # Add the estimations opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) + # Compute the polychromatic PSFs poly_psfs = self.tf_batch_poly_PSF([opd_maps, packed_SEDs]) @@ -519,10 +513,13 @@ def predict_mono_psfs(self, input_positions, lambda_obs, phase_N): # Predict zernikes from parametric model and physical layer zks_coeffs = self.predict_zernikes(input_positions) + # Propagate to obtain the OPD param_opd_maps = self.tf_zernike_OPD(zks_coeffs) + # Calculate the non parametric part nonparam_opd_maps = self.tf_np_poly_opd(input_positions) + # Add the estimations opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) @@ -547,10 +544,13 @@ def predict_opd(self, input_positions): """ # Predict zernikes from parametric model and physical layer zks_coeffs = self.predict_zernikes(input_positions) + # Propagate to obtain the OPD param_opd_maps = self.tf_zernike_OPD(zks_coeffs) + # Calculate the non parametric part nonparam_opd_maps = self.tf_np_poly_opd(input_positions) + # Add the estimations opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) @@ -585,9 +585,10 @@ def compute_zernikes(self, input_positions): zernike_prior = self.tf_physical_layer.call(input_positions) # Pad and sum the zernike coefficients - padded_zernike_params, padded_zernike_prior = self.pad_zernikes( - zernike_params, zernike_prior + padded_zernike_params, padded_zernike_prior = pad_tf_zernikes( + zernike_params, zernike_prior, self.n_zks_total ) + zernike_coeffs = tf.math.add(padded_zernike_params, padded_zernike_prior) return zernike_coeffs @@ -622,8 +623,8 @@ def predict_zernikes(self, input_positions): physical_layer_prediction = self.tf_physical_layer.predict(input_positions) # Pad and sum the Zernike coefficients - padded_zernike_params, padded_physical_layer_prediction = self.pad_zernikes( - zernike_params, physical_layer_prediction + padded_zernike_params, padded_physical_layer_prediction = pad_tf_zernikes( + zernike_params, physical_layer_prediction, self.n_zks_total ) zernike_coeffs = tf.math.add( padded_zernike_params, padded_physical_layer_prediction @@ -688,22 +689,21 @@ def call(self, inputs, training=True): # Compute zernikes from parametric model and physical layer zks_coeffs = self.compute_zernikes(input_positions) - # Propagate to obtain the OPD + # Parametric OPD maps from Zernikes param_opd_maps = self.tf_zernike_OPD(zks_coeffs) - # Add l2 loss on the parametric OPD - self.add_loss( - self.l2_param * tf.math.reduce_sum(tf.math.square(param_opd_maps)) - ) + # Add L2 regularization loss on parametric OPD maps + self.add_loss(self.l2_param * tf.reduce_sum(tf.square(param_opd_maps))) - # Calculate the non parametric part + # Non-parametric correction nonparam_opd_maps = self.tf_np_poly_opd(input_positions) - # Add the estimations - opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) + # Combine both contributions + opd_maps = tf.add(param_opd_maps, nonparam_opd_maps) # Compute the polychromatic PSFs poly_psfs = self.tf_batch_poly_PSF([opd_maps, packed_SEDs]) + # For the inference else: # Compute predictions diff --git a/src/wf_psf/psf_models/psf_model_semiparametric.py b/src/wf_psf/psf_models/models/psf_model_semiparametric.py similarity index 99% rename from src/wf_psf/psf_models/psf_model_semiparametric.py rename to src/wf_psf/psf_models/models/psf_model_semiparametric.py index dc535204..7b2ff04d 100644 --- a/src/wf_psf/psf_models/psf_model_semiparametric.py +++ b/src/wf_psf/psf_models/models/psf_model_semiparametric.py @@ -10,9 +10,9 @@ import numpy as np import tensorflow as tf from wf_psf.psf_models import psf_models as psfm -from wf_psf.psf_models import tf_layers as tfl +from wf_psf.psf_models.tf_modules import tf_layers as tfl from wf_psf.utils.utils import decompose_tf_obscured_opd_basis -from wf_psf.psf_models.tf_layers import ( +from wf_psf.psf_models.tf_modules.tf_layers import ( TFBatchPolychromaticPSF, TFBatchMonochromaticPSF, ) @@ -421,7 +421,7 @@ def project_DD_features(self, tf_zernike_cube=None): s_new = self.tf_np_poly_opd.S_mat - s_mat_projected self.assign_S_mat(s_new) - def call(self, inputs): + def call(self, inputs, **kwargs): """Define the PSF field forward model. [1] From positions to Zernike coefficients diff --git a/src/wf_psf/psf_models/psf_model_loader.py b/src/wf_psf/psf_models/psf_model_loader.py new file mode 100644 index 00000000..e41e3536 --- /dev/null +++ b/src/wf_psf/psf_models/psf_model_loader.py @@ -0,0 +1,57 @@ +"""PSF Model Loader. + +This module provides helper functions for loading trained PSF models. +It includes utilities to: +- Load a model from disk using its configuration and weights. +- Prepare inputs for inference or evaluation workflows. + +Author: Jennifer Pollack +""" + +import logging +from wf_psf.psf_models.psf_models import get_psf_model, get_psf_model_weights_filepath + +logger = logging.getLogger(__name__) + + +def load_trained_psf_model(training_conf, data_conf, weights_path_pattern): + """ + Loads a trained PSF model and applies saved weights. + + Parameters + ---------- + training_conf : RecursiveNamespace + Configuration object containing model parameters and training hyperparameters. + Supports attribute-style access to nested fields. + data_conf : RecursiveNamespace or dict + Configuration RecursiveNamespace object or a dictionary containing data parameters (e.g. pixel data, positions, masks, etc). + weights_path_pattern : str + Glob-style pattern used to locate the model weights file. + + Returns + ------- + model : tf.keras.Model or compatible + The PSF model instance with loaded weights. + + Raises + ------ + RuntimeError + If loading the model weights fails for any reason. + """ + model = get_psf_model( + training_conf.training.model_params, + training_conf.training.training_hparams, + data_conf, + ) + + weights_path = get_psf_model_weights_filepath(weights_path_pattern) + + try: + logger.info(f"Loading PSF model weights from {weights_path}") + status = model.load_weights(weights_path) + status.expect_partial() + + except Exception as e: + logger.exception("Failed to load model weights.") + raise RuntimeError("Model weight loading failed.") from e + return model diff --git a/src/wf_psf/psf_models/psf_models.py b/src/wf_psf/psf_models/psf_models.py index 463d1c52..4c44f698 100644 --- a/src/wf_psf/psf_models/psf_models.py +++ b/src/wf_psf/psf_models/psf_models.py @@ -187,24 +187,24 @@ def build_PSF_model(model_inst, optimizer=None, loss=None, metrics=None): def get_psf_model_weights_filepath(weights_filepath): """Get PSF model weights filepath. - A function to return the basename of the user-specified psf model weights path. + A function to return the basename of the user-specified PSF model weights path. Parameters ---------- weights_filepath: str - Basename of the psf model weights to be loaded. + Basename of the PSF model weights to be loaded. Returns ------- str - The absolute path concatenated to the basename of the psf model weights to be loaded. + The absolute path concatenated to the basename of the PSF model weights to be loaded. """ try: return glob.glob(weights_filepath)[0].split(".")[0] except IndexError: logger.exception( - "PSF weights file not found. Check that you've specified the correct weights file in the metrics config file." + "PSF weights file not found. Check that you've specified the correct weights file in the your config file." ) raise PSFModelError("PSF model weights error.") diff --git a/src/wf_psf/psf_models/tf_modules/__init__.py b/src/wf_psf/psf_models/tf_modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/wf_psf/psf_models/tf_layers.py b/src/wf_psf/psf_models/tf_modules/tf_layers.py similarity index 98% rename from src/wf_psf/psf_models/tf_layers.py rename to src/wf_psf/psf_models/tf_modules/tf_layers.py index eda43305..cdd01e16 100644 --- a/src/wf_psf/psf_models/tf_layers.py +++ b/src/wf_psf/psf_models/tf_modules/tf_layers.py @@ -7,7 +7,8 @@ """ import tensorflow as tf -from wf_psf.psf_models.tf_modules import TFMonochromaticPSF +from wf_psf.psf_models.tf_modules.tf_modules import TFMonochromaticPSF +from wf_psf.psf_models.tf_modules.tf_utils import find_position_indices from wf_psf.utils.utils import calc_poly_position_mat import wf_psf.utils.utils as utils from wf_psf.utils.interpolation import tfa_interpolate_spline_rbf @@ -997,13 +998,10 @@ def call(self, positions): If the shape of the input `positions` tensor is not compatible. """ + # Find indices for all positions in one batch operation + idx = find_position_indices(self.obs_pos, positions) - def calc_index(idx_pos): - return tf.where(tf.equal(self.obs_pos, idx_pos))[0, 0] - - # Calculate the indices of the input batch - indices = tf.map_fn(calc_index, positions, fn_output_signature=tf.int64) - # Recover the prior zernikes from the batch indexes - batch_zks = tf.gather(self.zks_prior, indices=indices, axis=0, batch_dims=0) + # Gather the corresponding Zernike coefficients + batch_zks = tf.gather(self.zks_prior, idx, axis=0) return batch_zks[:, :, tf.newaxis, tf.newaxis] diff --git a/src/wf_psf/psf_models/tf_modules.py b/src/wf_psf/psf_models/tf_modules/tf_modules.py similarity index 100% rename from src/wf_psf/psf_models/tf_modules.py rename to src/wf_psf/psf_models/tf_modules/tf_modules.py diff --git a/src/wf_psf/psf_models/tf_psf_field.py b/src/wf_psf/psf_models/tf_modules/tf_psf_field.py similarity index 97% rename from src/wf_psf/psf_models/tf_psf_field.py rename to src/wf_psf/psf_models/tf_modules/tf_psf_field.py index 0c9ba2f7..07b523d1 100644 --- a/src/wf_psf/psf_models/tf_psf_field.py +++ b/src/wf_psf/psf_models/tf_modules/tf_psf_field.py @@ -9,15 +9,16 @@ import numpy as np import tensorflow as tf from tensorflow.python.keras.engine import data_adapter -from wf_psf.psf_models.tf_layers import ( +from wf_psf.psf_models.tf_modules.tf_layers import ( TFZernikeOPD, TFBatchPolychromaticPSF, TFBatchMonochromaticPSF, TFPhysicalLayer, ) -from wf_psf.psf_models.psf_model_semiparametric import TFSemiParametricField -from wf_psf.data.training_preprocessing import get_obs_positions +from wf_psf.psf_models.models.psf_model_semiparametric import TFSemiParametricField +from wf_psf.data.data_handler import get_data_array from wf_psf.psf_models import psf_models as psfm +from wf_psf.psf_models.tf_modules.tf_utils import ensure_tensor import logging logger = logging.getLogger(__name__) @@ -221,7 +222,9 @@ def __init__(self, model_params, training_params, data, coeff_mat): self.output_Q = model_params.output_Q # Inputs: TF_physical_layer - self.obs_pos = get_obs_positions(data) + self.obs_pos = ensure_tensor( + get_data_array(data, data.run_type, key="positions"), dtype=tf.float32 + ) self.zks_prior = get_ground_truth_zernike(data) self.n_zks_prior = tf.shape(self.zks_prior)[1].numpy() diff --git a/src/wf_psf/psf_models/tf_modules/tf_utils.py b/src/wf_psf/psf_models/tf_modules/tf_utils.py new file mode 100644 index 00000000..4bd1246a --- /dev/null +++ b/src/wf_psf/psf_models/tf_modules/tf_utils.py @@ -0,0 +1,99 @@ +"""TensorFlow Utilities Module. + +Provides lightweight utility functions for safely converting and managing data types +within TensorFlow-based workflows. + +Includes: +- `ensure_tensor`: ensures inputs are TensorFlow tensors with specified dtype + +These tools are designed to support PSF model components, including lazy property evaluation, +data input validation, and type normalization. + +This module is intended for internal use in model layers and inference components to enforce +TensorFlow-compatible inputs. + +Authors: Jennifer Pollack +""" + +import tensorflow as tf + + +@tf.function +def find_position_indices(obs_pos, batch_positions): + """Find indices of batch positions within observed positions using vectorized operations. + + This function locates the indices of multiple query positions within a + reference set of observed positions using broadcasting and vectorized operations. + Each position in the batch must have an exact match in the observed positions. + + Parameters + ---------- + obs_pos : tf.Tensor + Reference positions tensor of shape (n_obs, 2), where n_obs is the number of + observed positions. Each row contains [x, y] coordinates. + batch_positions : tf.Tensor + Query positions tensor of shape (batch_size, 2), where batch_size is the number + of positions to look up. Each row contains [x, y] coordinates. + + Returns + ------- + indices : tf.Tensor + Tensor of shape (batch_size,) containing the indices of each batch position + within obs_pos. The dtype is tf.int64. + + Raises + ------ + tf.errors.InvalidArgumentError + If any position in batch_positions is not found in obs_pos. + + Notes + ----- + Uses exact equality matching - positions must match exactly. More efficient than + iterative lookups for multiple positions due to vectorized operations. + """ + # Shape: obs_pos (n_obs, 2), batch_positions (batch_size, 2) + # Expand for broadcasting: (1, n_obs, 2) and (batch_size, 1, 2) + obs_expanded = tf.expand_dims(obs_pos, 0) + pos_expanded = tf.expand_dims(batch_positions, 1) + + # Compare all positions at once: (batch_size, n_obs) + matches = tf.reduce_all(tf.equal(obs_expanded, pos_expanded), axis=2) + + # Find the index of the matching position for each batch item + # argmax returns the first True value's index along axis=1 + indices = tf.argmax(tf.cast(matches, tf.int32), axis=1) + + # Verify all positions were found + tf.debugging.assert_equal( + tf.reduce_all(tf.reduce_any(matches, axis=1)), + True, + message="Some positions not found in obs_pos", + ) + + return indices + + +def ensure_tensor(input_array, dtype=tf.float32): + """ + Ensure the input is a TensorFlow tensor of the specified dtype. + + Parameters + ---------- + input_array : array-like, tf.Tensor, or np.ndarray + The input to convert. + dtype : tf.DType, optional + The desired TensorFlow dtype (default: tf.float32). + + Returns + ------- + tf.Tensor + A TensorFlow tensor with the specified dtype. + """ + if tf.is_tensor(input_array): + # If already a tensor, optionally cast dtype if different + if input_array.dtype != dtype: + return tf.cast(input_array, dtype) + return input_array + else: + # Convert numpy arrays or other types to tensor + return tf.convert_to_tensor(input_array, dtype=dtype) diff --git a/src/wf_psf/psf_models/zernikes.py b/src/wf_psf/psf_models/zernikes.py deleted file mode 100644 index dcfa6e39..00000000 --- a/src/wf_psf/psf_models/zernikes.py +++ /dev/null @@ -1,58 +0,0 @@ -"""Zernikes. - -A module to make Zernike maps. - -:Author: Tobias Liaudat and Jennifer Pollack - -""" - -import numpy as np -import zernike as zk -import logging - -logger = logging.getLogger(__name__) - - -def zernike_generator(n_zernikes, wfe_dim): - """ - Generate Zernike maps. - - Based on the zernike github repository. - https://github.com/jacopoantonello/zernike - - Parameters - ---------- - n_zernikes: int - Number of Zernike modes desired. - wfe_dim: int - Dimension of the Zernike map [wfe_dim x wfe_dim]. - - Returns - ------- - zernikes: list of np.ndarray - List containing the Zernike modes. - The values outside the unit circle are filled with NaNs. - """ - # Calculate which n (from the (n,m) Zernike convention) we need - # so that we have the desired total number of Zernike coefficients - min_n = (-3 + np.sqrt(1 + 8 * n_zernikes)) / 2 - n = int(np.ceil(min_n)) - - # Initialize the zernike generator - cart = zk.RZern(n) - # Create a [-1,1] mesh - ddx = np.linspace(-1.0, 1.0, wfe_dim) - ddy = np.linspace(-1.0, 1.0, wfe_dim) - xv, yv = np.meshgrid(ddx, ddy) - cart.make_cart_grid(xv, yv) - - c = np.zeros(cart.nk) - zernikes = [] - - # Extract each Zernike map one by one - for i in range(n_zernikes): - c *= 0.0 - c[i] = 1.0 - zernikes.append(cart.eval_grid(c, matrix=True)) - - return zernikes diff --git a/src/wf_psf/tests/__init__.py b/src/wf_psf/tests/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/src/wf_psf/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/src/wf_psf/tests/conftest.py b/src/wf_psf/tests/conftest.py index 5b617c63..beb6b9fb 100644 --- a/src/wf_psf/tests/conftest.py +++ b/src/wf_psf/tests/conftest.py @@ -13,7 +13,7 @@ from wf_psf.training.train import TrainingParamsHandler from wf_psf.utils.configs_handler import DataConfigHandler from wf_psf.psf_models import psf_models -from wf_psf.data.training_preprocessing import DataHandler +from wf_psf.data.data_handler import DataHandler training_config = RecursiveNamespace( id_name="-coherent_euclid_200stars", diff --git a/src/wf_psf/tests/data/validation/masked_loss/results/plot_results.ipynb b/src/wf_psf/tests/data/validation/masked_loss/results/plot_results.ipynb index bec994a4..c36a9e31 100644 --- a/src/wf_psf/tests/data/validation/masked_loss/results/plot_results.ipynb +++ b/src/wf_psf/tests/data/validation/masked_loss/results/plot_results.ipynb @@ -17,19 +17,19 @@ "outputs": [], "source": [ "# Trained on masked data, tested on masked data\n", - "metrics_path_mm = '../runs/masked_train_masked_test/wf-outputs/wf-outputs-202503131718/metrics/metrics-polymask_train_mask_test.npy'\n", + "metrics_path_mm = \"../runs/masked_train_masked_test/wf-outputs/wf-outputs-202503131718/metrics/metrics-polymask_train_mask_test.npy\"\n", "mask_train_mask_test = np.load(metrics_path_mm, allow_pickle=True)[()]\n", "\n", "# Trained on masked data, tested on unmasked data\n", - "metrics_path_mu = '../runs/masked_train_unit_mask_test/wf-outputs/wf-outputs-202503131720/metrics/metrics-polymasked_train_unit_mask_test.npy'\n", + "metrics_path_mu = \"../runs/masked_train_unit_mask_test/wf-outputs/wf-outputs-202503131720/metrics/metrics-polymasked_train_unit_mask_test.npy\"\n", "mask_train_nomask_test = np.load(metrics_path_mu, allow_pickle=True)[()]\n", "\n", "# Trained on unmasked data, tested on unmasked data\n", - "metrics_path_c = '../runs/control_train/wf-outputs/wf-outputs-202503131716/metrics/metrics-polycontrol_train.npy'\n", + "metrics_path_c = \"../runs/control_train/wf-outputs/wf-outputs-202503131716/metrics/metrics-polycontrol_train.npy\"\n", "control_train = np.load(metrics_path_c, allow_pickle=True)[()]\n", "\n", "# Trained and tested with unitary masks\n", - "metrics_path_u = '../runs/unit_masked_train/wf-outputs/wf-outputs-202503131721/metrics/metrics-polyunit_masked_train.npy'\n", + "metrics_path_u = \"../runs/unit_masked_train/wf-outputs/wf-outputs-202503131721/metrics/metrics-polyunit_masked_train.npy\"\n", "unitary = np.load(metrics_path_u, allow_pickle=True)[()]" ] }, @@ -50,8 +50,8 @@ ], "source": [ "print(mask_train_mask_test.keys())\n", - "print(mask_train_mask_test['test_metrics'].keys())\n", - "print(mask_train_mask_test['test_metrics']['poly_metric'].keys())" + "print(mask_train_mask_test[\"test_metrics\"].keys())\n", + "print(mask_train_mask_test[\"test_metrics\"][\"poly_metric\"].keys())" ] }, { @@ -60,17 +60,25 @@ "metadata": {}, "outputs": [], "source": [ - "mask_test_mask_test_rel_rmse = mask_train_mask_test['test_metrics']['poly_metric']['rel_rmse']\n", - "mask_test_mask_test_std_rel_rmse = mask_train_mask_test['test_metrics']['poly_metric']['std_rel_rmse']\n", + "mask_test_mask_test_rel_rmse = mask_train_mask_test[\"test_metrics\"][\"poly_metric\"][\n", + " \"rel_rmse\"\n", + "]\n", + "mask_test_mask_test_std_rel_rmse = mask_train_mask_test[\"test_metrics\"][\"poly_metric\"][\n", + " \"std_rel_rmse\"\n", + "]\n", "\n", - "mask_test_nomask_test_rel_rmse = mask_train_nomask_test['test_metrics']['poly_metric']['rel_rmse']\n", - "mask_test_nomask_test_std_rel_rmse = mask_train_nomask_test['test_metrics']['poly_metric']['std_rel_rmse']\n", + "mask_test_nomask_test_rel_rmse = mask_train_nomask_test[\"test_metrics\"][\"poly_metric\"][\n", + " \"rel_rmse\"\n", + "]\n", + "mask_test_nomask_test_std_rel_rmse = mask_train_nomask_test[\"test_metrics\"][\n", + " \"poly_metric\"\n", + "][\"std_rel_rmse\"]\n", "\n", - "control_test_rel_rmse = control_train['test_metrics']['poly_metric']['rel_rmse']\n", - "control_test_std_rel_rmse = control_train['test_metrics']['poly_metric']['std_rel_rmse']\n", + "control_test_rel_rmse = control_train[\"test_metrics\"][\"poly_metric\"][\"rel_rmse\"]\n", + "control_test_std_rel_rmse = control_train[\"test_metrics\"][\"poly_metric\"][\"std_rel_rmse\"]\n", "\n", - "unitary_test_rel_rmse = unitary['test_metrics']['poly_metric']['rel_rmse']\n", - "unitary_test_std_rel_rmse = unitary['test_metrics']['poly_metric']['std_rel_rmse']" + "unitary_test_rel_rmse = unitary[\"test_metrics\"][\"poly_metric\"][\"rel_rmse\"]\n", + "unitary_test_std_rel_rmse = unitary[\"test_metrics\"][\"poly_metric\"][\"std_rel_rmse\"]" ] }, { @@ -92,12 +100,29 @@ "source": [ "# Plot the results\n", "fig, ax = plt.subplots(figsize=(8, 4))\n", - "plt.title('Relative RMSE 1x - Test dataset')\n", - "ax.errorbar([0, 1, 2, 3], [control_test_rel_rmse, mask_test_mask_test_rel_rmse, mask_test_nomask_test_rel_rmse, unitary_test_rel_rmse], yerr=[control_test_std_rel_rmse, mask_test_mask_test_std_rel_rmse, mask_test_nomask_test_std_rel_rmse, unitary_test_std_rel_rmse], fmt='o')\n", + "plt.title(\"Relative RMSE 1x - Test dataset\")\n", + "ax.errorbar(\n", + " [0, 1, 2, 3],\n", + " [\n", + " control_test_rel_rmse,\n", + " mask_test_mask_test_rel_rmse,\n", + " mask_test_nomask_test_rel_rmse,\n", + " unitary_test_rel_rmse,\n", + " ],\n", + " yerr=[\n", + " control_test_std_rel_rmse,\n", + " mask_test_mask_test_std_rel_rmse,\n", + " mask_test_nomask_test_std_rel_rmse,\n", + " unitary_test_std_rel_rmse,\n", + " ],\n", + " fmt=\"o\",\n", + ")\n", "ax.set_xticks([0, 1, 2, 3])\n", - "ax.set_xticklabels(['Control Train', 'Mask Train Mask Test', 'Mask Train Unit Mask Test', 'Unitary'])\n", - "ax.grid('minor')\n", - "ax.set_ylabel('Relative RMSE')\n", + "ax.set_xticklabels(\n", + " [\"Control Train\", \"Mask Train Mask Test\", \"Mask Train Unit Mask Test\", \"Unitary\"]\n", + ")\n", + "ax.grid(\"minor\")\n", + "ax.set_ylabel(\"Relative RMSE\")\n", "plt.show()" ] }, @@ -107,17 +132,27 @@ "metadata": {}, "outputs": [], "source": [ - "mask_train_mask_test_rel_rmse = mask_train_mask_test['train_metrics']['poly_metric']['rel_rmse']\n", - "mask_train_mask_test_std_rel_rmse = mask_train_mask_test['train_metrics']['poly_metric']['std_rel_rmse']\n", + "mask_train_mask_test_rel_rmse = mask_train_mask_test[\"train_metrics\"][\"poly_metric\"][\n", + " \"rel_rmse\"\n", + "]\n", + "mask_train_mask_test_std_rel_rmse = mask_train_mask_test[\"train_metrics\"][\n", + " \"poly_metric\"\n", + "][\"std_rel_rmse\"]\n", "\n", - "mask_train_nomask_test_rel_rmse = mask_train_nomask_test['train_metrics']['poly_metric']['rel_rmse']\n", - "mask_train_nomask_test_std_rel_rmse = mask_train_nomask_test['train_metrics']['poly_metric']['std_rel_rmse']\n", + "mask_train_nomask_test_rel_rmse = mask_train_nomask_test[\"train_metrics\"][\n", + " \"poly_metric\"\n", + "][\"rel_rmse\"]\n", + "mask_train_nomask_test_std_rel_rmse = mask_train_nomask_test[\"train_metrics\"][\n", + " \"poly_metric\"\n", + "][\"std_rel_rmse\"]\n", "\n", - "control_train_rel_rmse = control_train['train_metrics']['poly_metric']['rel_rmse']\n", - "control_train_std_rel_rmse = control_train['train_metrics']['poly_metric']['std_rel_rmse']\n", + "control_train_rel_rmse = control_train[\"train_metrics\"][\"poly_metric\"][\"rel_rmse\"]\n", + "control_train_std_rel_rmse = control_train[\"train_metrics\"][\"poly_metric\"][\n", + " \"std_rel_rmse\"\n", + "]\n", "\n", - "unitary_rel_rmse = unitary['train_metrics']['poly_metric']['rel_rmse']\n", - "unitary_std_rel_rmse = unitary['train_metrics']['poly_metric']['std_rel_rmse']" + "unitary_rel_rmse = unitary[\"train_metrics\"][\"poly_metric\"][\"rel_rmse\"]\n", + "unitary_std_rel_rmse = unitary[\"train_metrics\"][\"poly_metric\"][\"std_rel_rmse\"]" ] }, { @@ -139,12 +174,29 @@ "source": [ "# Plot the results\n", "fig, ax = plt.subplots(figsize=(8, 4))\n", - "plt.title('Relative RMSE 1x - Train dataset')\n", - "ax.errorbar([0, 1, 2, 3], [control_train_rel_rmse, mask_train_mask_test_rel_rmse, mask_train_nomask_test_rel_rmse, unitary_rel_rmse], yerr=[control_train_std_rel_rmse, mask_train_mask_test_std_rel_rmse, mask_train_nomask_test_std_rel_rmse, unitary_std_rel_rmse], fmt='o')\n", + "plt.title(\"Relative RMSE 1x - Train dataset\")\n", + "ax.errorbar(\n", + " [0, 1, 2, 3],\n", + " [\n", + " control_train_rel_rmse,\n", + " mask_train_mask_test_rel_rmse,\n", + " mask_train_nomask_test_rel_rmse,\n", + " unitary_rel_rmse,\n", + " ],\n", + " yerr=[\n", + " control_train_std_rel_rmse,\n", + " mask_train_mask_test_std_rel_rmse,\n", + " mask_train_nomask_test_std_rel_rmse,\n", + " unitary_std_rel_rmse,\n", + " ],\n", + " fmt=\"o\",\n", + ")\n", "ax.set_xticks([0, 1, 2, 3])\n", - "ax.set_xticklabels(['Control Train', 'Mask Train Mask Test', 'Mask Train Unit Mask Test', 'Unitary'])\n", - "ax.set_ylabel('Relative RMSE')\n", - "ax.grid('minor')\n", + "ax.set_xticklabels(\n", + " [\"Control Train\", \"Mask Train Mask Test\", \"Mask Train Unit Mask Test\", \"Unitary\"]\n", + ")\n", + "ax.set_ylabel(\"Relative RMSE\")\n", + "ax.grid(\"minor\")\n", "plt.show()" ] }, @@ -167,16 +219,50 @@ "source": [ "# Plot test and train relative RMSE in the same plot\n", "fig, ax = plt.subplots(figsize=(8, 4))\n", - "plt.title('Relative RMSE 1x - Train and Test dataset')\n", - "ax.errorbar([0, 1, 2, 3], [control_train_rel_rmse, mask_train_mask_test_rel_rmse, mask_train_nomask_test_rel_rmse, unitary_rel_rmse], yerr=[control_train_std_rel_rmse, mask_train_mask_test_std_rel_rmse, mask_train_nomask_test_std_rel_rmse, unitary_std_rel_rmse], fmt='o', label='Train')\n", - "ax.errorbar([0.02, 1.02, 2.02, 3.02], [control_test_rel_rmse, mask_test_mask_test_rel_rmse, mask_test_nomask_test_rel_rmse, unitary_test_rel_rmse], yerr=[control_test_std_rel_rmse, mask_test_mask_test_std_rel_rmse, mask_test_nomask_test_std_rel_rmse, unitary_test_std_rel_rmse], fmt='o', label='Test')\n", + "plt.title(\"Relative RMSE 1x - Train and Test dataset\")\n", + "ax.errorbar(\n", + " [0, 1, 2, 3],\n", + " [\n", + " control_train_rel_rmse,\n", + " mask_train_mask_test_rel_rmse,\n", + " mask_train_nomask_test_rel_rmse,\n", + " unitary_rel_rmse,\n", + " ],\n", + " yerr=[\n", + " control_train_std_rel_rmse,\n", + " mask_train_mask_test_std_rel_rmse,\n", + " mask_train_nomask_test_std_rel_rmse,\n", + " unitary_std_rel_rmse,\n", + " ],\n", + " fmt=\"o\",\n", + " label=\"Train\",\n", + ")\n", + "ax.errorbar(\n", + " [0.02, 1.02, 2.02, 3.02],\n", + " [\n", + " control_test_rel_rmse,\n", + " mask_test_mask_test_rel_rmse,\n", + " mask_test_nomask_test_rel_rmse,\n", + " unitary_test_rel_rmse,\n", + " ],\n", + " yerr=[\n", + " control_test_std_rel_rmse,\n", + " mask_test_mask_test_std_rel_rmse,\n", + " mask_test_nomask_test_std_rel_rmse,\n", + " unitary_test_std_rel_rmse,\n", + " ],\n", + " fmt=\"o\",\n", + " label=\"Test\",\n", + ")\n", "ax.set_xticks([0, 1, 2, 3])\n", - "ax.set_xticklabels(['Control Train', 'Mask Train Mask Test', 'Mask Train Unit Mask Test', 'Unitary'])\n", - "ax.set_ylabel('Relative RMSE')\n", - "ax.grid('minor')\n", + "ax.set_xticklabels(\n", + " [\"Control Train\", \"Mask Train Mask Test\", \"Mask Train Unit Mask Test\", \"Unitary\"]\n", + ")\n", + "ax.set_ylabel(\"Relative RMSE\")\n", + "ax.grid(\"minor\")\n", "plt.legend()\n", "# plt.show()\n", - "plt.savefig('masked_loss_validation.pdf')\n" + "plt.savefig(\"masked_loss_validation.pdf\")" ] }, { diff --git a/src/wf_psf/tests/test_data/__init__.py b/src/wf_psf/tests/test_data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/wf_psf/tests/test_utils/centroids_test.py b/src/wf_psf/tests/test_data/centroids_test.py similarity index 62% rename from src/wf_psf/tests/test_utils/centroids_test.py rename to src/wf_psf/tests/test_data/centroids_test.py index 8557704f..185da8d7 100644 --- a/src/wf_psf/tests/test_utils/centroids_test.py +++ b/src/wf_psf/tests/test_data/centroids_test.py @@ -8,8 +8,9 @@ import numpy as np import pytest +from wf_psf.data.centroids import compute_centroid_correction, CentroidEstimator +from wf_psf.utils.read_config import RecursiveNamespace from unittest.mock import MagicMock, patch -from wf_psf.utils.centroids import compute_zernike_tip_tilt, CentroidEstimator # Function to compute centroid based on first-order moments @@ -28,25 +29,6 @@ def calculate_centroid(image, mask=None): return (xc, yc) -@pytest.fixture -def simple_image(): - """Fixture for a batch of simple star images.""" - num_images = 1 # Change this to test with multiple images - image = np.zeros((num_images, 5, 5)) # Create a 3D array - image[:, 2, 2] = 1 # Place the star at the center for each image - return image - - -@pytest.fixture -def multiple_images(): - """Fixture for a batch of images with stars at different positions.""" - images = np.zeros((3, 5, 5)) # 3 images, each of size 5x5 - images[0, 2, 2] = 1 # Star at center of image 0 - images[1, 1, 3] = 1 # Star at (1, 3) in image 1 - images[2, 3, 1] = 1 # Star at (3, 1) in image 2 - return images - - @pytest.fixture def simple_star_and_mask(): """Fixture for an image with multiple non-zero pixels for centroid calculation.""" @@ -68,12 +50,6 @@ def simple_star_and_mask(): return image, mask -@pytest.fixture -def identity_mask(): - """Creates a mask where all pixels are fully considered.""" - return np.ones((5, 5)) - - @pytest.fixture def simple_image_with_mask(simple_image): """Fixture for a batch of star images with masks.""" @@ -129,133 +105,84 @@ def batch_images(): return images -def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_mask): - """Test compute_zernike_tip_tilt with single batch input and mocks.""" - # Mock the CentroidEstimator class - mock_centroid_calc = mocker.patch( - "wf_psf.utils.centroids.CentroidEstimator", autospec=True - ) - - # Create a mock instance and configure get_intra_pixel_shifts() - mock_instance = mock_centroid_calc.return_value - mock_instance.get_intra_pixel_shifts.return_value = np.array( - [[0.05, -0.02]] - ) # Shape (1, 2) - - # Mock shift_x_y_to_zk1_2_wavediff to return predictable values - mock_shift_fn = mocker.patch( - "wf_psf.utils.centroids.shift_x_y_to_zk1_2_wavediff", - side_effect=lambda shift: shift * 0.5, # Mocked conversion for test - ) - - # Define test inputs (batch of 1 image) - pixel_sampling = 12e-6 - reference_shifts = [-1 / 3, -1 / 3] # Default Euclid conditions - - # Run the function - zernike_corrections = compute_zernike_tip_tilt( - simple_image, identity_mask, pixel_sampling, reference_shifts - ) - zernike_corrections = compute_zernike_tip_tilt( - simple_image, identity_mask, pixel_sampling, reference_shifts - ) - - # Expected shifts based on centroid calculation - expected_dx = reference_shifts[1] - (-0.02) # Expected x-axis shift in meters - expected_dy = reference_shifts[0] - 0.05 # Expected y-axis shift in meters - - # Expected calls to the mocked function - # Extract the arguments passed to mock_shift_fn - args, _ = mock_shift_fn.call_args_list[0] # Get the first call args - - # Compare expected values with the actual arguments passed to the mock function - np.testing.assert_allclose( - args[0][0], expected_dx * pixel_sampling, rtol=1e-7, atol=0 +def test_compute_centroid_correction_with_masks(mock_data): + """Test compute_centroid_correction function with masks present.""" + # Given that compute_centroid_correction expects a model_params and data object + model_params = RecursiveNamespace( + pix_sampling=12e-6, # Example pixel sampling in meters + correct_centroids=True, + reference_shifts=["-1/3", "-1/3"], ) - # Check dy values similarly - np.testing.assert_allclose( - args[0][1], expected_dy * pixel_sampling, rtol=1e-7, atol=0 - ) - - # Expected values based on mock side_effect (0.5 * shift) - np.testing.assert_allclose( - zernike_corrections[0, 0], expected_dx * pixel_sampling * 0.5 - ) # Zk1 - np.testing.assert_allclose( - zernike_corrections[0, 1], expected_dy * pixel_sampling * 0.5 - ) # Zk2 - - -def test_compute_zernike_tip_tilt_batch(mocker, multiple_images): - """Test compute_zernike_tip_tilt with batch input and mocks.""" - # Mock the CentroidEstimator class - mock_centroid_calc = mocker.patch( - "wf_psf.utils.centroids.CentroidEstimator", autospec=True - ) - - # Create a mock instance and configure get_intra_pixel_shifts() - mock_instance = mock_centroid_calc.return_value - mock_instance.get_intra_pixel_shifts.return_value = np.array( - [[0.05, -0.02], [0.04, -0.01], [0.06, -0.03]] - ) # Shape (3, 2) - - # Mock shift_x_y_to_zk1_2_wavediff to return predictable values - mock_shift_fn = mocker.patch( - "wf_psf.utils.centroids.shift_x_y_to_zk1_2_wavediff", - side_effect=lambda shift: shift * 0.5, # Mocked conversion for test - ) - - # Define test inputs (batch of 3 images) - pixel_sampling = 12e-6 - reference_shifts = [-1 / 3, -1 / 3] # Default Euclid conditions - - # Run the function - zernike_corrections = compute_zernike_tip_tilt( - star_images=multiple_images, - pixel_sampling=pixel_sampling, - reference_shifts=reference_shifts, - ) + # Wrap mock_data into a dict to match the function signature + centroid_dataset = { + "stamps": mock_data.training_data.dataset["noisy_stars"], + "masks": mock_data.training_data.dataset["masks"], + } - # Check if the mock function was called once with the full batch - assert len(mock_shift_fn.call_args_list) == 1, ( - f"Expected 1 call, but got {len(mock_shift_fn.call_args_list)}" + # Mock the internal function calls: + with ( + patch( + "wf_psf.data.centroids.compute_zernike_tip_tilt" + ) as mock_compute_zernike_tip_tilt, + ): + # Mock compute_zernike_tip_tilt to return synthetic Zernike coefficients + mock_compute_zernike_tip_tilt.return_value = np.array([[0.1, 0.2], [0.3, 0.4]]) + + # Call the function under test + result = compute_centroid_correction(model_params, centroid_dataset) + + # Ensure the result has the correct shape + assert result.shape == (4, 3) # Should be (n_stars, 3 Zernike components) + + assert np.allclose( + result[0, :], np.array([0, -0.1, -0.2]) + ) # First star Zernike coefficients + assert np.allclose( + result[1, :], np.array([0, -0.3, -0.4]) + ) # Second star Zernike coefficients + + +def test_compute_centroid_correction_without_masks(mock_data): + """Test compute_centroid_correction function when no masks are provided.""" + # Define model parameters + model_params = RecursiveNamespace( + pix_sampling=12e-6, # Example pixel sampling in meters + correct_centroids=True, + reference_shifts=["-1/3", "-1/3"], ) - # Get the arguments passed to the mock function for the batch of images - args, _ = mock_shift_fn.call_args_list[0] + # Wrap mock_data into a dict to match the function signature + centroid_dataset = { + "stamps": mock_data.training_data.dataset["noisy_stars"], + } - print("Shape of args[0]:", args[0].shape) - print("Contents of args[0]:", args[0]) - print("Mock function call args list:", mock_shift_fn.call_args_list) + # Mock internal function calls + with ( + patch( + "wf_psf.data.centroids.compute_zernike_tip_tilt" + ) as mock_compute_zernike_tip_tilt, + ): - # Reshape args[0] to (N, 2) for batch processing - args_array = np.array(args[0]).reshape(-1, 2) + # Mock compute_zernike_tip_tilt assuming no masks + mock_compute_zernike_tip_tilt.return_value = np.array([[0.1, 0.2], [0.3, 0.4]]) - # Process the displacements and expected values for each image in the batch - expected_dx = ( - reference_shifts[1] - mock_instance.get_intra_pixel_shifts.return_value[:, 1] - ) # Expected x-axis shift in meters + # Call function under test + result = compute_centroid_correction(model_params, centroid_dataset) - expected_dy = ( - reference_shifts[0] - mock_instance.get_intra_pixel_shifts.return_value[:, 0] - ) # Expected y-axis shift in meters + # Validate result shape + assert result.shape == (4, 3) # (n_stars, 3 Zernike components) - # Compare expected values with the actual arguments passed to the mock function - np.testing.assert_allclose( - args_array[:, 0], expected_dx * pixel_sampling, rtol=1e-7, atol=0 - ) - np.testing.assert_allclose( - args_array[:, 1], expected_dy * pixel_sampling, rtol=1e-7, atol=0 - ) - - # Expected values based on mock side_effect (0.5 * shift) - np.testing.assert_allclose( - zernike_corrections[:, 0], expected_dx * pixel_sampling * 0.5 - ) # Zk1 for each image - np.testing.assert_allclose( - zernike_corrections[:, 1], expected_dy * pixel_sampling * 0.5 - ) # Zk2 for each image + # Validate expected values (adjust based on behavior) + expected_result = np.array( + [ + [0, -0.1, -0.2], # From training data + [0, -0.3, -0.4], + [0, -0.1, -0.2], # From test data (reused mocked return) + [0, -0.3, -0.4], + ] + ) + assert np.allclose(result, expected_result) # Test for centroid calculation without mask @@ -442,9 +369,9 @@ def test_intra_pixel_shifts(simple_image_with_centroid): expected_y_shift = 2.7 - 2.0 # yc - yc0 # Check that the shifts are correct - assert np.isclose(shifts[0], expected_x_shift), ( - f"Expected {expected_x_shift}, got {shifts[0]}" - ) - assert np.isclose(shifts[1], expected_y_shift), ( - f"Expected {expected_y_shift}, got {shifts[1]}" - ) + assert np.isclose( + shifts[0], expected_x_shift + ), f"Expected {expected_x_shift}, got {shifts[0]}" + assert np.isclose( + shifts[1], expected_y_shift + ), f"Expected {expected_y_shift}, got {shifts[1]}" diff --git a/src/wf_psf/tests/test_data/conftest.py b/src/wf_psf/tests/test_data/conftest.py index 6159d53a..131922e5 100644 --- a/src/wf_psf/tests/test_data/conftest.py +++ b/src/wf_psf/tests/test_data/conftest.py @@ -9,8 +9,13 @@ """ import pytest +import numpy as np +import tensorflow as tf +from types import SimpleNamespace + from wf_psf.utils.read_config import RecursiveNamespace from wf_psf.psf_models import psf_models +from wf_psf.tests.test_data.test_data_utils import MockData training_config = RecursiveNamespace( id_name="-coherent_euclid_200stars", @@ -93,6 +98,76 @@ ) +@pytest.fixture +def mock_data(scope="module"): + """Fixture to provide mock data for testing.""" + # Mock positions and Zernike priors + training_positions = tf.constant([[1, 2], [3, 4]]) + test_positions = tf.constant([[5, 6], [7, 8]]) + training_zernike_priors = tf.constant([[0.1, 0.2], [0.3, 0.4]]) + test_zernike_priors = tf.constant([[0.5, 0.6], [0.7, 0.8]]) + + # Define dummy 5x5 image patches for stars (mock star images) + # Define varied values for 5x5 star images + noisy_stars = tf.constant( + [np.arange(25).reshape(5, 5), np.arange(25, 50).reshape(5, 5)], dtype=tf.float32 + ) + + noisy_masks = tf.constant([np.eye(5), np.ones((5, 5))], dtype=tf.float32) + + stars = tf.constant([np.full((5, 5), 100), np.full((5, 5), 200)], dtype=tf.float32) + + masks = tf.constant([np.zeros((5, 5)), np.tri(5)], dtype=tf.float32) + + return MockData( + training_positions, + test_positions, + training_zernike_priors, + test_zernike_priors, + noisy_stars, + noisy_masks, + stars, + masks, + ) + + +@pytest.fixture +def mock_data_inference(): + """Flat dataset for inference path only.""" + return SimpleNamespace( + dataset={ + "positions": np.array([[9, 9], [10, 10]]), + "zernike_prior": np.array([[0.9, 0.9]]), + # no "missing_key" → used to trigger allow_missing behavior + } + ) + + +@pytest.fixture +def simple_image(scope="module"): + """Fixture for a simple star image.""" + num_images = 1 # Change this to test with multiple images + image = np.zeros((num_images, 5, 5)) # Create a 3D array + image[:, 2, 2] = 1 # Place the star at the center for each image + return image + + +@pytest.fixture +def identity_mask(scope="module"): + """Creates a mask where all pixels are fully considered.""" + return np.ones((5, 5)) + + +@pytest.fixture +def multiple_images(scope="module"): + """Fixture for a batch of images with stars at different positions.""" + images = np.zeros((3, 5, 5)) # 3 images, each of size 5x5 + images[0, 2, 2] = 1 # Star at center of image 0 + images[1, 1, 3] = 1 # Star at (1, 3) in image 1 + images[2, 3, 1] = 1 # Star at (3, 1) in image 2 + return images + + @pytest.fixture(scope="module", params=[data]) def data_params(): return data diff --git a/src/wf_psf/tests/test_data/data_handler_test.py b/src/wf_psf/tests/test_data/data_handler_test.py new file mode 100644 index 00000000..d29771a1 --- /dev/null +++ b/src/wf_psf/tests/test_data/data_handler_test.py @@ -0,0 +1,361 @@ +import pytest +import numpy as np +import tensorflow as tf +from wf_psf.data.data_handler import ( + DataHandler, + get_data_array, + extract_star_data, +) +from wf_psf.utils.read_config import RecursiveNamespace + + +def mock_sed(): + # Create a fake SED with shape (n_wavelengths,) — match what your real SEDs look like + return np.linspace(0.1, 1.0, 50) + + +def test_process_sed_data_auto_load(data_params, simPSF): + # load_data=True → dataset is used and SEDs processed automatically + data_handler = DataHandler( + "training", data_params.training, simPSF, n_bins_lambda=10, load_data=True + ) + assert data_handler.sed_data is not None + assert data_handler.sed_data.shape[1] == 10 # n_bins_lambda + + +def test_load_train_dataset(tmp_path, simPSF): + # Create a temporary directory and a temporary data file + data_dir = tmp_path / "data" + data_dir.mkdir() + temp_data_dir = data_dir / "train_data.npy" + + # Mock dataset + mock_dataset = { + "positions": np.array([[1, 2], [3, 4]]), + "noisy_stars": np.array([[5, 6], [7, 8]]), + "SEDs": np.array([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]]), + } + + # Save the mock dataset to the temporary data file + np.save(temp_data_dir, mock_dataset) + + # Initialize DataHandler instance + data_params = RecursiveNamespace(data_dir=str(data_dir), file="train_data.npy") + + n_bins_lambda = 10 + data_handler = DataHandler( + "training", data_params, simPSF, n_bins_lambda, load_data=False + ) + + # Call the load_dataset method + data_handler.load_dataset() + + # Assertions + assert np.array_equal(data_handler.dataset["positions"], mock_dataset["positions"]) + assert np.array_equal( + data_handler.dataset["noisy_stars"], mock_dataset["noisy_stars"] + ) + assert np.array_equal(data_handler.dataset["SEDs"], mock_dataset["SEDs"]) + + +def test_load_test_dataset(tmp_path, simPSF): + # Create a temporary directory and a temporary data file + data_dir = tmp_path / "data" + data_dir.mkdir() + temp_data_dir = data_dir / "test_data.npy" + + # Mock dataset + mock_dataset = { + "positions": np.array([[1, 2], [3, 4]]), + "stars": np.array([[5, 6], [7, 8]]), + "SEDs": np.array([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]]), + } + + # Save the mock dataset to the temporary data file + np.save(temp_data_dir, mock_dataset) + + # Initialize DataHandler instance + data_params = RecursiveNamespace(data_dir=str(data_dir), file="test_data.npy") + + n_bins_lambda = 10 + data_handler = DataHandler( + dataset_type="test", + data_params=data_params, + simPSF=simPSF, + n_bins_lambda=n_bins_lambda, + load_data=False, + ) + + # Call the load_dataset method + data_handler.load_dataset() + + # Assertions + assert np.array_equal(data_handler.dataset["positions"], mock_dataset["positions"]) + assert np.array_equal(data_handler.dataset["stars"], mock_dataset["stars"]) + assert np.array_equal(data_handler.dataset["SEDs"], mock_dataset["SEDs"]) + + +def test_validate_train_dataset_missing_noisy_stars_raises(tmp_path, simPSF): + """Test that validation raises an error if 'noisy_stars' is missing in training data.""" + data_dir = tmp_path / "data" + data_dir.mkdir() + temp_data_file = data_dir / "train_data.npy" + + mock_dataset = { + "positions": np.array([[1, 2], [3, 4]]), # No 'noisy_stars' key + "SEDs": np.array([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]]), + } + + np.save(temp_data_file, mock_dataset) + + data_params = RecursiveNamespace(data_dir=str(data_dir), file="train_data.npy") + + n_bins_lambda = 10 + data_handler = DataHandler( + "training", data_params, simPSF, n_bins_lambda, load_data=False + ) + + with pytest.raises( + ValueError, match="Missing required field 'noisy_stars' in training dataset." + ): + data_handler.load_dataset() + data_handler.validate_and_process_dataset() + + +def test_load_test_dataset_missing_stars(tmp_path, simPSF): + """Test that a warning is raised if 'stars' is missing in test data.""" + data_dir = tmp_path / "data" + data_dir.mkdir() + temp_data_file = data_dir / "test_data.npy" + + mock_dataset = { + "positions": np.array([[1, 2], [3, 4]]), # No 'stars' key + "SEDs": np.array([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]]), + } + + np.save(temp_data_file, mock_dataset) + + data_params = RecursiveNamespace(data_dir=str(data_dir), file="test_data.npy") + + n_bins_lambda = 10 + data_handler = DataHandler( + "test", data_params, simPSF, n_bins_lambda, load_data=False + ) + + with pytest.raises( + ValueError, match="Missing required field 'stars' in test dataset." + ): + data_handler.load_dataset() + data_handler.validate_and_process_dataset() + + +def test_extract_star_data_valid_keys(mock_data): + """Test extracting valid data from the dataset.""" + result = extract_star_data(mock_data, train_key="noisy_stars", test_key="stars") + + expected = tf.concat( + [ + tf.constant( + [np.arange(25).reshape(5, 5), np.arange(25, 50).reshape(5, 5)], + dtype=tf.float32, + ), + tf.constant([np.full((5, 5), 100), np.full((5, 5), 200)], dtype=tf.float32), + ], + axis=0, + ) + + np.testing.assert_array_equal(result, expected) + + +def test_extract_star_data_masks(mock_data): + """Test extracting star masks from the dataset.""" + result = extract_star_data(mock_data, train_key="masks", test_key="masks") + + mask0 = np.eye(5, dtype=np.float32) + mask1 = np.ones((5, 5), dtype=np.float32) + mask2 = np.zeros((5, 5), dtype=np.float32) + mask3 = np.tri(5, dtype=np.float32) + + expected = np.array([mask0, mask1, mask2, mask3], dtype=np.float32) + + np.testing.assert_array_equal(result, expected) + + +def test_extract_star_data_missing_key(mock_data): + """Test that the function raises a KeyError when a key is missing.""" + with pytest.raises(KeyError, match="Missing keys in dataset: \\['invalid_key'\\]"): + extract_star_data(mock_data, train_key="invalid_key", test_key="stars") + + +def test_extract_star_data_partially_missing_key(mock_data): + """Test that the function raises a KeyError if only one key is missing.""" + with pytest.raises( + KeyError, match="Missing keys in dataset: \\['missing_stars'\\]" + ): + extract_star_data(mock_data, train_key="noisy_stars", test_key="missing_stars") + + +def test_extract_star_data_tensor_conversion(mock_data): + """Test that the function properly converts TensorFlow tensors to NumPy arrays.""" + result = extract_star_data(mock_data, train_key="noisy_stars", test_key="stars") + + assert isinstance(result, np.ndarray), "The result should be a NumPy array" + assert result.dtype == np.float32, "The NumPy array should have dtype float32" + + +def test_reference_shifts_broadcasting(): + reference_shifts = [-1 / 3, -1 / 3] # Example reference_shifts + shifts = np.random.rand(2, 2400) # Example shifts array + + # Ensure reference_shifts is a NumPy array (if it's not already) + reference_shifts = np.array(reference_shifts) + + # Broadcast reference_shifts to match the shape of shifts + reference_shifts = np.broadcast_to( + reference_shifts[:, None], shifts.shape + ) # Shape will be (2, 2400) + + # Ensure shapes are compatible for subtraction + displacements = reference_shifts - shifts + + # Test the result + assert displacements.shape == shifts.shape, "Shapes do not match" + assert np.all(displacements.shape == (2, 2400)), "Broadcasting failed" + + +@pytest.mark.parametrize( + "run_type,data_fixture,key,train_key,test_key,allow_missing,expect", + [ + # =================================================== + # training/simulation/metrics → extract_star_data path + # =================================================== + ( + "training", + "mock_data", + None, + "positions", + None, + False, + np.array([[1, 2], [3, 4], [5, 6], [7, 8]]), + ), + ( + "simulation", + "mock_data", + "none", + "noisy_stars", + "stars", + True, + # will concatenate noisy_stars from train and stars from test + # expected shape: (4, 5, 5) + # validate shape only, not full content (too large) + "shape:(4, 5, 5)", + ), + ( + "metrics", + "mock_data", + "zernike_prior", + None, + None, + True, + np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]]), + ), + # ================= + # inference (success) + # ================= + ( + "inference", + "mock_data_inference", + "positions", + None, + None, + False, + np.array([[9, 9], [10, 10]]), + ), + ( + "inference", + "mock_data_inference", + "zernike_prior", + None, + None, + False, + np.array([[0.9, 0.9]]), + ), + # ============================== + # inference → allow_missing=True + # ============================== + ( + "inference", + "mock_data_inference", + None, + None, + None, + True, + None, + ), + ( + "inference", + "mock_data_inference", + "missing_key", + None, + None, + True, + None, + ), + # ================================= + # inference → allow_missing=False → errors + # ================================= + ( + "inference", + "mock_data_inference", + None, + None, + None, + False, + pytest.raises(ValueError), + ), + ( + "inference", + "mock_data_inference", + "missing_key", + None, + None, + False, + pytest.raises(KeyError), + ), + ], +) +def test_get_data_array_v2( + request, run_type, data_fixture, key, train_key, test_key, allow_missing, expect +): + data = request.getfixturevalue(data_fixture) + + if hasattr(expect, "__enter__") and hasattr(expect, "__exit__"): + with expect: + get_data_array( + data, + run_type, + key=key, + train_key=train_key, + test_key=test_key, + allow_missing=allow_missing, + ) + return + + result = get_data_array( + data, + run_type, + key=key, + train_key=train_key, + test_key=test_key, + allow_missing=allow_missing, + ) + + if expect is None: + assert result is None + elif isinstance(expect, str) and expect.startswith("shape:"): + expected_shape = tuple(eval(expect.replace("shape:", ""))) + assert isinstance(result, np.ndarray) + assert result.shape == expected_shape + else: + assert isinstance(result, np.ndarray) + assert np.allclose(result, expect, rtol=1e-6, atol=1e-8) diff --git a/src/wf_psf/tests/test_data/data_zernike_utils_test.py b/src/wf_psf/tests/test_data/data_zernike_utils_test.py new file mode 100644 index 00000000..66d23309 --- /dev/null +++ b/src/wf_psf/tests/test_data/data_zernike_utils_test.py @@ -0,0 +1,543 @@ +import pytest +import numpy as np +from unittest.mock import MagicMock, patch +import tensorflow as tf +from wf_psf.data.data_zernike_utils import ( + ZernikeInputsFactory, + get_np_zernike_prior, + pad_contribution_to_order, + combine_zernike_contributions, + assemble_zernike_contributions, + compute_zernike_tip_tilt, + pad_tf_zernikes, +) +from types import SimpleNamespace as RecursiveNamespace + + +@pytest.fixture +def mock_model_params(): + return RecursiveNamespace( + use_prior=True, + correct_centroids=True, + add_ccd_misalignments=True, + param_hparams=RecursiveNamespace(n_zernikes=6), + ) + + +@pytest.fixture +def dummy_prior(): + return np.ones((4, 6), dtype=np.float32) + + +@pytest.fixture +def dummy_centroid_dataset(): + return {"training": "dummy_train", "test": "dummy_test"} + + +def test_training_without_prior(mock_model_params, mock_data): + mock_model_params.use_prior = False + + # Clear priors to simulate no prior being used + mock_data.training_data.dataset.pop("zernike_prior", None) + mock_data.test_data.dataset.pop("zernike_prior", None) + + zinputs = ZernikeInputsFactory.build( + data=mock_data, run_type="training", model_params=mock_model_params + ) + + mock_data_stamps = np.concatenate( + [ + mock_data.training_data.dataset["noisy_stars"], + mock_data.test_data.dataset["stars"], + ] + ) + mock_data_masks = np.concatenate( + [ + mock_data.training_data.dataset["masks"], + mock_data.test_data.dataset["masks"], + ] + ) + + assert np.allclose( + zinputs.centroid_dataset["stamps"], mock_data_stamps, rtol=1e-6, atol=1e-8 + ) + + assert np.allclose( + zinputs.centroid_dataset["masks"], mock_data_masks, rtol=1e-6, atol=1e-8 + ) + + assert zinputs.zernike_prior is None + + expected_positions = np.concatenate( + [ + mock_data.training_data.dataset["positions"], + mock_data.test_data.dataset["positions"], + ] + ) + np.testing.assert_array_equal(zinputs.misalignment_positions, expected_positions) + + +def test_training_with_dataset_prior(mock_model_params, mock_data): + mock_model_params.use_prior = True + + zinputs = ZernikeInputsFactory.build( + data=mock_data, run_type="training", model_params=mock_model_params + ) + + expected_priors = np.concatenate( + ( + mock_data.training_data.dataset["zernike_prior"], + mock_data.test_data.dataset["zernike_prior"], + ), + axis=0, + ) + np.testing.assert_array_equal(zinputs.zernike_prior, expected_priors) + + +def test_training_with_explicit_prior(mock_model_params, caplog): + mock_model_params.use_prior = True + data = MagicMock() + data.training_dataset = {"positions": np.ones((1, 2))} + data.test_dataset = {"positions": np.zeros((1, 2))} + + explicit_prior = np.array([9.0, 9.0, 9.0]) + + with caplog.at_level("WARNING"): + zinputs = ZernikeInputsFactory.build( + data, "training", mock_model_params, prior=explicit_prior + ) + + assert "Explicit prior provided; ignoring dataset-based prior." in caplog.text + assert (zinputs.zernike_prior == explicit_prior).all() + + +def test_inference_with_dict_and_prior(mock_model_params): + mock_model_params.use_prior = True + data = RecursiveNamespace( + dataset={ + "positions": tf.ones((5, 2)), + "zernike_prior": tf.constant([42.0, 0.0]), + } + ) + + zinputs = ZernikeInputsFactory.build(data, "inference", mock_model_params) + + for key in ["stamps", "masks"]: + assert zinputs.centroid_dataset[key] is None + + # NumPy array comparison + np.testing.assert_array_equal( + zinputs.misalignment_positions, data.dataset["positions"].numpy() + ) + + # TensorFlow tensor comparison + tf.debugging.assert_equal(zinputs.zernike_prior, data.dataset["zernike_prior"]) + + +def test_invalid_run_type(mock_model_params): + data = {"positions": np.ones((2, 2))} + with pytest.raises(ValueError, match="Unsupported run_type"): + ZernikeInputsFactory.build(data, "invalid_mode", mock_model_params) + + +def test_get_np_zernike_prior(): + # Mock training and test data + training_prior = np.array([[1, 2, 3], [4, 5, 6]]) + test_prior = np.array([[7, 8, 9]]) + + # Construct fake DataConfigHandler structure using RecursiveNamespace + data = RecursiveNamespace( + training_data=RecursiveNamespace(dataset={"zernike_prior": training_prior}), + test_data=RecursiveNamespace(dataset={"zernike_prior": test_prior}), + ) + + expected_prior = np.concatenate((training_prior, test_prior), axis=0) + + result = get_np_zernike_prior(data) + + # Assert shape and values match expected + np.testing.assert_array_equal(result, expected_prior) + + +def test_pad_contribution_to_order(): + # Input: batch of 2 samples, each with 3 Zernike coefficients + input_contribution = np.array( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ) + + max_order = 5 # Target size: pad to 5 coefficients + + expected_output = np.array( + [ + [1.0, 2.0, 3.0, 0.0, 0.0], + [4.0, 5.0, 6.0, 0.0, 0.0], + ] + ) + + padded = pad_contribution_to_order(input_contribution, max_order) + + assert padded.shape == (2, 5), "Output shape should match padded shape" + np.testing.assert_array_equal(padded, expected_output) + + +def test_no_padding_needed(): + """If current order equals max_order, return should be unchanged.""" + input_contribution = np.array([[1, 2, 3], [4, 5, 6]]) + max_order = 3 + output = pad_contribution_to_order(input_contribution, max_order) + assert output.shape == input_contribution.shape + np.testing.assert_array_equal(output, input_contribution) + + +def test_padding_to_much_higher_order(): + """Pad from order 2 to order 10.""" + input_contribution = np.array([[1, 2], [3, 4]]) + max_order = 10 + expected_output = np.hstack([input_contribution, np.zeros((2, 8))]) + output = pad_contribution_to_order(input_contribution, max_order) + assert output.shape == (2, 10) + np.testing.assert_array_equal(output, expected_output) + + +def test_empty_contribution(): + """Test behavior with empty input array (0 features).""" + input_contribution = np.empty((3, 0)) # 3 samples, 0 coefficients + max_order = 4 + expected_output = np.zeros((3, 4)) + output = pad_contribution_to_order(input_contribution, max_order) + assert output.shape == (3, 4) + np.testing.assert_array_equal(output, expected_output) + + +def test_zero_samples(): + """Test with zero samples (empty batch).""" + input_contribution = np.empty((0, 3)) # 0 samples, 3 coefficients + max_order = 5 + expected_output = np.empty((0, 5)) + output = pad_contribution_to_order(input_contribution, max_order) + assert output.shape == (0, 5) + np.testing.assert_array_equal(output, expected_output) + + +def test_combine_zernike_contributions_basic_case(): + """Combine two contributions with matching sample count and varying order.""" + contrib1 = np.array([[1, 2], [3, 4]]) # shape (2, 2) + contrib2 = np.array([[5], [6]]) # shape (2, 1) + expected = np.array([[1 + 5, 2 + 0], [3 + 6, 4 + 0]]) # padded contrib2 to (2, 2) + result = combine_zernike_contributions([contrib1, contrib2]) + np.testing.assert_array_equal(result, expected) + + +def test_combine_multiple_contributions(): + """Combine three contributions.""" + c1 = np.array([[1, 2, 3]]) # shape (1, 3) + c2 = np.array([[4, 5]]) # shape (1, 2) + c3 = np.array([[6]]) # shape (1, 1) + expected = np.array([[1 + 4 + 6, 2 + 5 + 0, 3 + 0 + 0]]) # shape (1, 3) + result = combine_zernike_contributions([c1, c2, c3]) + np.testing.assert_array_equal(result, expected) + + +def test_empty_input_list(): + """Raise ValueError when input list is empty.""" + with pytest.raises(ValueError, match="No contributions provided."): + combine_zernike_contributions([]) + + +def test_inconsistent_sample_count(): + """Raise error or produce incorrect shape if contributions have inconsistent sample counts.""" + c1 = np.array([[1, 2], [3, 4]]) # shape (2, 2) + c2 = np.array([[5, 6]]) # shape (1, 2) + with pytest.raises(ValueError): + combine_zernike_contributions([c1, c2]) + + +def test_single_contribution(): + """Combining a single contribution should return the same array (no-op).""" + contrib = np.array([[7, 8, 9], [10, 11, 12]]) + result = combine_zernike_contributions([contrib]) + np.testing.assert_array_equal(result, contrib) + + +def test_zero_order_contributions(): + """Contributions with 0 Zernike coefficients.""" + contrib1 = np.empty((2, 0)) # 2 samples, 0 coefficients + contrib2 = np.empty((2, 0)) + expected = np.empty((2, 0)) + result = combine_zernike_contributions([contrib1, contrib2]) + assert result.shape == (2, 0) + np.testing.assert_array_equal(result, expected) + + +@patch("wf_psf.data.data_zernike_utils.compute_centroid_correction") +@patch("wf_psf.data.data_zernike_utils.compute_ccd_misalignment") +def test_full_contribution_combination( + mock_ccd, mock_centroid, mock_model_params, dummy_prior, dummy_centroid_dataset +): + mock_centroid.return_value = np.full((4, 6), 2.0) + mock_ccd.return_value = np.full((4, 6), 3.0) + dummy_positions = np.full((4, 6), 1.0) + + result = assemble_zernike_contributions( + model_params=mock_model_params, + zernike_prior=dummy_prior, + centroid_dataset=dummy_centroid_dataset, + positions=dummy_positions, + ) + + expected = dummy_prior + 2.0 + 3.0 + np.testing.assert_allclose(result.numpy(), expected) + + +def test_prior_only(mock_model_params, dummy_prior): + mock_model_params.correct_centroids = False + mock_model_params.add_ccd_misalignments = False + + result = assemble_zernike_contributions( + model_params=mock_model_params, + zernike_prior=dummy_prior, + centroid_dataset=None, + positions=None, + ) + + np.testing.assert_array_equal(result.numpy(), dummy_prior) + + +def test_no_contributions_returns_zeros(): + model_params = RecursiveNamespace( + use_prior=False, + correct_centroids=False, + add_ccd_misalignments=False, + param_hparams=RecursiveNamespace(n_zernikes=8), + ) + + result = assemble_zernike_contributions(model_params) + + assert isinstance(result, tf.Tensor) + assert result.shape == (1, 8) + np.testing.assert_array_equal(result.numpy(), np.zeros((1, 8))) + + +def test_prior_as_tensor(mock_model_params): + tensor_prior = tf.ones((4, 6), dtype=tf.float32) + + mock_model_params.correct_centroids = False + mock_model_params.add_ccd_misalignments = False + + result = assemble_zernike_contributions( + model_params=mock_model_params, zernike_prior=tensor_prior + ) + assert tf.executing_eagerly(), "TensorFlow must be in eager mode for this test" + assert isinstance(result, tf.Tensor) + np.testing.assert_array_equal(result.numpy(), np.ones((4, 6))) + + +@patch("wf_psf.data.data_zernike_utils.compute_centroid_correction") +def test_inconsistent_shapes_raises_error( + mock_centroid, mock_model_params, dummy_prior, dummy_centroid_dataset +): + mock_model_params.add_ccd_misalignments = False + mock_centroid.return_value = np.ones((5, 6)) # 5 samples instead of 4 + + with pytest.raises( + ValueError, match="All contributions must have the same number of samples" + ): + assemble_zernike_contributions( + model_params=mock_model_params, + zernike_prior=dummy_prior, + centroid_dataset=dummy_centroid_dataset, + positions=None, + ) + + +def test_pad_zernikes_num_of_zernikes_equal(): + # Prepare your test tensors + zk_param = tf.constant([[[[1.0]]], [[[2.0]]]]) # Shape (2, 1, 1, 1) + zk_prior = tf.constant([[[[1.0]]], [[[2.0]]]]) # Same shape + + # Reshape to (1, 2, 1, 1) + zk_param = tf.reshape(zk_param, (1, 2, 1, 1)) + zk_prior = tf.reshape(zk_prior, (1, 2, 1, 1)) + + # Reset _n_zks_total to max number of zernikes (2 here) + n_zks_total = max(tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy()) + + # Call pad_zernikes method + padded_zk_param, padded_zk_prior = pad_tf_zernikes(zk_param, zk_prior, n_zks_total) + + # Assert shapes are equal and correct + assert padded_zk_param.shape[1] == n_zks_total + assert padded_zk_prior.shape[1] == n_zks_total + + # If num zernikes already equal, output should be unchanged + np.testing.assert_array_equal(padded_zk_param.numpy(), zk_param.numpy()) + np.testing.assert_array_equal(padded_zk_prior.numpy(), zk_prior.numpy()) + + +def test_pad_zernikes_prior_greater_than_param(): + zk_param = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) + zk_prior = tf.constant([[[[1]], [[2]], [[3]], [[4]], [[5]]]]) # Shape: (5, 1, 1, 1) + + # Reshape the tensors to have the desired shapes + zk_param = tf.reshape(zk_param, (1, 2, 1, 1)) # Reshaping tensor1 to (1, 2, 1, 1) + zk_prior = tf.reshape(zk_prior, (1, 5, 1, 1)) # Reshaping tensor2 to (1, 5, 1, 1) + + # Reset n_zks_total attribute + n_zks_total = max(tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy()) + + # Call the method under test + padded_zk_param, padded_zk_prior = pad_tf_zernikes(zk_param, zk_prior, n_zks_total) + + # Assert that the padded tensors have the correct shapes + assert padded_zk_param.shape == (1, 5, 1, 1) + assert padded_zk_prior.shape == (1, 5, 1, 1) + + +def test_pad_zernikes_param_greater_than_prior(): + zk_param = tf.constant([[[[10]], [[20]], [[30]], [[40]]]]) # Shape: (4, 1, 1, 1) + zk_prior = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) + + # Reshape the tensors to have the desired shapes + zk_param = tf.reshape(zk_param, (1, 4, 1, 1)) # Reshaping tensor1 to (1, 2, 1, 1) + zk_prior = tf.reshape(zk_prior, (1, 2, 1, 1)) # Reshaping tensor2 to (1, 4, 1, 1) + + # Reset n_zks_total attribute + n_zks_total = max(tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy()) + + # Call the method under test + padded_zk_param, padded_zk_prior = pad_tf_zernikes(zk_param, zk_prior, n_zks_total) + + # Assert that the padded tensors have the correct shapes + assert padded_zk_param.shape == (1, 4, 1, 1) + assert padded_zk_prior.shape == (1, 4, 1, 1) + + +def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_mask): + """Test compute_zernike_tip_tilt handling with single batch input and mocks.""" + # Mock the CentroidEstimator class + mock_centroid_calc = mocker.patch( + "wf_psf.data.centroids.CentroidEstimator", autospec=True + ) + + # Create a mock instance and configure get_intra_pixel_shifts() + mock_instance = mock_centroid_calc.return_value + mock_instance.get_intra_pixel_shifts.return_value = np.array( + [[0.05, -0.02]] + ) # Shape (1, 2) + + # Mock shift_x_y_to_zk1_2_wavediff to return predictable values + mock_shift_fn = mocker.patch( + "wf_psf.data.data_zernike_utils.shift_x_y_to_zk1_2_wavediff", + side_effect=lambda shift: shift * 0.5, # Mocked conversion for test + ) + + # Define test inputs (batch of 1 image) + pixel_sampling = 12e-6 + reference_shifts = [-1 / 3, -1 / 3] # Default Euclid conditions + + # Run the function + zernike_corrections = compute_zernike_tip_tilt( + simple_image, identity_mask, pixel_sampling, reference_shifts + ) + zernike_corrections = compute_zernike_tip_tilt( + simple_image, identity_mask, pixel_sampling, reference_shifts + ) + + # Expected shifts based on centroid calculation + expected_dx = reference_shifts[1] - (-0.02) # Expected x-axis shift in meters + expected_dy = reference_shifts[0] - 0.05 # Expected y-axis shift in meters + + # Expected calls to the mocked function + # Extract the arguments passed to mock_shift_fn + args, _ = mock_shift_fn.call_args_list[0] # Get the first call args + + # Compare expected values with the actual arguments passed to the mock function + np.testing.assert_allclose( + args[0][0], expected_dx * pixel_sampling, rtol=1e-7, atol=0 + ) + + # Check dy values similarly + np.testing.assert_allclose( + args[0][1], expected_dy * pixel_sampling, rtol=1e-7, atol=0 + ) + + # Expected values based on mock side_effect (0.5 * shift) + np.testing.assert_allclose( + zernike_corrections[0, 0], expected_dx * pixel_sampling * 0.5 + ) # Zk1 + np.testing.assert_allclose( + zernike_corrections[0, 1], expected_dy * pixel_sampling * 0.5 + ) # Zk2 + + +def test_compute_zernike_tip_tilt_batch(mocker, multiple_images): + """Test compute_zernike_tip_tilt batch handling of multiple inputs.""" + # Mock the CentroidEstimator class + mock_centroid_calc = mocker.patch( + "wf_psf.data.centroids.CentroidEstimator", autospec=True + ) + + # Create a mock instance and configure get_intra_pixel_shifts() + mock_instance = mock_centroid_calc.return_value + mock_instance.get_intra_pixel_shifts.return_value = np.array( + [[0.05, -0.02], [0.04, -0.01], [0.06, -0.03]] + ) # Shape (3, 2) + + # Mock shift_x_y_to_zk1_2_wavediff to return predictable values + mock_shift_fn = mocker.patch( + "wf_psf.data.data_zernike_utils.shift_x_y_to_zk1_2_wavediff", + side_effect=lambda shift: shift * 0.5, # Mocked conversion for test + ) + + # Define test inputs (batch of 3 images) + pixel_sampling = 12e-6 + reference_shifts = [-1 / 3, -1 / 3] # Default Euclid conditions + + # Run the function + zernike_corrections = compute_zernike_tip_tilt( + star_images=multiple_images, + pixel_sampling=pixel_sampling, + reference_shifts=reference_shifts, + ) + + # Check if the mock function was called once with the full batch + assert ( + len(mock_shift_fn.call_args_list) == 1 + ), f"Expected 1 call, but got {len(mock_shift_fn.call_args_list)}" + + # Get the arguments passed to the mock function for the batch of images + args, _ = mock_shift_fn.call_args_list[0] + + print("Shape of args[0]:", args[0].shape) + print("Contents of args[0]:", args[0]) + print("Mock function call args list:", mock_shift_fn.call_args_list) + + # Reshape args[0] to (N, 2) for batch processing + args_array = np.array(args[0]).reshape(-1, 2) + + # Process the displacements and expected values for each image in the batch + expected_dx = ( + reference_shifts[1] - mock_instance.get_intra_pixel_shifts.return_value[:, 1] + ) # Expected x-axis shift in meters + expected_dy = ( + reference_shifts[0] - mock_instance.get_intra_pixel_shifts.return_value[:, 0] + ) # Expected y-axis shift in meters + + # Compare expected values with the actual arguments passed to the mock function + np.testing.assert_allclose( + args_array[:, 0], expected_dx * pixel_sampling, rtol=1e-7, atol=0 + ) + np.testing.assert_allclose( + args_array[:, 1], expected_dy * pixel_sampling, rtol=1e-7, atol=0 + ) + + # Expected values based on mock side_effect (0.5 * shift) + np.testing.assert_allclose( + zernike_corrections[:, 0], expected_dx * pixel_sampling * 0.5 + ) # Zk1 for each image + np.testing.assert_allclose( + zernike_corrections[:, 1], expected_dy * pixel_sampling * 0.5 + ) # Zk2 for each image diff --git a/src/wf_psf/tests/test_data/test_data_utils.py b/src/wf_psf/tests/test_data/test_data_utils.py new file mode 100644 index 00000000..de111427 --- /dev/null +++ b/src/wf_psf/tests/test_data/test_data_utils.py @@ -0,0 +1,36 @@ +class MockDataset: + def __init__(self, positions, zernike_priors, star_type, stars, masks): + self.dataset = { + "positions": positions, + "zernike_prior": zernike_priors, + star_type: stars, + "masks": masks, + } + + +class MockData: + def __init__( + self, + training_positions, + test_positions, + training_zernike_priors=None, + test_zernike_priors=None, + noisy_stars=None, + noisy_masks=None, + stars=None, + masks=None, + ): + self.training_data = MockDataset( + positions=training_positions, + zernike_priors=training_zernike_priors, + star_type="noisy_stars", + stars=noisy_stars, + masks=noisy_masks, + ) + self.test_data = MockDataset( + positions=test_positions, + zernike_priors=test_zernike_priors, + star_type="stars", + stars=stars, + masks=masks, + ) diff --git a/src/wf_psf/tests/test_data/training_preprocessing_test.py b/src/wf_psf/tests/test_data/training_preprocessing_test.py deleted file mode 100644 index 3efc8272..00000000 --- a/src/wf_psf/tests/test_data/training_preprocessing_test.py +++ /dev/null @@ -1,407 +0,0 @@ -import pytest -import numpy as np -import tensorflow as tf -from wf_psf.utils.read_config import RecursiveNamespace -from wf_psf.data.training_preprocessing import ( - DataHandler, - get_obs_positions, - get_zernike_prior, - extract_star_data, - compute_centroid_correction, -) -from unittest.mock import patch - - -class MockData: - def __init__( - self, - training_positions, - test_positions, - training_zernike_priors, - test_zernike_priors, - noisy_stars=None, - noisy_masks=None, - stars=None, - masks=None, - ): - self.training_data = MockDataset( - positions=training_positions, - zernike_priors=training_zernike_priors, - star_type="noisy_stars", - stars=noisy_stars, - masks=noisy_masks, - ) - self.test_data = MockDataset( - positions=test_positions, - zernike_priors=test_zernike_priors, - star_type="stars", - stars=stars, - masks=masks, - ) - - -class MockDataset: - def __init__(self, positions, zernike_priors, star_type, stars, masks): - self.dataset = { - "positions": positions, - "zernike_prior": zernike_priors, - star_type: stars, - "masks": masks, - } - - -@pytest.fixture -def mock_data(): - # Mock data for testing - # Mock training and test positions and Zernike priors - training_positions = np.array([[1, 2], [3, 4]]) - test_positions = np.array([[5, 6], [7, 8]]) - training_zernike_priors = np.array([[0.1, 0.2], [0.3, 0.4]]) - test_zernike_priors = np.array([[0.5, 0.6], [0.7, 0.8]]) - # Mock noisy stars, stars and masks - noisy_stars = tf.constant([[1, 2], [3, 4]], dtype=tf.float32) - noisy_masks = tf.constant([[1], [0]], dtype=tf.float32) - stars = tf.constant([[5, 6], [7, 8]], dtype=tf.float32) - masks = tf.constant([[0], [1]], dtype=tf.float32) - - return MockData( - training_positions, - test_positions, - training_zernike_priors, - test_zernike_priors, - noisy_stars, - noisy_masks, - stars, - masks, - ) - - -def test_load_train_dataset(tmp_path, data_params, simPSF): - # Create a temporary directory and a temporary data file - data_dir = tmp_path / "data" - data_dir.mkdir() - temp_data_dir = data_dir / "train_data.npy" - - # Mock dataset - mock_dataset = { - "positions": np.array([[1, 2], [3, 4]]), - "noisy_stars": np.array([[5, 6], [7, 8]]), - "SEDs": np.array([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]]), - } - - # Save the mock dataset to the temporary data file - np.save(temp_data_dir, mock_dataset) - - # Initialize DataHandler instance - data_params = RecursiveNamespace( - training=RecursiveNamespace(data_dir=str(data_dir), file="train_data.npy") - ) - - n_bins_lambda = 10 - data_handler = DataHandler( - "training", data_params, simPSF, n_bins_lambda, load_data=False - ) - - # Call the load_dataset method - data_handler.load_dataset() - - # Assertions - assert np.array_equal(data_handler.dataset["positions"], mock_dataset["positions"]) - assert np.array_equal( - data_handler.dataset["noisy_stars"], mock_dataset["noisy_stars"] - ) - assert np.array_equal(data_handler.dataset["SEDs"], mock_dataset["SEDs"]) - - -def test_load_test_dataset(tmp_path, data_params, simPSF): - # Create a temporary directory and a temporary data file - data_dir = tmp_path / "data" - data_dir.mkdir() - temp_data_dir = data_dir / "test_data.npy" - - # Mock dataset - mock_dataset = { - "positions": np.array([[1, 2], [3, 4]]), - "stars": np.array([[5, 6], [7, 8]]), - "SEDs": np.array([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]]), - } - - # Save the mock dataset to the temporary data file - np.save(temp_data_dir, mock_dataset) - - # Initialize DataHandler instance - data_params = RecursiveNamespace( - test=RecursiveNamespace(data_dir=str(data_dir), file="test_data.npy") - ) - - n_bins_lambda = 10 - data_handler = DataHandler( - "test", data_params, simPSF, n_bins_lambda, load_data=False - ) - - # Call the load_dataset method - data_handler.load_dataset() - - # Assertions - assert np.array_equal(data_handler.dataset["positions"], mock_dataset["positions"]) - assert np.array_equal(data_handler.dataset["stars"], mock_dataset["stars"]) - assert np.array_equal(data_handler.dataset["SEDs"], mock_dataset["SEDs"]) - - -def test_load_train_dataset_missing_noisy_stars(tmp_path, data_params, simPSF): - """Test that a warning is raised if 'noisy_stars' is missing in training data.""" - data_dir = tmp_path / "data" - data_dir.mkdir() - temp_data_file = data_dir / "train_data.npy" - - mock_dataset = { - "positions": np.array([[1, 2], [3, 4]]), # No 'noisy_stars' key - "SEDs": np.array([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]]), - } - - np.save(temp_data_file, mock_dataset) - - data_params = RecursiveNamespace( - training=RecursiveNamespace(data_dir=str(data_dir), file="train_data.npy") - ) - - n_bins_lambda = 10 - data_handler = DataHandler( - "training", data_params, simPSF, n_bins_lambda, load_data=False - ) - - with patch("wf_psf.data.training_preprocessing.logger.warning") as mock_warning: - data_handler.load_dataset() - mock_warning.assert_called_with("Missing 'noisy_stars' in training dataset.") - - -def test_load_test_dataset_missing_stars(tmp_path, data_params, simPSF): - """Test that a warning is raised if 'stars' is missing in test data.""" - data_dir = tmp_path / "data" - data_dir.mkdir() - temp_data_file = data_dir / "test_data.npy" - - mock_dataset = { - "positions": np.array([[1, 2], [3, 4]]), # No 'stars' key - "SEDs": np.array([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]]), - } - - np.save(temp_data_file, mock_dataset) - - data_params = RecursiveNamespace( - test=RecursiveNamespace(data_dir=str(data_dir), file="test_data.npy") - ) - - n_bins_lambda = 10 - data_handler = DataHandler( - "test", data_params, simPSF, n_bins_lambda, load_data=False - ) - - with patch("wf_psf.data.training_preprocessing.logger.warning") as mock_warning: - data_handler.load_dataset() - mock_warning.assert_called_with("Missing 'stars' in test dataset.") - - -def test_process_sed_data(data_params, simPSF): - mock_dataset = { - "positions": np.array([[1, 2], [3, 4]]), - "noisy_stars": np.array([[5, 6], [7, 8]]), - "SEDs": np.array([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]]), - } - # Initialize DataHandler instance - n_bins_lambda = 4 - data_handler = DataHandler("training", data_params, simPSF, n_bins_lambda, False) - - data_handler.dataset = mock_dataset - data_handler.process_sed_data() - # Assertions - assert isinstance(data_handler.sed_data, tf.Tensor) - assert data_handler.sed_data.dtype == tf.float32 - assert data_handler.sed_data.shape == ( - len(data_handler.dataset["positions"]), - n_bins_lambda, - len(["feasible_N", "feasible_wv", "SED_norm"]), - ) - - -def test_get_obs_positions(mock_data): - observed_positions = get_obs_positions(mock_data) - expected_positions = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) - assert tf.reduce_all(tf.equal(observed_positions, expected_positions)) - - -def test_get_zernike_prior(model_params, mock_data): - zernike_priors = get_zernike_prior(model_params, mock_data) - expected_shape = ( - 4, - 2, - ) # Assuming 2 Zernike priors for each dataset (training and test) - assert zernike_priors.shape == expected_shape - - -def test_get_zernike_prior_dtype(model_params, mock_data): - zernike_priors = get_zernike_prior(model_params, mock_data) - assert zernike_priors.dtype == np.float32 - - -def test_get_zernike_prior_concatenation(model_params, mock_data): - zernike_priors = get_zernike_prior(model_params, mock_data) - expected_zernike_priors = tf.convert_to_tensor( - np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]]), dtype=tf.float32 - ) - - assert np.array_equal(zernike_priors, expected_zernike_priors) - - -def test_get_zernike_prior_empty_data(model_params): - empty_data = MockData(np.array([]), np.array([]), np.array([]), np.array([])) - zernike_priors = get_zernike_prior(model_params, empty_data) - assert zernike_priors.shape == tf.TensorShape([0]) # Check for empty array shape - - -def test_extract_star_data_valid_keys(mock_data): - """Test extracting valid data from the dataset.""" - result = extract_star_data(mock_data, train_key="noisy_stars", test_key="stars") - - expected = np.array([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=np.float32) - np.testing.assert_array_equal(result, expected) - - -def test_extract_star_data_masks(mock_data): - """Test extracting star masks from the dataset.""" - result = extract_star_data(mock_data, train_key="masks", test_key="masks") - - expected = np.array([[1], [0], [0], [1]], dtype=np.float32) - np.testing.assert_array_equal(result, expected) - - -def test_extract_star_data_missing_key(mock_data): - """Test that the function raises a KeyError when a key is missing.""" - with pytest.raises(KeyError, match="Missing keys in dataset: \\['invalid_key'\\]"): - extract_star_data(mock_data, train_key="invalid_key", test_key="stars") - - -def test_extract_star_data_partially_missing_key(mock_data): - """Test that the function raises a KeyError if only one key is missing.""" - with pytest.raises( - KeyError, match="Missing keys in dataset: \\['missing_stars'\\]" - ): - extract_star_data(mock_data, train_key="noisy_stars", test_key="missing_stars") - - -def test_extract_star_data_tensor_conversion(mock_data): - """Test that the function properly converts TensorFlow tensors to NumPy arrays.""" - result = extract_star_data(mock_data, train_key="noisy_stars", test_key="stars") - - assert isinstance(result, np.ndarray), "The result should be a NumPy array" - assert result.dtype == np.float32, "The NumPy array should have dtype float32" - - -def test_compute_centroid_correction_with_masks(mock_data): - """Test compute_centroid_correction function with masks present.""" - # Given that compute_centroid_correction expects a model_params and data object - model_params = RecursiveNamespace( - pix_sampling=12e-6, # Example pixel sampling in meters - correct_centroids=True, - reference_shifts=["-1/3", "-1/3"], - ) - - # Mock the internal function calls: - with ( - patch( - "wf_psf.data.training_preprocessing.extract_star_data" - ) as mock_extract_star_data, - patch( - "wf_psf.data.training_preprocessing.compute_zernike_tip_tilt" - ) as mock_compute_zernike_tip_tilt, - ): - # Mock the return values of extract_star_data and compute_zernike_tip_tilt - mock_extract_star_data.side_effect = lambda data, train_key, test_key: ( - np.array([[1, 2], [3, 4]]) - if train_key == "noisy_stars" - else np.array([[5, 6], [7, 8]]) - ) - mock_compute_zernike_tip_tilt.return_value = np.array([[0.1, 0.2], [0.3, 0.4]]) - - # Call the function under test - result = compute_centroid_correction(model_params, mock_data) - - # Ensure the result has the correct shape - assert result.shape == (4, 3) # Should be (n_stars, 3 Zernike components) - - assert np.allclose( - result[0, :], np.array([0, -0.1, -0.2]) - ) # First star Zernike coefficients - assert np.allclose( - result[1, :], np.array([0, -0.3, -0.4]) - ) # Second star Zernike coefficients - - -def test_compute_centroid_correction_without_masks(mock_data): - """Test compute_centroid_correction function when no masks are provided.""" - # Remove masks from mock_data - mock_data.test_data.dataset["masks"] = None - mock_data.training_data.dataset["masks"] = None - - # Define model parameters - model_params = RecursiveNamespace( - pix_sampling=12e-6, # Example pixel sampling in meters - correct_centroids=True, - reference_shifts=["-1/3", "-1/3"], - ) - - # Mock internal function calls - with ( - patch( - "wf_psf.data.training_preprocessing.extract_star_data" - ) as mock_extract_star_data, - patch( - "wf_psf.data.training_preprocessing.compute_zernike_tip_tilt" - ) as mock_compute_zernike_tip_tilt, - ): - # Mock extract_star_data to return synthetic star postage stamps - mock_extract_star_data.side_effect = lambda data, train_key, test_key: ( - np.array([[1, 2], [3, 4]]) - if train_key == "noisy_stars" - else np.array([[5, 6], [7, 8]]) - ) - - # Mock compute_zernike_tip_tilt assuming no masks - mock_compute_zernike_tip_tilt.return_value = np.array([[0.1, 0.2], [0.3, 0.4]]) - - # Call function under test - result = compute_centroid_correction(model_params, mock_data) - - # Validate result shape - assert result.shape == (4, 3) # (n_stars, 3 Zernike components) - - # Validate expected values (adjust based on behavior) - expected_result = np.array( - [ - [0, -0.1, -0.2], # From training data - [0, -0.3, -0.4], - [0, -0.1, -0.2], # From test data (reused mocked return) - [0, -0.3, -0.4], - ] - ) - assert np.allclose(result, expected_result) - - -def test_reference_shifts_broadcasting(): - reference_shifts = [-1 / 3, -1 / 3] # Example reference_shifts - shifts = np.random.rand(2, 2400) # Example shifts array - - # Ensure reference_shifts is a NumPy array (if it's not already) - reference_shifts = np.array(reference_shifts) - - # Broadcast reference_shifts to match the shape of shifts - reference_shifts = np.broadcast_to( - reference_shifts[:, None], shifts.shape - ) # Shape will be (2, 2400) - - # Ensure shapes are compatible for subtraction - displacements = reference_shifts - shifts - - # Test the result - assert displacements.shape == shifts.shape, "Shapes do not match" - assert np.all(displacements.shape == (2, 2400)), "Broadcasting failed" diff --git a/src/wf_psf/tests/test_inference/psf_inference_test.py b/src/wf_psf/tests/test_inference/psf_inference_test.py new file mode 100644 index 00000000..4cff7a13 --- /dev/null +++ b/src/wf_psf/tests/test_inference/psf_inference_test.py @@ -0,0 +1,570 @@ +"""UNIT TESTS FOR PACKAGE MODULE: PSF Inference. + +This module contains unit tests for the wf_psf.inference.psf_inference module. + +:Author: Jennifer Pollack + +""" + +import numpy as np +import os +from pathlib import Path +import pytest +import tensorflow as tf +from types import SimpleNamespace +from unittest.mock import MagicMock, patch, PropertyMock +from wf_psf.inference.psf_inference import ( + InferenceConfigHandler, + PSFInference, + PSFInferenceEngine, +) + +from wf_psf.utils.read_config import RecursiveNamespace + + +def _patch_data_handler(): + """Helper for patching data_handler to avoid full PSF logic.""" + patcher = patch.object(PSFInference, "data_handler", new_callable=PropertyMock) + mock_data_handler = patcher.start() + mock_instance = MagicMock() + mock_data_handler.return_value = mock_instance + + def fake_process(x): + mock_instance.sed_data = tf.convert_to_tensor(x) + + mock_instance.process_sed_data.side_effect = fake_process + return patcher, mock_instance + + +@pytest.fixture +def mock_training_config(): + training_config = RecursiveNamespace( + training=RecursiveNamespace( + id_name="mock_id", + model_params=RecursiveNamespace( + model_name="mock_model", + output_Q=2, + output_dim=32, + pupil_diameter=256, + oversampling_rate=3, + interpolation_type=None, + interpolation_args=None, + sed_interp_pts_per_bin=0, + sed_extrapolate=True, + sed_interp_kind="linear", + sed_sigma=0, + x_lims=[0.0, 1000.0], + y_lims=[0.0, 1000.0], + pix_sampling=12, + tel_diameter=1.2, + tel_focal_length=24.5, + euclid_obsc=True, + LP_filter_length=3, + param_hparams=RecursiveNamespace( + n_zernikes=10, + ), + ), + ) + ) + return training_config + + +@pytest.fixture +def mock_inference_config(): + inference_config = RecursiveNamespace( + inference=RecursiveNamespace( + batch_size=16, + cycle=2, + configs=RecursiveNamespace( + trained_model_path="/path/to/trained/model", + model_subdir="psf_model", + trained_model_config_path="config/training_config.yaml", + data_config_path=None, + ), + model_params=RecursiveNamespace(n_bins_lda=8, output_Q=1, output_dim=64), + ) + ) + return inference_config + + +@pytest.fixture +def psf_test_setup(mock_inference_config): + num_sources = 2 + num_bins = 10 + output_dim = 32 + + mock_positions = tf.convert_to_tensor([[0.1, 0.1], [0.2, 0.2]], dtype=tf.float32) + mock_seds = tf.convert_to_tensor( + np.random.rand(num_sources, num_bins, 2), dtype=tf.float32 + ) + expected_psfs = np.random.rand(num_sources, output_dim, output_dim).astype( + np.float32 + ) + + inference = PSFInference( + "dummy_path.yaml", + x_field=[0.1, 0.2], + y_field=[0.1, 0.2], + seds=np.random.rand(num_sources, num_bins, 2), + ) + inference._config_handler = MagicMock() + inference._config_handler.inference_config = mock_inference_config + inference._trained_psf_model = MagicMock() + + return { + "inference": inference, + "mock_positions": mock_positions, + "mock_seds": mock_seds, + "expected_psfs": expected_psfs, + "num_sources": num_sources, + "num_bins": num_bins, + "output_dim": output_dim, + } + + +@pytest.fixture +def psf_single_star_setup(mock_inference_config): + num_sources = 1 + num_bins = 10 + output_dim = 32 + + # Single position + mock_positions = tf.convert_to_tensor([[0.1, 0.1]], dtype=tf.float32) + # Shape (1, 2, num_bins) + mock_seds = tf.convert_to_tensor( + np.random.rand(num_sources, 2, num_bins), dtype=tf.float32 + ) + expected_psfs = np.random.rand(num_sources, output_dim, output_dim).astype( + np.float32 + ) + + inference = PSFInference( + "dummy_path.yaml", + x_field=0.1, # scalar for single star + y_field=0.1, + seds=np.random.rand(num_bins, 2), # shape (num_bins, 2) before batching + ) + inference._config_handler = MagicMock() + inference._config_handler.inference_config = mock_inference_config + inference._trained_psf_model = MagicMock() + + return { + "inference": inference, + "mock_positions": mock_positions, + "mock_seds": mock_seds, + "expected_psfs": expected_psfs, + "num_sources": num_sources, + "num_bins": num_bins, + "output_dim": output_dim, + } + + +@pytest.fixture +def mock_compute_psfs_with_cache(psf_test_setup): + """ + Fixture that patches PSFInferenceEngine.compute_psfs with a side effect + that populates the engine's cache. + + Returns + ------- + dict + Dictionary containing: + - 'mock': The mock object for compute_psfs + - 'inference': The PSFInference instance + - 'positions': Mock positions tensor + - 'seds': Mock SEDs tensor + - 'expected_psfs': Expected PSF array + """ + inference = psf_test_setup["inference"] + mock_positions = psf_test_setup["mock_positions"] + mock_seds = psf_test_setup["mock_seds"] + expected_psfs = psf_test_setup["expected_psfs"] + + with patch.object(PSFInferenceEngine, "compute_psfs") as mock_compute_psfs: + + def fake_compute_psfs(positions, seds): + inference.engine._inferred_psfs = expected_psfs + return expected_psfs + + mock_compute_psfs.side_effect = fake_compute_psfs + + yield { + "mock": mock_compute_psfs, + "inference": inference, + "positions": mock_positions, + "seds": mock_seds, + "expected_psfs": expected_psfs, + } + + +def test_set_config_paths(mock_inference_config): + """Test setting configuration paths.""" + # Initialize handler and inject mock config + config_handler = InferenceConfigHandler("fake/path") + config_handler.inference_config = mock_inference_config + + # Call the method under test + config_handler.set_config_paths() + + # Assertions + assert config_handler.trained_model_path == Path("/path/to/trained/model") + assert config_handler.model_subdir == "psf_model" + assert config_handler.trained_model_config_path == Path( + "/path/to/trained/model/config/training_config.yaml" + ) + assert config_handler.data_config_path == None + + +def test_overwrite_model_params(mock_training_config, mock_inference_config): + """Test that model_params can be overwritten.""" + # Mock the model_params object with some initial values + training_config = mock_training_config + inference_config = mock_inference_config + + InferenceConfigHandler.overwrite_model_params(training_config, inference_config) + + # Assert that the model_params were overwritten correctly + assert ( + training_config.training.model_params.output_Q == 1 + ), "output_Q should be overwritten" + assert ( + training_config.training.model_params.output_dim == 64 + ), "output_dim should be overwritten" + + assert ( + training_config.training.id_name == "mock_id" + ), "id_name should not be overwritten" + + +def test_prepare_configs(mock_training_config, mock_inference_config): + """Test preparing configurations for inference.""" + # Mock the model_params object with some initial values + training_config = mock_training_config + inference_config = mock_inference_config + + # Make copy of the original training config model_params + original_model_params = mock_training_config.training.model_params + + # Instantiate PSFInference + psf_inf = PSFInference("/dummy/path.yaml") + + # Mock the config handler attribute with a mock InferenceConfigHandler + mock_config_handler = MagicMock(spec=InferenceConfigHandler) + mock_config_handler.training_config = training_config + mock_config_handler.inference_config = inference_config + + # Patch the overwrite_model_params to use the real static method + mock_config_handler.overwrite_model_params.side_effect = ( + InferenceConfigHandler.overwrite_model_params + ) + + psf_inf._config_handler = mock_config_handler + + # Run prepare_configs + psf_inf.prepare_configs() + + # Assert that the training model_params were updated + assert original_model_params.output_Q == 1 + assert original_model_params.output_dim == 64 + + +def test_config_handler_lazy_load(monkeypatch): + inference = PSFInference("dummy_path.yaml") + + called = {} + + class DummyHandler: + def load_configs(self): + called["load"] = True + self.inference_config = {} + self.training_config = {} + self.data_config = {} + + def overwrite_model_params(self, *args): + pass + + monkeypatch.setattr( + "wf_psf.inference.psf_inference.InferenceConfigHandler", + lambda path: DummyHandler(), + ) + + inference.prepare_configs() + + assert "load" in called # Confirm lazy load happened + + +def test_batch_size_positive(): + inference = PSFInference("dummy_path.yaml") + inference._config_handler = MagicMock() + inference._config_handler.inference_config = SimpleNamespace( + inference=SimpleNamespace( + batch_size=4, model_params=SimpleNamespace(output_dim=32) + ) + ) + assert inference.batch_size == 4 + + +@patch("wf_psf.inference.psf_inference.DataHandler") +@patch("wf_psf.inference.psf_inference.load_trained_psf_model") +def test_load_inference_model( + mock_load_trained_psf_model, + mock_data_handler, + mock_training_config, + mock_inference_config, +): + mock_data_config = MagicMock() + mock_data_handler.return_value = mock_data_config + mock_config_handler = MagicMock(spec=InferenceConfigHandler) + mock_config_handler.trained_model_path = "mock/path/to/model" + mock_config_handler.training_config = mock_training_config + mock_config_handler.inference_config = mock_inference_config + mock_config_handler.model_subdir = "psf_model" + mock_config_handler.data_config = MagicMock() + + psf_inf = PSFInference("dummy_path.yaml") + psf_inf._config_handler = mock_config_handler + + psf_inf.load_inference_model() + + weights_path_pattern = os.path.join( + mock_config_handler.trained_model_path, + mock_config_handler.model_subdir, + f"{mock_config_handler.model_subdir}*_{mock_config_handler.training_config.training.model_params.model_name}*{mock_config_handler.training_config.training.id_name}_cycle{mock_config_handler.inference_config.inference.cycle}*", + ) + + # Assert calls to the mocked methods + mock_load_trained_psf_model.assert_called_once_with( + mock_training_config, mock_data_config, weights_path_pattern + ) + + +@patch.object(PSFInference, "prepare_configs") +@patch.object(PSFInference, "_prepare_positions_and_seds") +@patch.object(PSFInferenceEngine, "compute_psfs") +def test_run_inference( + mock_compute_psfs, + mock_prepare_positions_and_seds, + mock_prepare_configs, + psf_test_setup, +): + inference = psf_test_setup["inference"] + mock_positions = psf_test_setup["mock_positions"] + mock_seds = psf_test_setup["mock_seds"] + expected_psfs = psf_test_setup["expected_psfs"] + + mock_prepare_positions_and_seds.return_value = (mock_positions, mock_seds) + mock_compute_psfs.return_value = expected_psfs + + psfs = inference.run_inference() + + assert isinstance(psfs, np.ndarray) + assert psfs.shape == expected_psfs.shape + mock_prepare_positions_and_seds.assert_called_once() + mock_compute_psfs.assert_called_once_with(mock_positions, mock_seds) + mock_prepare_configs.assert_called_once() + + +@patch("wf_psf.inference.psf_inference.psf_models.simPSF") +def test_simpsf_uses_updated_model_params( + mock_simpsf, mock_training_config, mock_inference_config +): + """Test that simPSF uses the updated model parameters.""" + training_config = mock_training_config + inference_config = mock_inference_config + + # Set the expected output_Q + expected_output_Q = inference_config.inference.model_params.output_Q + training_config.training.model_params.output_Q = expected_output_Q + + # Create fake psf instance + fake_psf_instance = MagicMock() + fake_psf_instance.output_Q = expected_output_Q + mock_simpsf.return_value = fake_psf_instance + + mock_config_handler = MagicMock(spec=InferenceConfigHandler) + mock_config_handler.trained_model_path = "mock/path/to/model" + mock_config_handler.training_config = training_config + mock_config_handler.inference_config = inference_config + mock_config_handler.model_subdir = "psf_model" + mock_config_handler.data_config = MagicMock() + + modeller = PSFInference("dummy_path.yaml") + modeller._config_handler = mock_config_handler + + modeller.prepare_configs() + result = modeller.simPSF + + # Confirm simPSF was called once with the updated model_params + mock_simpsf.assert_called_once() + called_args, _ = mock_simpsf.call_args + model_params_passed = called_args[0] + assert model_params_passed.output_Q == expected_output_Q + assert result.output_Q == expected_output_Q + + +@patch.object(PSFInference, "_prepare_positions_and_seds") +def test_get_psfs_runs_inference( + mock_prepare_positions_and_seds, mock_compute_psfs_with_cache +): + """Test that get_psfs uses cached PSFs after first computation.""" + mock = mock_compute_psfs_with_cache["mock"] + inference = mock_compute_psfs_with_cache["inference"] + mock_positions = mock_compute_psfs_with_cache["positions"] + mock_seds = mock_compute_psfs_with_cache["seds"] + expected_psfs = mock_compute_psfs_with_cache["expected_psfs"] + + mock_prepare_positions_and_seds.return_value = (mock_positions, mock_seds) + + psfs_1 = inference.get_psfs() + assert np.all(psfs_1 == expected_psfs) + + psfs_2 = inference.get_psfs() + assert np.all(psfs_2 == expected_psfs) + + assert mock.call_count == 1 + + +def test_single_star_inference_shape(psf_single_star_setup): + setup = psf_single_star_setup + + _, mock_instance = _patch_data_handler() + + # Run the method under test + positions, sed_data = setup["inference"]._prepare_positions_and_seds() + + # Check shapes + assert sed_data.shape == (1, setup["num_bins"], 2) + assert positions.shape == (1, 2) + + # Verify the call happened + mock_instance.process_sed_data.assert_called_once() + args, _ = mock_instance.process_sed_data.call_args + input_array = args[0] + + # Check input SED had the right shape before being tensorized + assert input_array.shape == ( + 1, + setup["num_bins"], + 2, + ), "process_sed_data should have been called with shape (1, num_bins, 2)" + + +def test_multiple_star_inference_shape(psf_test_setup): + """Test that _prepare_positions_and_seds returns correct shapes for multiple stars.""" + setup = psf_test_setup + + _, mock_instance = _patch_data_handler() + + # Run the method under test + positions, sed_data = setup["inference"]._prepare_positions_and_seds() + + # Check shapes + assert sed_data.shape == (2, setup["num_bins"], 2) + assert positions.shape == (2, 2) + + # Verify the call happened + mock_instance.process_sed_data.assert_called_once() + args, _ = mock_instance.process_sed_data.call_args + input_array = args[0] + + # Check input SED had the right shape before being tensorized + assert input_array.shape == ( + 2, + setup["num_bins"], + 2, + ), "process_sed_data should have been called with shape (2, num_bins, 2)" + + +def test_valueerror_on_mismatched_batches(psf_single_star_setup): + """Raise if sed_data batch size != positions batch size and sed_data != 1.""" + setup = psf_single_star_setup + inference = setup["inference"] + + patcher, _ = _patch_data_handler() + try: + # Force sed_data to have 2 sources while positions has 1 + bad_sed = np.ones((2, setup["num_bins"], 2), dtype=np.float32) + + # Replace fixture's sed_data with mismatched one + inference.seds = bad_sed + inference.positions = np.ones((1, 2), dtype=np.float32) + + with pytest.raises( + ValueError, match="SEDs batch size 2 does not match number of positions 1" + ): + inference._prepare_positions_and_seds() + finally: + patcher.stop() + + +def test_valueerror_on_mismatched_positions(psf_single_star_setup): + """Raise if positions batch size != sed_data batch size (opposite mismatch).""" + setup = psf_single_star_setup + inference = setup["inference"] + + patcher, _ = _patch_data_handler() + try: + # Force positions to have 3 entries while sed_data has 2 + bad_sed = np.ones((2, setup["num_bins"], 2), dtype=np.float32) + inference.seds = bad_sed + inference.x_field = np.ones((3, 1), dtype=np.float32) + inference.y_field = np.ones((3, 1), dtype=np.float32) + + with pytest.raises( + ValueError, match="SEDs batch size 2 does not match number of positions 3" + ): + inference._prepare_positions_and_seds() + finally: + patcher.stop() + + +def test_inference_clear_cache(psf_test_setup): + """Test that PSFInference clear_cache resets the instance of PSFInference.""" + inference = psf_test_setup["inference"] + inference._simPSF = MagicMock() + inference._data_handler = MagicMock() + inference._trained_psf_model = MagicMock() + inference._n_bins_lambda = MagicMock() + inference._batch_size = MagicMock() + inference._cycle = MagicMock() + inference._output_dim = MagicMock() + inference.engine = MagicMock() + + # Clear the cache + inference.clear_cache() + + # Check that the internal cache is None + assert ( + inference._config_handler == None, + inference._simPSF == None, + inference._data_handler == None, + inference._trained_psf_model == None, + inference._n_bins_lambda == None, + inference._batch_size == None, + inference._cycle == None, + inference._output_dim == None, + inference.engine == None, + ), "Inference attributes should be cleared to None" # type: ignore + + +def test_engine_clear_cache(psf_test_setup): + """Test that clear_cache resets the internal PSF cache.""" + inference = psf_test_setup["inference"] + expected_psfs = psf_test_setup["expected_psfs"] + + # Create the engine and compute PSFs + inference.engine = PSFInferenceEngine( + trained_model=inference.trained_psf_model, + batch_size=inference.batch_size, + output_dim=inference.output_dim, + ) + + inference.engine._inferred_psfs = expected_psfs + + # Clear the cache + inference.engine.clear_cache() + + # Check that the internal cache is None + assert ( + inference.engine._inferred_psfs is None + ), "PSF cache should be cleared to None" diff --git a/src/wf_psf/tests/test_metrics/metrics_interface_test.py b/src/wf_psf/tests/test_metrics/metrics_interface_test.py index 44bd09bf..35f3b9a1 100644 --- a/src/wf_psf/tests/test_metrics/metrics_interface_test.py +++ b/src/wf_psf/tests/test_metrics/metrics_interface_test.py @@ -1,7 +1,7 @@ from unittest.mock import patch, MagicMock import pytest from wf_psf.metrics.metrics_interface import evaluate_model -from wf_psf.data.training_preprocessing import DataHandler +from wf_psf.data.data_handler import DataHandler @pytest.fixture @@ -106,7 +106,6 @@ def test_evaluate_model_flags( trained_model_params=mock_trained_model_params, data=mock_data, psf_model=mock_psf_model, - weights_path="/mock/path", metrics_output="/mock/output", ) @@ -134,7 +133,6 @@ def test_missing_ground_truth_model_raises( trained_model_params=mock_trained_model_params, data=mock_data, psf_model=mock_psf_model, - weights_path="/mock/weights/path", metrics_output="/mock/metrics/output", ) @@ -168,7 +166,6 @@ def test_plotting_config_passed( trained_model_params=mock_trained_model_params, data=mock_data, psf_model=mock_psf_model, - weights_path="/mock/path", metrics_output="/mock/output", ) @@ -200,8 +197,6 @@ def test_evaluate_model( ) as mock_evaluate_shape_results_dict, patch("numpy.save", new_callable=MagicMock) as mock_np_save, ): - # Mock the logger - _ = mocker.patch("wf_psf.metrics.metrics_interface.logger") # Call evaluate_model evaluate_model( @@ -209,7 +204,6 @@ def test_evaluate_model( trained_model_params=mock_trained_model_params, data=mock_data, psf_model=mock_psf_model, - weights_path="/mock/weights/path", metrics_output="/mock/metrics/output", ) diff --git a/src/wf_psf/tests/test_psf_models/conftest.py b/src/wf_psf/tests/test_psf_models/conftest.py index cbaae8d9..4693b343 100644 --- a/src/wf_psf/tests/test_psf_models/conftest.py +++ b/src/wf_psf/tests/test_psf_models/conftest.py @@ -12,7 +12,7 @@ from wf_psf.utils.read_config import RecursiveNamespace from wf_psf.training.train import TrainingParamsHandler from wf_psf.psf_models import psf_models -from wf_psf.data.training_preprocessing import DataHandler +from wf_psf.data.data_handler import DataHandler training_config = RecursiveNamespace( id_name="_sample_w_bis1_2k", diff --git a/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py b/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py index cae7b141..e900a6d3 100644 --- a/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py +++ b/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py @@ -9,7 +9,8 @@ import pytest import numpy as np import tensorflow as tf -from wf_psf.psf_models.psf_model_physical_polychromatic import ( +from unittest.mock import patch +from wf_psf.psf_models.models.psf_model_physical_polychromatic import ( TFPhysicalPolychromaticField, ) from wf_psf.utils.configs_handler import DataConfigHandler @@ -28,14 +29,27 @@ def zks_prior(): @pytest.fixture -def mock_data(mocker): +def mock_data(mocker, zks_prior): mock_instance = mocker.Mock(spec=DataConfigHandler) - # Configure the mock data object to have the necessary attributes + mock_instance.run_type = "training" + + training_dataset = { + "positions": np.array([[1, 2], [3, 4]]), + "zernike_prior": zks_prior, + "noisy_stars": np.zeros((2, 1, 1, 1)), + } + test_dataset = { + "positions": np.array([[5, 6], [7, 8]]), + "zernike_prior": zks_prior, + "stars": np.zeros((2, 1, 1, 1)), + } + mock_instance.training_data = mocker.Mock() - mock_instance.training_data.dataset = {"positions": np.array([[1, 2], [3, 4]])} + mock_instance.training_data.dataset = training_dataset mock_instance.test_data = mocker.Mock() - mock_instance.test_data.dataset = {"positions": np.array([[5, 6], [7, 8]])} - mock_instance.batch_size = 32 + mock_instance.test_data.dataset = test_dataset + mock_instance.batch_size = 16 + return mock_instance @@ -47,256 +61,57 @@ def mock_model_params(mocker): return model_params_mock -def test_initialize_parameters(mocker, mock_data, mock_model_params, zks_prior): - # Create mock objects for model_params, training_params - # model_params_mock = mocker.MagicMock() - mock_training_params = mocker.Mock() - - # Mock internal methods called during initialization - mocker.patch( - "wf_psf.psf_models.psf_model_physical_polychromatic.get_zernike_prior", - return_value=zks_prior, - ) - - mocker.patch( - "wf_psf.data.training_preprocessing.get_obs_positions", return_value=True - ) - - # Create TFPhysicalPolychromaticField instance - field_instance = TFPhysicalPolychromaticField( - mock_model_params, mock_training_params, mock_data - ) - - mocker.patch.object(field_instance, "_initialize_zernike_parameters") - mocker.patch.object(field_instance, "_initialize_layers") - mocker.patch.object(field_instance, "assign_coeff_matrix") - - # Call the method being tested - field_instance._initialize_parameters_and_layers( - mock_model_params, mock_training_params, mock_data - ) - - # Check if internal methods were called with the correct arguments - field_instance._initialize_zernike_parameters.assert_called_once_with( - mock_model_params, mock_data - ) - field_instance._initialize_layers.assert_called_once_with( - mock_model_params, mock_training_params - ) - field_instance.assign_coeff_matrix.assert_not_called() # Because coeff_mat is None in this test - - -def test_initialize_zernike_parameters(mocker, mock_model_params, mock_data, zks_prior): - # Create training params mock object - mock_training_params = mocker.Mock() - - # Mock internal methods called during initialization - mocker.patch( - "wf_psf.psf_models.psf_model_physical_polychromatic.get_zernike_prior", - return_value=zks_prior, - ) - - # Create TFPhysicalPolychromaticField instance - field_instance = TFPhysicalPolychromaticField( - mock_model_params, mock_training_params, mock_data - ) - - # Assert that the attributes are set correctly - # assert field_instance.n_zernikes == mock_model_params.param_hparams.n_zernikes - assert np.array_equal(field_instance.zks_prior.numpy(), zks_prior.numpy()) - assert field_instance.n_zks_total == mock_model_params.param_hparams.n_zernikes - assert isinstance( - field_instance.zernike_maps, tf.Tensor - ) # Check if the returned value is a TensorFlow tensor - assert ( - field_instance.zernike_maps.dtype == tf.float32 - ) # Check if the data type of the tensor is float32 - - # Expected shape of the tensor based on the input parameters - expected_shape = ( - field_instance.n_zks_total, - mock_model_params.pupil_diameter, - mock_model_params.pupil_diameter, - ) - assert field_instance.zernike_maps.shape == expected_shape - - # Modify model_params to simulate zks_prior > n_zernikes - mock_model_params.param_hparams.n_zernikes = 2 - - # Call the method again to initialize the parameters - field_instance._initialize_zernike_parameters(mock_model_params, mock_data) - - assert field_instance.n_zks_total == tf.cast( - tf.shape(field_instance.zks_prior)[1], tf.int32 - ) - # Expected shape of the tensor based on the input parameters - expected_shape = ( - field_instance.n_zks_total, - mock_model_params.pupil_diameter, - mock_model_params.pupil_diameter, - ) - assert field_instance.zernike_maps.shape == expected_shape - - -def test_initialize_physical_layer_mocking( - mocker, mock_model_params, mock_data, zks_prior -): - # Create training params mock object - mock_training_params = mocker.Mock() - - # Mock internal methods called during initialization - mocker.patch( - "wf_psf.psf_models.psf_model_physical_polychromatic.get_zernike_prior", - return_value=zks_prior, - ) - - # Create a mock for the TFPhysicalLayer class - mock_physical_layer_class = mocker.patch( - "wf_psf.psf_models.psf_model_physical_polychromatic.TFPhysicalLayer" - ) - - # Create TFPhysicalPolychromaticField instance - field_instance = TFPhysicalPolychromaticField( - mock_model_params, mock_training_params, mock_data - ) - - # Assert that the TFPhysicalLayer class was called with the expected arguments - mock_physical_layer_class.assert_called_once_with( - field_instance.obs_pos, - field_instance.zks_prior, - interpolation_type=mock_model_params.interpolation_type, - interpolation_args=mock_model_params.interpolation_args, - ) - - @pytest.fixture -def physical_layer_instance(mocker, mock_model_params, mock_data, zks_prior): - # Create training params mock object - mock_training_params = mocker.Mock() - - # Mock internal methods called during initialization - mocker.patch( - "wf_psf.psf_models.psf_model_physical_polychromatic.get_zernike_prior", - return_value=zks_prior, - ) +def physical_layer_instance(mocker, mock_model_params, mock_data): + # Patch expensive methods during construction to avoid errors + with patch( + "wf_psf.psf_models.models.psf_model_physical_polychromatic.TFPhysicalPolychromaticField._assemble_zernike_contributions", + return_value=tf.constant([[[[1.0]]], [[[2.0]]]]), + ): + from wf_psf.psf_models.models.psf_model_physical_polychromatic import ( + TFPhysicalPolychromaticField, + ) + + instance = TFPhysicalPolychromaticField( + mock_model_params, mocker.Mock(), mock_data + ) + return instance - # Create a mock for the TFPhysicalLayer class - mocker.patch("wf_psf.psf_models.psf_model_physical_polychromatic.TFPhysicalLayer") - # Create TFPhysicalPolychromaticField instance - psf_field_instance = TFPhysicalPolychromaticField( - mock_model_params, mock_training_params, mock_data - ) - return psf_field_instance - - -def test_pad_zernikes_num_of_zernikes_equal(physical_layer_instance): - # Define input tensors with same length and num of Zernikes - zk_param = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) - zk_prior = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) - - # Reshape the tensors to have the desired shapes - zk_param = tf.reshape(zk_param, (1, 2, 1, 1)) # Reshaping tensor1 to (1, 2, 1, 1) - zk_prior = tf.reshape(zk_prior, (1, 2, 1, 1)) # Reshaping tensor2 to (1, 2, 1, 1) - - # Reset n_zks_total attribute - physical_layer_instance.n_zks_total = max( - tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() - ) - # Call the method under test - padded_zk_param, padded_zk_prior = physical_layer_instance.pad_zernikes( - zk_param, zk_prior - ) - - # Assert that the padded tensors have the correct shapes - assert padded_zk_param.shape == (1, 2, 1, 1) - assert padded_zk_prior.shape == (1, 2, 1, 1) - - -def test_pad_zernikes_prior_greater_than_param(physical_layer_instance): - zk_param = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) - zk_prior = tf.constant([[[[1]], [[2]], [[3]], [[4]], [[5]]]]) # Shape: (5, 1, 1, 1) - - # Reshape the tensors to have the desired shapes - zk_param = tf.reshape(zk_param, (1, 2, 1, 1)) # Reshaping tensor1 to (1, 2, 1, 1) - zk_prior = tf.reshape(zk_prior, (1, 5, 1, 1)) # Reshaping tensor2 to (1, 5, 1, 1) - - # Reset n_zks_total attribute - physical_layer_instance.n_zks_total = max( - tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() - ) - - # Call the method under test - padded_zk_param, padded_zk_prior = physical_layer_instance.pad_zernikes( - zk_param, zk_prior +def test_compute_zernikes(mocker, physical_layer_instance): + # Expected output of mock components + padded_zernike_param = tf.constant( + [[[[10]], [[20]], [[30]], [[40]]]], dtype=tf.float32 ) - - # Assert that the padded tensors have the correct shapes - assert padded_zk_param.shape == (1, 5, 1, 1) - assert padded_zk_prior.shape == (1, 5, 1, 1) - - -def test_pad_zernikes_param_greater_than_prior(physical_layer_instance): - zk_param = tf.constant([[[[10]], [[20]], [[30]], [[40]]]]) # Shape: (4, 1, 1, 1) - zk_prior = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) - - # Reshape the tensors to have the desired shapes - zk_param = tf.reshape(zk_param, (1, 4, 1, 1)) # Reshaping tensor1 to (1, 2, 1, 1) - zk_prior = tf.reshape(zk_prior, (1, 2, 1, 1)) # Reshaping tensor2 to (1, 4, 1, 1) - - # Reset n_zks_total attribute - physical_layer_instance.n_zks_total = max( - tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() + padded_zernike_prior = tf.constant([[[[1]], [[2]], [[0]], [[0]]]], dtype=tf.float32) + n_zks_total = physical_layer_instance.n_zks_total + expected_values_list = [11, 22, 30, 40] + [0] * (n_zks_total - 4) + expected_values = tf.constant( + [[[[v]] for v in expected_values_list]], dtype=tf.float32 ) - - # Call the method under test - padded_zk_param, padded_zk_prior = physical_layer_instance.pad_zernikes( - zk_param, zk_prior + # Patch tf_poly_Z_field method + mocker.patch.object( + TFPhysicalPolychromaticField, + "tf_poly_Z_field", + return_value=padded_zernike_param, ) - # Assert that the padded tensors have the correct shapes - assert padded_zk_param.shape == (1, 4, 1, 1) - assert padded_zk_prior.shape == (1, 4, 1, 1) - - -def test_compute_zernikes(mocker, physical_layer_instance): - # Mock padded tensors - padded_zk_param = tf.constant( - [[[[10]], [[20]], [[30]], [[40]]]] - ) # Shape: (1, 4, 1, 1) - padded_zk_prior = tf.constant([[[[1]], [[2]], [[0]], [[0]]]]) # Shape: (1, 4, 1, 1) - - # Reset n_zks_total attribute - physical_layer_instance.n_zks_total = 4 # Assuming a specific value for simplicity - - # Define the mock return values for tf_poly_Z_field and tf_physical_layer.call - padded_zernike_param = tf.constant( - [[[[10]], [[20]], [[30]], [[40]]]] - ) # Shape: (1, 4, 1, 1) - padded_zernike_prior = tf.constant( - [[[[1]], [[2]], [[0]], [[0]]]] - ) # Shape: (1, 4, 1, 1) - + # Patch tf_physical_layer.call method + mock_tf_physical_layer = mocker.Mock() + mock_tf_physical_layer.call.return_value = padded_zernike_prior mocker.patch.object( - physical_layer_instance, "tf_poly_Z_field", return_value=padded_zk_param + TFPhysicalPolychromaticField, "tf_physical_layer", mock_tf_physical_layer ) - mocker.patch.object(physical_layer_instance, "call", return_value=padded_zk_prior) - mocker.patch.object( - physical_layer_instance, - "pad_zernikes", + + # Patch pad_tf_zernikes function + mocker.patch( + "wf_psf.data.data_zernike_utils.pad_tf_zernikes", return_value=(padded_zernike_param, padded_zernike_prior), ) - # Call the method under test + # Run the test zernike_coeffs = physical_layer_instance.compute_zernikes(tf.constant([[0.0, 0.0]])) - # Define the expected values - expected_values = tf.constant( - [[[[11]], [[22]], [[30]], [[40]]]] - ) # Shape: (1, 4, 1, 1) - - # Assert that the shapes are equal + # Assertions + tf.debugging.assert_equal(zernike_coeffs, expected_values) assert zernike_coeffs.shape == expected_values.shape - - # Assert that the tensor values are equal - assert tf.reduce_all(tf.equal(zernike_coeffs, expected_values)) diff --git a/src/wf_psf/tests/test_psf_models/psf_models_test.py b/src/wf_psf/tests/test_psf_models/psf_models_test.py index 066e1328..2b907eff 100644 --- a/src/wf_psf/tests/test_psf_models/psf_models_test.py +++ b/src/wf_psf/tests/test_psf_models/psf_models_test.py @@ -7,8 +7,8 @@ """ -from wf_psf.psf_models import ( - psf_models, +from wf_psf.psf_models import psf_models +from wf_psf.psf_models.models import ( psf_model_semiparametric, psf_model_physical_polychromatic, ) diff --git a/src/wf_psf/tests/test_utils/configs_handler_test.py b/src/wf_psf/tests/test_utils/configs_handler_test.py index d95761e9..57dfdc8a 100644 --- a/src/wf_psf/tests/test_utils/configs_handler_test.py +++ b/src/wf_psf/tests/test_utils/configs_handler_test.py @@ -4,15 +4,17 @@ :Author: Jennifer Pollack - """ import pytest +from wf_psf.data.data_handler import DataHandler from wf_psf.utils import configs_handler from wf_psf.utils.read_config import RecursiveNamespace from wf_psf.utils.io import FileIOHandler -from wf_psf.utils.configs_handler import TrainingConfigHandler, DataConfigHandler -from wf_psf.data.training_preprocessing import DataHandler +from wf_psf.utils.configs_handler import ( + TrainingConfigHandler, + DataConfigHandler, +) import os @@ -116,10 +118,21 @@ def test_data_config_handler_init(mock_training_conf, mock_data_read_conf, mocke "wf_psf.psf_models.psf_models.simPSF", return_value=mock_simPSF_instance ) - # Patch the load_dataset and process_sed_data methods inside DataHandler - mocker.patch.object(DataHandler, "load_dataset") + # Patch process_sed_data method mocker.patch.object(DataHandler, "process_sed_data") + # Patch validate_and_process_datasetmethod + mocker.patch.object(DataHandler, "validate_and_process_dataset") + + # Patch load_dataset to assign dataset + def mock_load_dataset(self): + self.dataset = { + "SEDs": ["dummy_sed_data"], + "positions": ["dummy_positions_data"], + } + + mocker.patch.object(DataHandler, "load_dataset", new=mock_load_dataset) + # Create DataConfigHandler instance data_config_handler = DataConfigHandler( "/path/to/data_config.yaml", @@ -141,7 +154,7 @@ def test_data_config_handler_init(mock_training_conf, mock_data_read_conf, mocke assert ( data_config_handler.batch_size == mock_training_conf.training.training_hparams.batch_size - ) # Default value + ) def test_training_config_handler_init(mocker, mock_training_conf, mock_file_handler): @@ -229,21 +242,3 @@ def test_run_method_calls_train_with_correct_arguments( mock_th.optimizer_dir, mock_th.psf_model_dir, ) - - -def test_MetricsConfigHandler_weights_basename_filepath( - path_to_tmp_output_dir, path_to_config_dir -): - test_file_handler = FileIOHandler(path_to_tmp_output_dir, path_to_config_dir) - - metrics_config_file = "validation/main_random_seed/config/metrics_config.yaml" - - metrics_object = configs_handler.MetricsConfigHandler( - os.path.join(path_to_config_dir, metrics_config_file), test_file_handler - ) - weights_filepath = metrics_object.weights_basename_filepath - - assert ( - weights_filepath - == "src/wf_psf/tests/data/validation/main_random_seed/checkpoint/checkpoint*_poly*_sample_w_bis1_2k_cycle2*" - ) diff --git a/src/wf_psf/tests/test_utils/utils_test.py b/src/wf_psf/tests/test_utils/utils_test.py index dacf0bdc..cc7f2a2b 100644 --- a/src/wf_psf/tests/test_utils/utils_test.py +++ b/src/wf_psf/tests/test_utils/utils_test.py @@ -20,11 +20,6 @@ from unittest import mock - -def test_sanity(): - assert 1 + 1 == 2 - - def test_downsample_basic(): """Test apply_mask when a zeroed mask is provided.""" img_dim = (10, 10) @@ -37,9 +32,9 @@ def test_downsample_basic(): # The result should be an array of False values, as the mask excludes all pixels expected_result = np.zeros(img_dim, dtype=bool) - assert np.array_equal(result, expected_result), ( - "apply_mask did not handle the zeroed mask correctly." - ) + assert np.array_equal( + result, expected_result + ), "apply_mask did not handle the zeroed mask correctly." def test_initialization(): @@ -121,9 +116,9 @@ def test_apply_mask_with_none_mask(): result = estimator.apply_mask(None) # Pass None as the mask # It should return the window itself when no mask is provided - assert np.array_equal(result, estimator.window), ( - "apply_mask should return the window when mask is None." - ) + assert np.array_equal( + result, estimator.window + ), "apply_mask should return the window when mask is None." def test_apply_mask_with_valid_mask(): @@ -139,9 +134,9 @@ def test_apply_mask_with_valid_mask(): # Check that the mask was applied correctly: pixel (5, 5) should be False, others True expected_result = estimator.window & custom_mask - assert np.array_equal(result, expected_result), ( - "apply_mask did not apply the mask correctly." - ) + assert np.array_equal( + result, expected_result + ), "apply_mask did not apply the mask correctly." def test_apply_mask_with_zeroed_mask(): @@ -156,9 +151,9 @@ def test_apply_mask_with_zeroed_mask(): # The result should be an array of False values, as the mask excludes all pixels expected_result = np.zeros(img_dim, dtype=bool) - assert np.array_equal(result, expected_result), ( - "apply_mask did not handle the zeroed mask correctly." - ) + assert np.array_equal( + result, expected_result + ), "apply_mask did not handle the zeroed mask correctly." def test_unobscured_zernike_projection(): @@ -252,6 +247,7 @@ def test_tf_decompose_obscured_opd_basis(): assert rmse_error < tol + def test_downsample_basic(): """Downsample a small array to a smaller square size.""" arr = np.arange(16).reshape(4, 4).astype(np.float32) @@ -262,9 +258,10 @@ def test_downsample_basic(): assert result.shape == (output_dim, output_dim), "Output shape mismatch" # Values should be averaged/downsampled; simple check - assert np.all(result >= arr.min()) and np.all(result <= arr.max()), \ - "Values outside input range" - + assert np.all(result >= arr.min()) and np.all( + result <= arr.max() + ), "Values outside input range" + def test_downsample_identity(): """Downsample to the same size should return same array (approximately).""" @@ -274,10 +271,12 @@ def test_downsample_identity(): # Since OpenCV / skimage may do minor interpolation, allow small tolerance np.testing.assert_allclose(result, arr, rtol=1e-6, atol=1e-6) + # ---------------------------- # Backend fallback tests # ---------------------------- + @mock.patch("wf_psf.utils.utils._HAS_CV2", False) @mock.patch("wf_psf.utils.utils._HAS_SKIMAGE", False) def test_downsample_no_backend(): @@ -296,10 +295,11 @@ def test_downsample_values_average(): # All output values should be close to input value np.testing.assert_allclose(result, 3.0, rtol=1e-6, atol=1e-6) + @mock.patch("wf_psf.utils.utils._HAS_CV2", True) def test_downsample_non_square_array(): """Check downsampling works for non-square arrays.""" arr = np.arange(12).reshape(3, 4).astype(np.float32) output_dim = 2 result = downsample_im(arr, output_dim) - assert result.shape == (2, 2) \ No newline at end of file + assert result.shape == (2, 2) diff --git a/src/wf_psf/training/train.py b/src/wf_psf/training/train.py index bb0e3df9..ab2f0ac1 100644 --- a/src/wf_psf/training/train.py +++ b/src/wf_psf/training/train.py @@ -7,6 +7,7 @@ """ +import gc import numpy as np import time import tensorflow as tf @@ -273,10 +274,7 @@ def _prepare_callbacks( def get_loss_metrics_monitor_and_outputs(training_handler, data_conf): - """Generate factory for loss, metrics, monitor, and outputs. - - A function to generate loss, metrics, monitor, and outputs - for training. + """Factory to return fresh loss, metrics (param & non-param), monitor, and outputs for the current cycle. Parameters ---------- @@ -369,12 +367,12 @@ def train( psf_model_dir : str Directory where the final trained PSF model weights will be saved per cycle. - Notes - ----- - - Utilizes TensorFlow and TensorFlow Addons for model training and optimization. - - Supports masked mean squared error loss for training with masked data. - - Allows for projection of data-driven features onto parametric models between cycles. - - Supports resetting of non-parametric features to initial states. + Returns + ------- + None + + Side Effects + ------------ - Saves model weights to `psf_model_dir` per training cycle (or final one if not all saved) - Saves optimizer histories to `optimizer_dir` - Logs cycle information and time durations @@ -538,3 +536,8 @@ def train( final_time = time.time() logger.info("\nTotal elapsed time: %f" % (final_time - starting_time)) logger.info("\n Training complete..") + + # Clean up memory + del psf_model + gc.collect() + tf.keras.backend.clear_session() diff --git a/src/wf_psf/utils/configs_handler.py b/src/wf_psf/utils/configs_handler.py index 5d45ba79..13ceb6de 100644 --- a/src/wf_psf/utils/configs_handler.py +++ b/src/wf_psf/utils/configs_handler.py @@ -12,12 +12,14 @@ import os import re import glob -from wf_psf.utils.read_config import read_conf -from wf_psf.data.training_preprocessing import DataHandler -from wf_psf.training import train -from wf_psf.psf_models import psf_models +from wf_psf.data.data_handler import DataHandler from wf_psf.metrics.metrics_interface import evaluate_model from wf_psf.plotting.plots_interface import plot_metrics +from wf_psf.psf_models import psf_models +from wf_psf.psf_models.psf_model_loader import load_trained_psf_model +from wf_psf.training import train +from wf_psf.utils.read_config import read_conf + logger = logging.getLogger(__name__) @@ -127,28 +129,31 @@ class DataConfigHandler: def __init__(self, data_conf, training_model_params, batch_size=16, load_data=True): try: self.data_conf = read_conf(data_conf) - except FileNotFoundError as e: - logger.exception(e) - exit() - except TypeError as e: + except (FileNotFoundError, TypeError) as e: logger.exception(e) exit() self.simPSF = psf_models.simPSF(training_model_params) + + # Extract sub-configs early + train_params = self.data_conf.data.training + test_params = self.data_conf.data.test + self.training_data = DataHandler( dataset_type="training", - data_params=self.data_conf.data, + data_params=train_params, simPSF=self.simPSF, n_bins_lambda=training_model_params.n_bins_lda, load_data=load_data, ) self.test_data = DataHandler( dataset_type="test", - data_params=self.data_conf.data, + data_params=test_params, simPSF=self.simPSF, n_bins_lambda=training_model_params.n_bins_lda, load_data=load_data, ) + self.batch_size = batch_size @@ -183,6 +188,7 @@ def __init__(self, training_conf, file_handler): self.training_conf.training.training_hparams.batch_size, self.training_conf.training.load_data_on_init, ) + self.data_conf.run_type = "training" self.file_handler.copy_conffile_to_output_dir( self.training_conf.training.data_config ) @@ -254,8 +260,13 @@ class MetricsConfigHandler: def __init__(self, metrics_conf, file_handler, training_conf=None): self._metrics_conf = read_conf(metrics_conf) self._file_handler = file_handler - self.trained_model_path = self._get_trained_model_path(training_conf) - self._training_conf = self._load_training_conf(training_conf) + self.training_conf = training_conf + self.data_conf = self._load_data_conf() + self.data_conf.run_type = "metrics" + self.metrics_dir = self._file_handler.get_metrics_dir( + self._file_handler._run_output_dir + ) + self.trained_psf_model = self._load_trained_psf_model() @property def metrics_conf(self): @@ -270,32 +281,29 @@ def metrics_conf(self): """ return self._metrics_conf - @property - def metrics_dir(self): - """Get Metrics Directory. - - A function that returns path - of metrics directory. - - Returns - ------- - str - Absolute path to metrics directory - """ - return self._file_handler.get_metrics_dir(self._file_handler._run_output_dir) - @property def training_conf(self): - """Get Training Conf. - - A function to return the training configuration file name. + """Returns the loaded training configuration.""" + return self._training_conf - Returns - ------- - RecursiveNamespace - An instance of the training configuration file. + @training_conf.setter + def training_conf(self, training_conf): """ - return self._training_conf + Sets the training configuration. If None is provided, attempts to load it + from the trained_model_path in the metrics configuration. + """ + if training_conf is None: + try: + training_conf_path = self._get_training_conf_path_from_metrics() + logger.info( + f"Loading training config from inferred path: {training_conf_path}" + ) + self._training_conf = read_conf(training_conf_path) + except Exception as e: + logger.error(f"Failed to load training config: {e}") + raise + else: + self._training_conf = training_conf @property def plotting_conf(self): @@ -310,112 +318,106 @@ def plotting_conf(self): """ return self.metrics_conf.metrics.plotting_config - @property - def data_conf(self): - """Get Data Conf. - - A function to return an instance of the DataConfigHandler class. + def _load_trained_psf_model(self): + trained_model_path = self._get_trained_model_path() + try: + model_subdir = self.metrics_conf.metrics.model_save_path + cycle = self.metrics_conf.metrics.saved_training_cycle + except AttributeError as e: + raise KeyError("Missing required model config fields.") from e + + model_name = self.training_conf.training.model_params.model_name + id_name = self.training_conf.training.id_name + + weights_path_pattern = os.path.join( + trained_model_path, + model_subdir, + (f"{model_subdir}*_{model_name}" f"*{id_name}_cycle{cycle}*"), + ) + return load_trained_psf_model( + self.training_conf, + self.data_conf, + weights_path_pattern, + ) - Returns - ------- - An instance of the DataConfigHandler class. + def _get_training_conf_path_from_metrics(self): """ - return self._load_data_conf() - - @property - def psf_model(self): - """Get PSF Model. - - A function to return an instance of the PSF model - to be evaluated. + Retrieves the full path to the training config based on the metrics configuration. Returns ------- - psf_model: obj - An instance of the PSF model to be evaluated. + str + Full path to the training configuration file. + + Raises + ------ + KeyError + If 'trained_model_config' key is missing. + FileNotFoundError + If the file does not exist at the constructed path. """ - return psf_models.get_psf_model( - self.training_conf.training.model_params, - self.training_conf.training.training_hparams, - self.data_conf, + trained_model_path = self._get_trained_model_path() + + try: + training_conf_filename = self._metrics_conf.metrics.trained_model_config + except AttributeError as e: + raise KeyError( + "Missing 'trained_model_config' key in metrics configuration." + ) from e + + training_conf_path = os.path.join( + self._file_handler.get_config_dir(trained_model_path), + training_conf_filename, ) - @property - def weights_path(self): - """Get Weights Path. + if not os.path.exists(training_conf_path): + raise FileNotFoundError( + f"Training config file not found: {training_conf_path}" + ) - A function to return the full path - of the user-specified psf model weights to be loaded. + return training_conf_path - Returns - ------- - str - A string representing the full path to the psf model weights to be loaded. + def _get_trained_model_path(self): """ - return psf_models.get_psf_model_weights_filepath(self.weights_basename_filepath) - - def _get_trained_model_path(self, training_conf): - """Get Trained Model Path. + Determine the trained model path from either: - Helper method to get the trained model path. - - Parameters - ---------- - training_conf: None or RecursiveNamespace - None type or RecursiveNamespace + 1. The metrics configuration file (i.e., for metrics-only runs after training), or + 2. The runtime-generated file handler paths (i.e., for single runs that perform both training and evaluation). Returns ------- str - A string representing the path to the trained model output run directory. + Path to the trained model directory. + Raises + ------ + ConfigParameterError + If the path specified in the metrics config is invalid or missing. """ - if training_conf is None: - try: - return self._metrics_conf.metrics.trained_model_path + trained_model_path = getattr( + self._metrics_conf.metrics, "trained_model_path", None + ) - except TypeError as e: - logger.exception(e) + if trained_model_path: + if not os.path.isdir(trained_model_path): raise ConfigParameterError( - "Metrics config file trained model path or config values are empty." + f"The trained model path provided in the metrics config is not a valid directory: {trained_model_path}" ) - else: - return os.path.join( - self._file_handler.output_path, - self._file_handler.parent_output_dir, - self._file_handler.workdir, + logger.info( + f"Using trained model path from metrics config: {trained_model_path}" ) + return trained_model_path - def _load_training_conf(self, training_conf): - """Load Training Conf. - - Load the training configuration if training_conf is not provided. - - Parameters - ---------- - training_conf: None or RecursiveNamespace - None type or a RecursiveNamespace storing the training configuration parameter setttings. - - Returns - ------- - RecursiveNamespace storing the training configuration parameter settings. - - """ - if training_conf is None: - try: - return read_conf( - os.path.join( - self._file_handler.get_config_dir(self.trained_model_path), - self._metrics_conf.metrics.trained_model_config, - ) - ) - except TypeError as e: - logger.exception(e) - raise ConfigParameterError( - "Metrics config file trained model path or config values are empty." - ) - else: - return training_conf + # Fallback for single-run training + metrics evaluation mode + fallback_path = os.path.join( + self._file_handler.output_path, + self._file_handler.parent_output_dir, + self._file_handler.workdir, + ) + logger.info( + f"Using fallback trained model path from runtime file handler: {fallback_path}" + ) + return fallback_path def _load_data_conf(self): """Load Data Conf. @@ -439,27 +441,6 @@ def _load_data_conf(self): logger.exception(e) raise ConfigParameterError("Data configuration loading error.") - @property - def weights_basename_filepath(self): - """Get PSF model weights filepath. - - A function to return the basename of the user-specified psf model weights path. - - Returns - ------- - weights_basename: str - The basename of the psf model weights to be loaded. - - """ - return os.path.join( - self.trained_model_path, - self.metrics_conf.metrics.model_save_path, - ( - f"{self.metrics_conf.metrics.model_save_path}*_{self.training_conf.training.model_params.model_name}" - f"*{self.training_conf.training.id_name}_cycle{self.metrics_conf.metrics.saved_training_cycle}*" - ), - ) - def call_plot_config_handler_run(self, model_metrics): """Make Metrics Plots. @@ -502,18 +483,17 @@ def call_plot_config_handler_run(self, model_metrics): def run(self): """Run. - A function to run wave-diff according to the + A function to run WaveDiff according to the input configuration. """ - logger.info(f"Running metrics evaluation on psf model: {self.weights_path}") + logger.info("Running metrics evaluation on trained PSF model...") model_metrics = evaluate_model( self.metrics_conf.metrics, self.training_conf.training, self.data_conf, - self.psf_model, - self.weights_path, + self.trained_psf_model, self.metrics_dir, ) diff --git a/src/wf_psf/utils/preprocessing.py b/src/wf_psf/utils/preprocessing.py deleted file mode 100644 index 210c03e5..00000000 --- a/src/wf_psf/utils/preprocessing.py +++ /dev/null @@ -1,102 +0,0 @@ -"""Preprocessing. - -A module with utils to preprocess data. - -:Author: Tobias Liaudat - -""" - -import numpy as np - - -def shift_x_y_to_zk1_2_wavediff(dxy, tel_focal_length=24.5, tel_diameter=1.2): - """Compute Zernike 1(2) for a given shifts in x(y) in WaveDifff conventions. - - All inputs should be in [m]. - A displacement of, for example, 0.5 pixels should be scaled with the corresponding pixel scale, - e.g. 12[um], to get a displacement in [m], which would be `dxy=0.5*12e-6`. - - The output zernike coefficient is in [um] units as expected by wavediff. - - To apply match the centroid with a `dx` that has a corresponding `zk1`, - the new PSF should be generated with `-zk1`. - - The same applies to `dy` and `zk2`. - - Parameters - ---------- - dxy : float - Centroid shift in [m]. It can be on the x-axis or the y-axis. - tel_focal_length : float - Telescope focal length in [m]. - tel_diameter : float - Telescope aperture diameter in [m]. - """ - reference_pix_sampling = 12e-6 - zernike_norm_factor = 2.0 - - # return zernike_norm_factor * (dx/reference_pix_sampling) / (tel_focal_length * tel_diameter / 2) - return ( - zernike_norm_factor - * (tel_diameter / 2) - * np.sin(np.arctan((dxy / reference_pix_sampling) / tel_focal_length)) - * 3.0 - ) - - -def defocus_to_zk4_zemax(dz, tel_focal_length=24.5, tel_diameter=1.2): - """Compute Zernike 4 value for a given defocus in zemax conventions. - - All inputs should be in [m]. - - Parameters - ---------- - dz : float - Shift in the z-axis, perpendicular to the focal plane. Units in [m]. - tel_focal_length : float - Telescope focal length in [m]. - tel_diameter : float - Telescope aperture diameter in [m]. - """ - # Base calculation - zk4 = dz / (8.0 * (tel_focal_length / tel_diameter) ** 2) - # Apply Z4 normalisation - # This step depends on the normalisation of the Zernike basis used - zk4 /= np.sqrt(3) - # Convert to waves with a reference of 800nm - zk4 /= 800e-9 - # Remove the peak to valley value - zk4 /= 2.0 - - return zk4 - - -def defocus_to_zk4_wavediff(dz, tel_focal_length=24.5, tel_diameter=1.2): - """Compute Zernike 4 value for a given defocus in WaveDifff conventions. - - All inputs should be in [m]. - - The output zernike coefficient is in [um] units as expected by wavediff. - - Parameters - ---------- - dz : float - Shift in the z-axis, perpendicular to the focal plane. Units in [m]. - tel_focal_length : float - Telescope focal length in [m]. - tel_diameter : float - Telescope aperture diameter in [m]. - """ - # Base calculation - zk4 = dz / (8.0 * (tel_focal_length / tel_diameter) ** 2) - # Apply Z4 normalisation - # This step depends on the normalisation of the Zernike basis used - zk4 /= np.sqrt(3) - - # Remove the peak to valley value - zk4 /= 2.0 - - # Change units to [um] as Wavediff uses - zk4 *= 1e6 - - return zk4 diff --git a/src/wf_psf/utils/read_config.py b/src/wf_psf/utils/read_config.py index 875ae8ed..48d23e00 100644 --- a/src/wf_psf/utils/read_config.py +++ b/src/wf_psf/utils/read_config.py @@ -140,4 +140,4 @@ def read_stream(conf_file): docs = yaml.load_all(stream, yaml.FullLoader) for doc in docs: # noqa: UP028 - yield doc + yield doc diff --git a/src/wf_psf/utils/utils.py b/src/wf_psf/utils/utils.py index 1b1f2d6d..17219ad9 100644 --- a/src/wf_psf/utils/utils.py +++ b/src/wf_psf/utils/utils.py @@ -26,6 +26,47 @@ pass +def scale_to_range(input_array, old_range, new_range): + # Scale to [0,1] + input_array = (input_array - old_range[0]) / (old_range[1] - old_range[0]) + # Scale to new_range + input_array = input_array * (new_range[1] - new_range[0]) + new_range[0] + return input_array + + +def ensure_batch(arr): + """ + Ensure array/tensor has a batch dimension. Converts shape (M, N) → (1, M, N). + + Parameters + ---------- + arr : np.ndarray or tf.Tensor + Input 2D or 3D array/tensor. + + Returns + ------- + np.ndarray or tf.Tensor + With batch dimension prepended if needed. + """ + if isinstance(arr, np.ndarray): + return arr if arr.ndim == 3 else np.expand_dims(arr, axis=0) + elif isinstance(arr, tf.Tensor): + return arr if arr.ndim == 3 else tf.expand_dims(arr, axis=0) + else: + raise TypeError(f"Expected np.ndarray or tf.Tensor, got {type(arr)}") + + +def calc_wfe(zernike_basis, zks): + wfe = np.einsum("ijk,ijk->jk", zernike_basis, zks.reshape(-1, 1, 1)) + return wfe + + +def calc_wfe_rms(zernike_basis, zks, pupil_mask): + wfe = calc_wfe(zernike_basis, zks) + wfe_rms = np.sqrt(np.mean((wfe[pupil_mask] - np.mean(wfe[pupil_mask])) ** 2)) + return wfe_rms + + def generalised_sigmoid(x, max_val=1, power_k=1): """ Apply a generalized sigmoid function to the input.