diff --git a/.github/workflows/cd_manual.yml b/.github/workflows/cd_manual.yml index 4064cf67..9c5568a4 100644 --- a/.github/workflows/cd_manual.yml +++ b/.github/workflows/cd_manual.yml @@ -24,7 +24,8 @@ jobs: - name: Install dependencies run: | - python -m pip install ".[docs]" + python -m pip install --upgrade pip setuptools wheel + pip install -e ".[docs]" - name: Build API documentation run: | @@ -35,4 +36,4 @@ jobs: uses: actions/upload-artifact@v4 with: name: api-docs - publish_dir: docs/build/html + path: docs/build/html diff --git a/.gitignore b/.gitignore index 1c4865fc..7f3bc838 100644 --- a/.gitignore +++ b/.gitignore @@ -89,6 +89,7 @@ instance/ # Sphinx documentation docs/build/ +docs/_build/ docs/source/wf_psf*.rst docs/source/_static/file.png docs/source/_static/images/logo_colab.png diff --git a/changelog.d/20260123_160331_jennifer.pollack_159_psf_output_from_trained_model.md b/changelog.d/20260123_160331_jennifer.pollack_159_psf_output_from_trained_model.md new file mode 100644 index 00000000..453a504e --- /dev/null +++ b/changelog.d/20260123_160331_jennifer.pollack_159_psf_output_from_trained_model.md @@ -0,0 +1,45 @@ + + + + +### New features + +- Added PSF inference capabilities for generating broadband (polychromatic) PSFs from trained models given star positions and SEDs +- Introduced `PSFInferenceEngine` class to centralize training, simulation, metrics, and inference workflows +- Added `run_type` attribute to `DataHandler` supporting training, simulation, metrics, and inference modes +- Implemented `ZernikeInputsFactory` class for building `ZernikeInputs` instances based on run type +- Added `psf_model_loader.py` module for centralized model weights loading + + +### Bug fixes + +- Fix logger formatting for relative RMSE metrics in `metrics.py` (values were not being displayed) + + + +### Internal changes + +- Refactored `TFPhysicalPolychromatic` and related modules to separate training vs. inference behavior +- Enhanced `ZernikeInputs` data class with intelligent assembly based on run type and available data +- Implemented hybrid loading pattern with eager loading in constructors and lazy-loading via property decorators +- Centralized PSF data extraction in `data_handler` module +- Improved code organization with new `tf_utils.py` module in `psf_models` sub-package +- Updated configuration handling to support inference workflows via `inference_config.yaml` +- Fixed incorrect argument name in `DataHandler` that prevented proper TensorFlow data type conversion +- Removed deprecated `get_obs_positions` method +- Updated documentation to include inference package diff --git a/changelog.d/20260210_150506_jennifer.pollack_159_psf_output_from_trained_model.md b/changelog.d/20260210_150506_jennifer.pollack_159_psf_output_from_trained_model.md new file mode 100644 index 00000000..7fb0859a --- /dev/null +++ b/changelog.d/20260210_150506_jennifer.pollack_159_psf_output_from_trained_model.md @@ -0,0 +1,37 @@ + + + + + + + +### Internal changes + +- Remove deprecated/optional import tensorflow-addons statement from tf_layers.py + + diff --git a/changelog.d/20260213_141819_jennifer.pollack_api_documentation_update.md b/changelog.d/20260213_141819_jennifer.pollack_api_documentation_update.md new file mode 100644 index 00000000..0239f903 --- /dev/null +++ b/changelog.d/20260213_141819_jennifer.pollack_api_documentation_update.md @@ -0,0 +1,40 @@ + + + + + + + +### Internal changes + +- Fixed Sphinx autosummary import errors by removing core dependencies (tensorflow) from autodoc_mock_imports in conf.py. +- Updated pyproject.toml to include all wf_psf packages under src/ and include config/yaml files. +- Updated example configuration files with clearer inline comments. + + + diff --git a/changelog.d/20260213_143539_jennifer.pollack_api_documentation_update.md b/changelog.d/20260213_143539_jennifer.pollack_api_documentation_update.md new file mode 100644 index 00000000..9b41f68f --- /dev/null +++ b/changelog.d/20260213_143539_jennifer.pollack_api_documentation_update.md @@ -0,0 +1,47 @@ + + + + + + + +### Internal changes + +- API documentation for new `inference` package in `api.rst` +- API documentation for new `instrument` package in `api.rst` +- Inference Configuration section in `configuration.md` documenting `inference_config.yaml` +- Restructured Configuration documentation: + - Split workflows into "CLI Tasks" and "Additional Components" sections + - Added configuration file dependency table showing required vs optional files per task + - Clarified configuration filename flexibility (filenames customizable, internal structure fixed) + - Standardized section titles (Training Configuration, Metrics Configuration, etc.) + - Improved markdown formatting and fixed broken anchor links +- Updated `dependencies.md` to document `tensorflow-addons` as optional dependency with manual installation instructions +- `tensorflow-addons` from core dependencies documentation (now documented as optional) + + diff --git a/config/data_config.yaml b/config/data_config.yaml index 02a15341..ca6939af 100644 --- a/config/data_config.yaml +++ b/config/data_config.yaml @@ -5,43 +5,8 @@ data: data_dir: data/coherent_euclid_dataset/ # Provide name of training dataset file: train_Euclid_res_200_TrainStars_id_001.npy - # if training data set file does not exist, generate a new one by setting values below - stars: null - positions: null - SEDS: null - zernike_coef: null - C_poly: null - params: # - d_max: 2 - max_order: 45 - x_lims: [0, 1000.0] - y_lims: [0, 1000.0] - grid_points: [4, 4] - n_bins: 20 - max_wfe_rms: 0.1 - oversampling_rate: 3.0 - output_Q: 3.0 - output_dim: 32 - LP_filter_length: 2 - pupil_diameter: 256 - euclid_obsc: true - n_stars: 200 test: # Specify directory path to training dataset data_dir: data/coherent_euclid_dataset/ # Provide name of test dataset - file: test_Euclid_res_id_001.npy - # If test data set file not provided produce a new one - stars: null - noisy_stars: null - positions: null - SEDS: null - zernike_coef: null - C_poly: null - parameters: - d_max: 2 - max_order: 45 - x_lims: [0, 1000.0] - y_lims: [0, 1000.0] - grid_points: [4,4] - max_wfe_rms: 0.1 \ No newline at end of file + file: test_Euclid_res_id_001.npy \ No newline at end of file diff --git a/config/inference_config.yaml b/config/inference_config.yaml new file mode 100644 index 00000000..7f32a235 --- /dev/null +++ b/config/inference_config.yaml @@ -0,0 +1,37 @@ + +inference: + # Inference batch size + batch_size: 16 + # Cycle to use for inference. Can be: 1, 2, ... + cycle: 2 + + # Paths to the configuration files and trained model directory + configs: + # Path to the directory containing the trained model + trained_model_path: /path/to/trained/model/ + + # Subdirectory name of the trained model, e.g. psf_model + model_subdir: model + + # Relative Path to the training configuration file used to train the model + trained_model_config_path: config/training_config.yaml + + # Path to the data config file (this could contain prior information) + data_config_path: + + # The following parameters will overwrite the `model_params` in the training config file. + model_params: + # Num of wavelength bins to reconstruct polychromatic objects. + n_bins_lda: 8 + + # Downsampling rate to match the oversampled model to the specified telescope's sampling. + output_Q: 1 + + # Dimension of the pixel PSF postage stamp + output_dim: 64 + + # Flag to perform centroid error correction + correct_centroids: False + + # Flag to perform CCD misalignment error correction + add_ccd_misalignments: True diff --git a/config/metrics_config.yaml b/config/metrics_config.yaml index 9f543c31..50dbcd75 100644 --- a/config/metrics_config.yaml +++ b/config/metrics_config.yaml @@ -51,9 +51,6 @@ metrics: # Top-hat filter to avoid the aliasing effect in the obscuration mask LP_filter_length: 2 - # Boolean to define if we use sample weights based on the noise standard deviation estimation - use_sample_weights: True - # Flag to use Zernike prior use_prior: False @@ -140,7 +137,7 @@ metrics: metrics_hparams: # Batch size to use for the evaluation. batch_size: 16 - + # Metrics and model evaluation configuration optimizer: name: 'adam' # Only standard Adam used for metrics diff --git a/config/plotting_config.yaml b/config/plotting_config.yaml index fcb17ecd..8abbad89 100644 --- a/config/plotting_config.yaml +++ b/config/plotting_config.yaml @@ -1,13 +1,18 @@ plotting_params: - # Specify path to parent folder containing wf-outputs-xxxxxxxxxxx for all runs, ex: $WORK/wf-outputs/ + # Path to the parent folder containing WaveDiff output directories metrics_output_path: - # List all of the parent output directories (i.e. wf-outputs-xxxxxxxxxxx) that contain metrics results to be included in the plot + + # List of output directories whose metrics should be plotted + # Leave commented/empty if plotting immediately after a metrics run metrics_dir: - # - wf-outputs-xxxxxxxxxxx1 - # - wf-outputs-xxxxxxxxxxx2 - # List of name of metric config file to add to plot (would like to change such that code goes and finds them in the metrics_dir) + # - wf-outputs-xxxxxxxxxxxxxxxxxxx1 + # - wf-outputs-xxxxxxxxxxxxxxxxxxx2 + + # List of metrics config filenames corresponding to each directory + # Leave commented/empty if plotting immediately after a metrics run metrics_config: - # - metrics_config_1.yaml - # - metrics_config_2.yaml - # Show Plots Flag + # - metrics_config_1.yaml + # - metrics_config_2.yaml + + # If True, plots are shown interactively during execution plot_show: False \ No newline at end of file diff --git a/config/training_config.yaml b/config/training_config.yaml index 9fa8a50f..6348c416 100644 --- a/config/training_config.yaml +++ b/config/training_config.yaml @@ -52,7 +52,7 @@ training: ccd_misalignments_input_path: /path/to/ccd_misalignments_file.txt # Boolean to use sample weights based on the noise standard deviation estimation - use_sample_weights: True + use_sample_weights: True # Sample weight generalised sigmoid function sample_weights_sigmoid: @@ -96,7 +96,7 @@ training: # Telescope's focal length in [m]. Default is `24.5`[m] (Euclid-like). tel_focal_length: 24.5 - # Wheter to use Euclid-like obscurations. + # Use Euclid-like obscurations. euclid_obsc: True # Length of one dimension of the Low-Pass (LP) filter to apply to the @@ -154,9 +154,9 @@ training: loss: 'mask_mse' # Optimizer to use during training. Options are: 'adam' or 'rectified_adam'. - optimizer: - name: 'rectified_adam' - + optimizer: + name: 'rectified_adam' + multi_cycle_params: # Number of training cycles to perform. Each cycle may use different learning rates or number of epochs. diff --git a/docs/source/api.rst b/docs/source/api.rst index 3bc4d395..868493f4 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -8,6 +8,8 @@ This section contains the API reference for the main packages in WaveDiff. :recursive: wf_psf.data + wf_psf.inference + wf_psf.instrument wf_psf.metrics wf_psf.plotting wf_psf.psf_models diff --git a/docs/source/conf.py b/docs/source/conf.py index 988df8fe..ff6ae3fe 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,14 +1,16 @@ -# Configuration file for the Sphinx documentation builder. -# -# For the full list of built-in configuration values, see the documentation: -# https://www.sphinx-doc.org/en/master/usage/configuration.html -import sys -import os -from datetime import datetime +""" +Sphinx configuration for the wf-psf documentation. -current_year = datetime.now().year +This file sets up paths, extensions, theme, and other options +for building the HTML docs. +""" + +from datetime import datetime +import os +import sys -sys.path.insert(0, os.path.abspath("src/wf_psf")) +repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +sys.path.insert(0, os.path.join(repo_root, "src")) # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information @@ -21,7 +23,7 @@ else: copyright = f"{start_year}, CosmoStat" author = "CosmoStat" -release = "3.0.0" +release = "3.1.0" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration @@ -103,7 +105,7 @@ bibtex_reference_style = "author_year" # -- Mock imports for documentation ------------------------------------------ -autodoc_mock_imports = [ - "tensorflow", +optional_deps = [ "tensorflow_addons", ] +autodoc_mock_imports = optional_deps diff --git a/docs/source/configuration.md b/docs/source/configuration.md index ba87cef2..b7e4d638 100644 --- a/docs/source/configuration.md +++ b/docs/source/configuration.md @@ -1,59 +1,84 @@ # Configuration -WaveDiff uses a set of YAML and INI configuration files to control each pipeline task. -This section provides a high-level overview of the configuration system, followed by detailed explanations of each file. +WaveDiff uses a set of YAML and INI configuration files to control each pipeline task. This section provides a high-level overview of the configuration system, followed by detailed explanations of each file. -## Overview of Pipeline Tasks +## Overview of Workflows -WaveDiff consists of four main pipeline tasks: +WaveDiff consists of three CLI tasks, configured by passing a configuration file to the `wavediff` command (e.g., `wavediff -c configs.yaml -o output/`): -| Pipeline | Purpose | -|---------|---------| +| Task | Purpose | +|------|---------| | `training` | Trains a PSF model using the provided dataset and hyperparameters. | -| `metrics` |Evaluates model performance using multiple metrics, optionally comparing against a ground-truth model. | -| `plotting` | Generates figures summarizing the results from the metrics pipeline. | -| `sims` | Simulates stellar PSFs used as training/test datasets.
*(Currently executed via a standalone script, not via `wavediff`.)* | +| `metrics` | Evaluates model performance using multiple metrics, optionally comparing against a ground-truth model. | +| `plotting` | Generates figures summarising the results from the metrics pipeline. | -You configure these tasks by passing a configuration file to the `wavediff` command (e.g., `--config configs.yaml`). +WaveDiff also provides two standalone Python APIs used outside the `wavediff` CLI: + +| Component | Purpose | +|-----------|---------| +| `sims` | Provides classes and methods for simulating monochromatic, polychromatic, and spatially-varying PSFs
for generating custom datasets. | +| `inference` | Provides classes and methods for inferring PSFs as a function of position and SED from a trained PSF
model. | ## Configuration File Structure -WaveDiff expects the following configuration files under the `config/` directory: +WaveDiff expects configuration files under the `config/` directory. **Configuration filenames are flexible** — you can name them as you wish (e.g., `training_euclid_v2.yaml`, `my_metrics.yaml`) as long as you reference them correctly in `configs.yaml` or command-line arguments. The filenames shown below are conventional defaults used in documentation examples. + +The files required depend on which task or component you are running: ``` -config -├── configs.yaml -├── data_config.yaml -├── logging.conf -├── metrics_config.yaml -├── plotting_config.yaml -└── training_config.yaml +config/ +├── configs.yaml # Master configuration (all CLI tasks) +├── data_config.yaml # Dataset paths (training task only) +├── logging.conf # Logging configuration (all CLI tasks) +├── training_config.yaml # Training task +├── metrics_config.yaml # Metrics task +├── plotting_config.yaml # Plotting task +└── inference_config.yaml # Inference API ``` -- All `.yaml` files use standard **YAML** syntax and are loaded as nested dictionaries of key–value pairs. -- `logging.conf` uses standard **INI** syntax and configures logging behavior. -- Users may modify values but should **not rename keys or section names**, as the software depends on them. +| Task / Component | Required | Optional | +|---------------|--------------|---------| +| `training` | `configs.yaml`, `data_config.yaml` ,
`logging.conf`, `training_config.yaml` | `metrics_config.yaml`
(_triggers post-training metrics_) | +| `metrics` | `configs.yaml`, `logging.conf`,
`metrics_config.yaml` | `plotting_config.yaml`
(_triggers post-metrics plotting_) | +| `plotting` | `configs.yaml`, `logging.conf`,
`plotting_config.yaml`| — | +| `inference` | `inference_config.yaml` | `data_config.yaml` | + +**Notes:** + +- **Configuration filenames are flexible.** The names shown above (e.g., `training_config.yaml`) are conventional defaults. You may use any filename as long as you reference it correctly in `configs.yaml` or via command-line arguments. +- **Keys and section names within configuration files must be preserved.** While you can rename files, the internal YAML structure (keys like `model_params`, `training`, etc.) must remain unchanged, as the software depends on them. +- The metrics and plotting tasks retrieve dataset paths from the trained model's configuration and do not require `data_config.yaml`. +- When `metrics_config.yaml` is specified as optional for the `training` task, metrics evaluation runs automatically after training completes. +- When `plotting_config.yaml` is specified as optional for the `metrics` task, plots are generated automatically after metrics evaluation completes. +- `logging.conf` uses standard INI syntax and configures logging behaviour for all CLI tasks. Each of the configuration files is described in detail below. (data_config)= -## `data_config.yaml` — Data Configuration +## Data Configuration ### 1. Purpose -Specifies where WaveDiff loads (or later versions may generate) the training and test datasets. -All training, evaluation, and metrics pipelines depend on this file for consistent dataset paths. +Specifies the training and test datasets used by the training CLI task. + ### 2. Key Fields -- `data.training.data_dir` _(required)_ — directory containing training data -- `data.training.file` _(required)_ — filename of the training dataset -- `data.test.*` — same structure as `training`, for the test dataset -- **Simulation-related fields** — reserved for future releases -**Notes** -- The simulation options are placeholders; WaveDiff v3.x does **not yet** auto-generate datasets. -- The default dataset bundled with WaveDiff can be used by simply pointing to its directory. +Both `data.training` and `data.test` share the same structure: + +| Field | Required | Description | +|-----------|--------------|--------------| +| `data_dir` | Yes | Path to the directory containing the dataset. | +| `file` | Yes | Filename of the dataset (`.npy`). | + + +### 3. Notes + +- The default dataset bundled with WaveDiff can be used by pointing `data_dir` to its installation directory. +- The `metrics` and `plotting` tasks retrieve dataset paths automatically from the trained model's configuration file and do not require this file. +- This file is optional for the `inference` API; see [inference_config.yaml](inference_config) if you need to supply prior information for inference. + +### 4. Example -**Example (minimal)** ```yaml data: training: @@ -65,51 +90,52 @@ data: ``` (training_config)= -## `training_config.yaml` — Training Pipeline Configuration +## Training Configuration ### 1. Purpose -Controls the training pipeline, including model selection, hyperparameters, optional metrics evaluation, and data loading behavior. +Controls the training pipeline, including model selection, hyperparameters, optional post-training metrics evaluation, and data loading behaviour. ### 2. General Notes -- Every field has an inline comment in the YAML file. -- **All required parameters must be specified.** Missing values will prevent the model from being instantiated, as there is currently no default configuration provided. +- **All required parameters must be specified.** There is currently no default configuration — missing values will prevent the model from being instantiated. - **Optional fields:** - - `metrics_config` (run metrics after training) - - `param_hparams`, `nonparam_hparams` - - `multi_cycle_params.save_all_cycles` -- Some parameters are specific to physical or polychromatic PSF models. -- Example training configuration file is provided in the top-level root directory of the repository (`training_config.yaml`). Users can copy and adapt this template for their own runs. -- If any descriptions are unclear, or unexpected behaviour occurs, please open a [GitHub issue](https://github.com/CosmoStat/wf-psf/issues/new). + - `metrics_config` — trigger metrics evaluation after training completes + - `multi_cycle_params.save_all_cycles`— defaults to `False` +- Some parameters are specific to the physical PSF model and may be ignored by simpler model types. +- An example training configuration file is provided in the repository root (`config/training_config.yaml`). Copy and adapt this template for your own runs. +- **Fraction notation**: Fields like `reference_shifts` accept fraction strings (e.g., "`-1/3`") which are automatically converted to floats. You can also use decimal values directly (e.g., `-0.333`). +- Every field in the YAML file includes an inline comment. If any descriptions remain unclear or unexpected behavior occurs, please open a [GitHub issue](https://github.com/CosmoStat/wf-psf/issues/new). -**Note:** The values in the examples shown below correspond to a typical WaveDiff training run. Users should adapt parameters such as `model_name`, telescope dimensions, pixel/field coordinates, and SED settings to match their own instrument or dataset. All required fields must still be specified. +**Note on example values**: The parameter values shown below correspond to a typical Euclid-like WaveDiff training run. Adapt `model_name`, telescope dimensions, pixel/field coordinates, and SED settings to match your instrument and dataset. ### 3. Top-Level Training Parameters +`training` ```yaml training: - # ID name for this run (used in output files) - id_name: run__001 + # ID name for this run (used in output filenames and logs) + id_name: run_001 - # Path to Data Configuration file (required) + # Path to data configuration file data_config: data_config.yaml # Load dataset on initialization (True) or manually later (False) load_data_on_init: True - # Optional: metrics configuration to run after training + # Optional: path to metrics configuration to run after training metrics_config: ``` -### 4. Model Parameters (`model_params`) +### 4. Model Parameters -Controls PSF model type, geometry, oversampling, and preprocessing: +Controls PSF model type, geometry, oversampling, and physical corrections. +`training.model_params` ```yaml model_params: - # Model type. Options: 'poly' and 'physical_poly' + # Model type. Options: 'poly', 'physical_poly' model_name: physical_poly # Number of wavelength bins for polychromatic reconstruction @@ -134,25 +160,22 @@ model_params: # Centroid correction parameters sigma_centroid_window: 2.5 # Std dev of centroiding window - reference_shifts: [-1/3, -1/3] # Euclid-like default shifts + reference_shifts: [-0.333, -0.333] # Reference pixel shifts (Euclid default: -1/3, -1/3) - # Obscuration / geometry - obscuration_rotation_angle: 0 # Degrees (multiple of 90); counterclockwise rotation. + # Obscuration geometry + obscuration_rotation_angle: 0 # Rotation in degrees (multiples of 90); counterclockwise # CCD misalignments input file path ccd_misalignments_input_path: /path/to/ccd_misalignments_file.txt - - # Boolean to use sample weights based on the noise standard deviation estimation - use_sample_weights: True - # Sample weight generalised sigmoid function + # Sample weighting based on noise standard deviation + use_sample_weights: True + + # Sample weight sigmoid function parameters sample_weights_sigmoid: - # Boolean to define if we apply the sigmoid function to the sample weights - apply_sigmoid: False - # Maximum value of the sigmoid function and consequently the maximum value of the sample weights - sigmoid_max_val: 5.0 - # Power of the sigmoid function. The higher the value the steeper the sigmoid function. In the limit - sigmoid_power_k: 1.0 + apply_sigmoid: False # Enable sigmoid weighting transform + sigmoid_max_val: 5.0 # Maximum sample weight value + sigmoid_power_k: 1.0 # Sigmoid steepness (higher = steeper) # Interpolation settings for physical-poly model interpolation_type: None @@ -165,30 +188,32 @@ model_params: sed_sigma: 0 # Field and pixel coordinates - x_lims: [0.0, 1.0e3] - y_lims: [0.0, 1.0e3] - pix_sampling: 12 # in [um] + x_lims: [0.0, 1000.0] + y_lims: [0.0, 1000.0] + pix_sampling: 12 # Pixel size in microns # Telescope parameters - tel_diameter: 1.2 # [m] - tel_focal_length: 24.5 # [m] - euclid_obsc: True. # Use Euclid-specific obscuration mask. Set to False for other instruments or custom masks. - LP_filter_length: 3 # Low-pass filter for obscurations + tel_diameter: 1.2 # Aperture diameter in meters + tel_focal_length: 24.5 # Focal length in meters + euclid_obsc: True # Use Euclid-specific obscuration mask (set False for other instruments) + LP_filter_length: 3 # Low-pass filter kernel size for obscurations ``` -### 5. Parametric Model Hyperparameters (`param_hparams`) +### 5. Parametric Model Hyperparameters +`training.model_params.param_hparams` ```yaml param_hparams: random_seed: 3877572 - l2_param: 0.0 # L2 loss for OPD/WFE - n_zernikes: 15 - d_max: 2 # Max polynomial degree - save_optim_history_param: true + l2_param: 0.0 # L2 regularization weight for OPD/WFE + n_zernikes: 15 # Number of Zernike polynomials + d_max: 2 # Maximum polynomial degree + save_optim_history_param: True ``` -### 6. Non-Parametric Model Hyperparameters (`nonparam_hparams`) +### 6. Non-Parametric Model Hyperparameters +`training.model_params.nonparam_hparams` ```yaml nonparam_hparams: d_max_nonparam: 5 @@ -196,38 +221,46 @@ nonparam_hparams: l1_rate: 1.0e-8 project_dd_features: False reset_dd_features: False - save_optim_history_nonparam: true + save_optim_history_nonparam: True ``` -### 7. Training Hyperparameters (`training_hparams`) +### 7. Training Hyperparameters -Controls batches, loss, and multi-cycle learning: +Controls batch size, loss function, optimizer selection, and multi-cycle learning. +`training.training_hparams` ```yaml training_hparams: - batch_size: 32 # Number of samples per batch - loss: 'mask_mse' # Options: 'mask_mse', 'mse' + batch_size: 32 # Number of samples per training batch + loss: 'mask_mse' # Loss function. Options: 'mask_mse', 'mse' + optimizer: + name: 'rectified_adam' # Options: 'adam', 'rectified_adam' multi_cycle_params: total_cycles: 2 - cycle_def: complete # Options: 'parametric', 'non-parametric', 'complete', etc. - save_all_cycles: False - saved_cycle: cycle2 - - learning_rate_params: [1.0e-2, 1.0e-2] - learning_rate_non_params: [1.0e-1, 1.0e-1] - n_epochs_params: [20, 20] - n_epochs_non_params: [100, 120] + cycle_def: complete # Options: 'parametric', 'non-parametric', 'complete' + save_all_cycles: False # If True, saves checkpoints for all cycles; otherwise only saved_cycle + saved_cycle: cycle2 # Which cycle checkpoint to retain + + learning_rate_params: [1.0e-2, 1.0e-2] # Per-cycle learning rate for parametric model + learning_rate_non_params: [1.0e-1, 1.0e-1] # Per-cycle learning rate for non-parametric model + n_epochs_params: [20, 20] # Per-cycle epochs for parametric model + n_epochs_non_params: [100, 120] # Per-cycle epochs for non-parametric model ``` +**Optimizer Notes:** +- `rectified_adam` requires tensorflow-addons to be installed manually. +- If TensorFlow Addons is not installed and `rectified_adam` is requested, WaveDiff will raise a runtime error with installation instructions. +- Standard workflows (`training`, `metrics`, `plotting`) run without TensorFlow Addons. (metrics_config)= -## `metrics_config.yaml` — Metrics Configuration +## Metrics Configuration ### 1. Purpose Defines how a trained PSF model is evaluated. This configuration specifies which metrics to compute, which model weights to use, and how ground truth stars are obtained. It allows you to: + - Select a fully trained PSF model or a checkpoint for evaluation. -- Specify which training cycle’s weights to evaluate. +- Specify which training cycle's weights to evaluate. - Compute Polychromatic, Monochromatic, OPD, and Weak Lensing Shape metrics. - Use precomputed ground truth stars from the dataset if available, or automatically generate them from the configured ground truth model. - Optionally produce plots of the computed metrics via a plotting configuration file. @@ -236,7 +269,7 @@ Defines how a trained PSF model is evaluated. This configuration specifies which - WaveDiff automatically searches the dataset used for training. If the dataset contains `stars`, `SR_stars`, or `super_res_stars` fields, these are used as the ground truth for metrics evaluation. - If precomputed ground truth stars are not found in the dataset, WaveDiff regenerates them using the `ground_truth_model` parameters. **All required fields in `model_params` must be specified**; leaving them empty will prevent the metrics pipeline from running (see [Ground Truth Model Parameters](section-ground-truth-model) for details). -- The metrics evaluation can be run independently of training by specifying trained_model_path and `trained_model_config`. +- Metrics evaluation can be run independently of training by specifying both `trained_model_path` and `trained_model_config` to point to a previously trained model. - Metrics defined in [Metrics Overview table](metrics-table) are selectively computed according to their boolean flags. The Polychromatic Pixel Reconstruction metric is always computed. - The `plotting_config` parameter triggers plotting of the metrics results if a valid configuration file is provided. If left empty, metrics are computed without generating plots (see [Plotting Configuration](section-plotting-config)). - Batch size and other evaluation hyperparameters can be set under `metrics_hparams` (see [Evaluation Hyperparameters](section-evaluation-hyperparameters)) @@ -259,6 +292,8 @@ Defines how a trained PSF model is evaluated. This configuration specifies which ### 4. Top-Level Configuration Parameters +`metrics` + ```yaml metrics: model_save_path: @@ -272,25 +307,29 @@ metrics: plotting_config: ``` +**Parameter descriptions:** + +- `model_save_path`: Specifies which weights to load. Options: `psf_model` (final trained weights) or `checkpoint` (intermediate checkpoint). +- `saved_training_cycle`: Which training cycle to evaluate (e.g., `1`, `2`, ...). +- `trained_model_path`: Absolute path to the parent directory of a previously trained model. Leave empty if running `training` + `metrics` sequentially in the same workflow. +- `trained_model_config`: Filename of the training configuration (located in `/config/`). +- `eval_mono_metric`: If `True`, computes the monochromatic pixel reconstruction metric. Requires `ground_truth_model` to be configured (see [Ground Truth Model Parameters](section-ground-truth-model)). +- `eval_opd_metric`: If `True`, computes the optical path difference (OPD) metric. Requires `ground_truth_model` to be configured. +- `eval_train_shape_results_dict`: If `True`, computes Weak Lensing Shape metrics on the training dataset. +- `eval_test_shape_results_dict`: If `True`, computes Weak Lensing Shape metrics on the test dataset. +- `plotting_config`: Optional filename of a plotting configuration (e.g., `plotting_config.yaml`) to automatically generate plots after metrics evaluation. Leave empty to skip plotting. + **Notes:** -- `model_save_path`: Load final PSF model weights (`psf_model`) or checkpoint weights (`checkpoint`). -- `saved_training_cycle`: Choose which training cycle to evaluate (1, 2, …). -- `trained_model_path`: Absolute path to parent directory of previously trained model. Leave empty for training + metrics in serial. -- `trained_model_config`: Name of training config file in `trained_model_path/config/`. -- `eval_mono_metric`: If True, computes the monochromatic pixel reconstruction metric. Requires a `ground_truth_model` (see). -- `eval_opd_metric`: If True, computes the optical path difference (OPD) metric. Requires a `ground_truth_model`. -- `eval_train_shape_results_dict` / `eval_test_shape_results_dict`: Compute Weak Lensing Shape metrics on the training and/or test dataset. -- `plotting_config:` Optionally provide a `plotting_config.yaml` file to generate plots after metrics evaluation. -- **Behaviour notes:** - - Metrics controlled by flags (`eval_mono_metric`, `eval_opd_metric`, `eval_train_shape_results_dict`, `eval_test_shape_results_dict`) are only computed if their respective flags are True. - - The Polychromatic Pixel Reconstruction metric is always computed, regardless of flags. - - Future releases may allow optional `ground_truth_model` instantiation if the dataset already contains precomputed stars. + +- The Polychromatic Pixel Reconstruction metric is **always computed** regardless of flag settings. +- All other metrics (`eval_mono_metric`, `eval_opd_metric`, `eval_train_shape_results_dict`, `eval_test_shape_results_dict`) are only computed when their respective flags are set to `True`. (section-ground-truth-model)= ### 5. Ground Truth Model Parameters -Mirrors training parameters for consistency: +Specifies parameters for generating ground truth PSFs when precomputed stars are not available in the dataset. This configuration includes a subset of the training parameters — only those needed to simulate ground truth PSFs for comparison. +`metrics.ground_truth_model` ```yaml ground_truth_model: model_params: @@ -301,11 +340,10 @@ ground_truth_model: output_dim: 32 pupil_diameter: 256 LP_filter_length: 2 - use_sample_weights: True use_prior: False correct_centroids: False sigma_centroid_window: 2.5 - reference_shifts: [-1/3, -1/3] + reference_shifts: [-0.333, -0.333] obscuration_rotation_angle: 0 add_ccd_misalignments: False ccd_misalignments_input_path: @@ -314,11 +352,11 @@ ground_truth_model: sed_extrapolate: True sed_interp_kind: linear sed_sigma: 0 - x_lims: [0.0, 1.0e+3] - y_lims: [0.0, 1.0e+3] + x_lims: [0.0, 1000.0] + y_lims: [0.0, 1000.0] param_hparams: random_seed: 3877572 - l2_param: 0. + l2_param: 0.0 n_zernikes: 45 d_max: 2 save_optim_history_param: True @@ -331,32 +369,48 @@ ground_truth_model: save_optim_history_nonparam: True ``` **Notes:** -- **All fields in `model_params` are required.** Do not leave them empty. Even if the dataset contains precomputed ground truth stars, omitting `model_params` will prevent the metrics pipeline from running. -- Parameters mirror `training_config.yaml` for consistency. +- **All fields shown above are required.** Do not leave them empty. Even if the dataset contains precomputed ground truth stars, omitting these fields will prevent the metrics pipeline from running. +- This configuration uses a subset of the training parameters — telescope geometry (`tel_diameter`, `tel_focal_length`, `pix_sampling`) and sample weighting (`use_sample_weights`, `sample_weights_sigmoid`) are not required for metrics evaluation, as these are only needed during model training. +- Ground truth model parameters should match the simulation settings used to generate your dataset for meaningful comparison. - Future releases may allow optional instantiation of `ground_truth_model` when precomputed stars are available in the dataset. (section-evaluation-hyperparameters)= ### 6. Evaluation Hyperparameters +`metrics.metrics_hparams` ```yaml metrics_hparams: batch_size: 16 opt_stars_rel_pix_rmse: False - l2_param: 0. + l2_param: 0.0 output_Q: 1 output_dim: 64 + + # Optimizer configuration for metrics evaluation + optimizer: + name: 'adam' # Fixed to Adam for metrics evaluation + learning_rate: 1.0e-2 + beta_1: 0.9 + beta_2: 0.999 + epsilon: 1.0e-7 + amsgrad: False ``` -**Parameter explanations:** +**Parameter descriptions:** - `batch_size`: Number of samples processed per batch during evaluation. -- `opt_stars_rel_pix_rmse`: If `True`, saves RMSE for each individual star in addition to mean across FOV. -- `l2_param`: L2 loss weight for OPD. -- `output_Q`: Downsampling rate from high-resolution pixel modeling space. -- `output_dim`: Size of the PSF postage stamp for evaluation. - +- `opt_stars_rel_pix_rmse`: (_optional individual star RMSE_) If `True`, saves the relative pixel RMSE for each individual star in the test dataset in addition to the mean across the field of view. +- `l2_param`: L2 loss weight for the OPD metric. +- `output_Q`: Downsampling rate from the high-resolution pixel modeling space to the resolution at which PSF shapes are measured. Recommended value: `1`. +- `output_dim`: Pixel dimension of the PSF postage stamp. Should be large enough to contain most of the PSF signal. The required size depends on the `output_Q` value used. Recommended value: `64` or higher. +- `optimizer`: Optimizer configuration for metrics evaluation. Unlike training, metrics evaluation always uses the standard Adam optimizer. + - `name`: Fixed to `'adam'` (no other optimizers supported for metrics). + - `learning_rate`: Learning rate for optimizer. + - `beta_1, beta_2`: Exponential decay rates for moment estimates. + - `epsilon`: Small constant for numerical stability. + - `amsgrad`: If `True`, uses AMSGrad variant of Adam. (section-plotting-config)= -## `plotting_config.yaml` — Plot Configuration +## Plotting Configuration The `plotting_config.yaml` file defines how WaveDiff generates diagnostic plots from the metrics produced during model evaluation. While the plotting routines are mostly pre-configured internally, this file allows you to combine and compare metrics from multiple training runs, or simply visualize the results of the most recent `metrics` pipeline execution. @@ -374,49 +428,58 @@ This configuration controls how metric outputs from one or more WaveDiff runs ar - All plotting styles and figure settings are hard-coded and do not require user modification. - If the plotting task is executed immediately after a metrics evaluation run, all fields except `plot_show` may be left empty—the pipeline will automatically locate the outputs of the active run. - When plotting results from multiple runs, the entries in `metrics_dir` and `metrics_config` must appear **row-aligned**, with each position referring to the same run. -- If any descriptions are unclear, or if you encounter unexpected behavior, please open a GitHub issue (). +- If any descriptions are unclear, or if you encounter unexpected behavior, please open a [GitHub issue](). -### 3. Basic Structure +### 3. Configuration Structure -An example `plotting_config.yaml` is shown below: +`plotting_params` ```yaml plotting_params: - # Path to the parent folder containing wf-psf output directories (e.g. $WORK/wf-outputs/) - metrics_output_path: + # Path to the parent folder containing WaveDiff output directories + metrics_output_path: /path/to/wf-outputs/ - # List of output directories (e.g. wf-outputs-xxxxxxxxxxx) whose metrics should be plotted + # List of output directories whose metrics should be plotted + # Leave commented/empty if plotting immediately after a metrics run metrics_dir: - # - wf-outputs-xxxxxxxxxxx1 - # - wf-outputs-xxxxxxxxxxx2 + # - wf-outputs-xxxxxxxxxxxxxxxxxxx1 + # - wf-outputs-xxxxxxxxxxxxxxxxxxx2 - # List of the metric config filenames corresponding to each listed directory + # List of metrics config filenames corresponding to each directory + # Leave commented/empty if plotting immediately after a metrics run metrics_config: - # - metrics_config_1.yaml - # - metrics_config_2.yaml + # - metrics_config_1.yaml + # - metrics_config_2.yaml # If True, plots are shown interactively during execution plot_show: False ``` +**Parameter descriptions:** + +- `metrics_output_path`: Absolute path to the parent directory containing WaveDiff output folders (e.g., `/home/user/wf-outputs/`). Can be left as `` placeholder if plotting immediately after a metrics run. +- `metrics_dir`: List of output directory names (e.g., `wf-outputs-xxxxxxxxxxxxxxxxxxx1`) whose metrics should be included in plots. **Leave empty or commented out if plotting immediately after a metrics run** — WaveDiff will automatically locate the current run's outputs. +- `metrics_config`: List of `metrics_config.yaml` filenames corresponding to each directory in `metrics_dir`. Each entry should match the config file in `/config/`. Must be row-aligned with `metrics_dir`. **Leave empty or commented out if plotting immediately after a metrics run.** +- `plot_show`: If `True`, displays plots interactively during execution. If `False`, plots are saved to disk without display. + ### 4. Example Directory Structure Below is an example of three WaveDiff runs stored under a single parent directory: ``` wf-outputs/ -├── wf-outputs-202305271829 +├── wf-outputs-xxxxxxxxxxxxxxxxxxx1 │ ├── config │ │ ├── data_config.yaml │ │ └── metrics_config_200.yaml │ ├── metrics │ │ └── metrics-poly-coherent_euclid_200stars.npy -├── wf-outputs-202305271845 +├── wf-outputs-xxxxxxxxxxxxxxxxxxx2 │ ├── config │ │ ├── data_config.yaml │ │ └── metrics_config_500.yaml │ ├── metrics │ │ └── metrics-poly-coherent_euclid_500stars.npy -├── wf-outputs-202305271918 +├── wf-outputs-xxxxxxxxxxxxxxxxxxx3 │ ├── config │ │ ├── data_config.yaml │ │ └── metrics_config_1000.yaml @@ -429,12 +492,12 @@ To jointly plot metrics from the three runs shown above, the `plotting_config.ya ```yaml plotting_params: - metrics_output_path: $WORK/wf-outputs/ + metrics_output_path: /path/to/wf-outputs/ metrics_dir: - - wf-outputs-202305271829 - - wf-outputs-202305271845 - - wf-outputs-202305271918 + - wf-outputs-xxxxxxxxxxxxxxxxxxx1 + - wf-outputs-xxxxxxxxxxxxxxxxxxx2 + - wf-outputs-xxxxxxxxxxxxxxxxxxx3 metrics_config: - metrics_config_200.yaml @@ -445,18 +508,94 @@ plotting_params: ``` This configuration instructs the plotting pipeline to load the metrics from each listed run and include them together in summary plots. +(inference_config)= +## Inference Configuration + +### 1. Purpose +Configures the WaveDiff inference API for generating polychromatic PSFs from a trained model, given a set of source positions and SEDs. Unlike the CLI tasks, the inference API is designed for external use: users are expected to load their own positions and SEDs programmatically and interact with the API directly. + +### 2. Key Fields + +`inference` +| Field | Required | Description | +|---------------|--------------|---------| +| `batch_size` | Yes | Number of PSFs to process per batch. | +| `cycle` | Yes | Training cycle checkpoint to load (e.g. `2`).
WaveDiff training typically runs two cycles.| + +`inference.configs` + +| Field | Required | Description | +|---------------|--------------|---------| +| `trained_model_path` | Yes | Absolute path to the directory containing the trained
model. | +| `model_subdir` | Yes | Subdirectory name within `trained_model_path`
containing the model weights (e.g. model). | +|`trained_model_config_path` | Yes | Path to the training configuration file used to train the
model, relative to `trained_model_path`. | +| `data_config_path` | No. | Path to a data configuration file supplying prior
information (e.g. a Phase Diversity calibration prior)
relevant to the inference context. This may differ
from the data configuration used during training. Leave
blank if no external prior is required. + +`inference.model_params` + +These fields are optional. Any field left blank inherits its value from the trained model configuration file. Populated fields override the corresponding `model_params` values from the training config. + +| Field | Required | Description | +|---------------|--------------|---------| +| `n_bins_lda` | inherited | Number of wavelength bins used to reconstruct polychromatic PSFs.| +| `output_Q` | inherited | Downsampling rate to match the oversampled model to the telescope's
native sampling. | +| `output_dim` | inherited | Pixel dimension of the output PSF postage stamp. | +| `correct_centroids` | False | If `True`, applies centroid error correction within the PSF model during inference.. | +| `add_ccd_misalignments` | False | If `True`, incorporates CCD misalignment corrections into
the PSF model during inference. Required data is retrieved
from the trained model configuration file. | + +### 3. Example + +```yaml +inference: + batch_size: 16 + cycle: 2 + configs: + trained_model_path: /path/to/trained/model/ + model_subdir: model + trained_model_config_path: config/training_config.yaml + data_config_path: + model_params: + n_bins_lda: 8 + output_Q: 1 + output_dim: 64 + correct_centroids: False + add_ccd_misalignments: True +``` + +### 4. Notes + +- `trained_model_config_path` is relative to `trained_model_path`, not to the working directory. +- All `model_params` fields are optional; omitting them inherits values from the training configuration. - Only populate fields where you explicitly want to override the trained model's parameters. +- `data_config_path` is intended for cases where inference is performed in a different data context than training, for example using an updated or alternative prior. Leave blank if the trained model's own configuration is sufficient. +- `correct_centroids` and `add_ccd_misalignments` are independent model behaviour flags that modify PSF model computation during inference. Both retrieve their required data from the trained model configuration file — no additional configuration is required to enable them. + + (master_config_file)= ## Master Configuration ### 1. Purpose -The `configs.yaml` file is the _master controller_ for WaveDiff. -It defines **which pipeline tasks** should be executed (training, metrics evaluation, plotting) and in which order. +The `configs.yaml` file is the master controller for WaveDiff CLI tasks. It defines **which pipeline tasks** should be executed (`training`, `metrics`, `plotting`) and in which order. Each task entry points to a dedicated YAML configuration file, allowing WaveDiff to run multiple jobs sequentially from a single entry point. Each task points to a dedicated YAML configuration file—allowing WaveDiff to run multiple jobs sequentially using a single entry point. -### 2. Example: Multiple Training Runs +### 2. General Notes + +`configs.yaml` may contain any combination of the three CLI task types: + +- `training` +- `metrics` +- `plotting` + +-Tasks always execute **in the order they appear** in the file. +- The current release runs all jobs sequentially on a single GPU. +- Parallel multi-GPU execution is planned for a future version. +- For questions or feedback, please open a [GitHub issue](). + +### 3. Example: Multiple Training Runs + To launch a sequence of training runs (models 1…n), list each task and its corresponding configuration file: +`configs.yaml` ```yaml --- training_conf_1: training_config_1.yaml @@ -464,10 +603,11 @@ To launch a sequence of training runs (models 1…n), list each task and its cor ... training_conf_n: training_config_n.yaml ``` -Outputs will be organized as: + +WaveDiff will execute each training task sequentially and organize outputs as: ``` -wf-outputs-20231119151932213823/ +wf-outputs-xxxxxxxxxxxxxxxxxxx1/ ├── checkpoint/ │ ├── checkpoint_callback_poly-coherent_euclid_200stars_1_cycle1.* │ ├── ... @@ -486,29 +626,34 @@ wf-outputs-20231119151932213823/ └── psf_model_poly-coherent_euclid_200stars_n_cycle1.* ``` -### 3 Example: Training + Metrics + Plotting -To evaluate metrics and generate plots for each trained model, include the corresponding configuration files: +### 4. Example: Training + Metrics + Plotting +To evaluate metrics and generate plots after each training run, include metrics and plotting tasks in +`configs.yaml`: + +``` +training_conf_1: training_config_1.yaml +metrics_conf_1: metrics_config_1.yaml +plotting_conf_1: plotting_config_1.yaml +training_conf_2: training_config_2.yaml +metrics_conf_2: metrics_config_2.yaml +plotting_conf_2: plotting_config_2.yaml +... +``` + +Required configuration files: ``` config/ ├── configs.yaml ├── data_config.yaml -├── metrics_config.yaml -├── plotting_config.yaml ├── training_config_1.yaml -├── ... -└── training_config_n.yaml +├── metrics_config_1.yaml +├── plotting_config_1.yaml +├── training_config_2.yaml +├── metrics_config_2.yaml +├── plotting_config_2.yaml +└── ... ``` -Note: current WaveDiff versions generate one plot per metric per model. Creating combined plots requires a separate run [Plot Configuration](section-plotting-config). A future update will support automatic combined plots. - -### 4 General Notes +**Note:** Current WaveDiff versions generate one plot per metric per model. Creating combined comparison plots across multiple runs requires a separate plotting-only run (see [Plot Configuration](section-plotting-config)). Automatic combined plots may be supported in a future release. -- `configs.yaml` may contain **any combination** of the three task types: - - `training` - - `metrics` - - `plotting` -- Tasks always execute **in the order they appear** in the file. -- The current release runs all jobs on a single GPU, sequentially. -- Parallel multi-GPU execution is planned for a future version. -- For questions or feedback, please open a [GitHub issue](https://github.com/CosmoStat/wf-psf/issues/new). diff --git a/docs/source/dependencies.md b/docs/source/dependencies.md index b1bef1da..0dea4934 100644 --- a/docs/source/dependencies.md +++ b/docs/source/dependencies.md @@ -10,7 +10,6 @@ Third-party software packages required by WaveDiff are installed automatically ( | [scipy](https://scipy.org) | {cite:t}`SciPy-NMeth:20` | | [keras](https://keras.io) | {cite:t}`chollet:2015keras`| | [tensorflow](https://www.tensorflow.org) | {cite:t}`tensorflow:15` | -| [tensorflow-addons](https://www.tensorflow.org/addons) |{cite:t}`tensorflow:15` | | [tensorflow-estimator](https://www.tensorflow.org/api_docs/python/tf/estimator) |{cite:t}`tensorflow:15` | | [zernike](https://github.com/jacopoantonello/zernike) | {cite:t}`Antonello:15` | | [opencv-python](https://docs.opencv.org/4.x/index.html) | {cite:t}`opencv_library:08` | @@ -19,4 +18,26 @@ Third-party software packages required by WaveDiff are installed automatically ( | [astropy](https://www.astropy.org) | {cite:t}`astropy:13,astropy:18`,
{cite:t}`astropy:22` | | [matplotlib](https://matplotlib.org) | {cite:t}`Hunter:07` | | [pandas](https://pandas.pydata.org) | {cite:t}`mckinney:2010pandas` | -| [seaborn](https://seaborn.pydata.org) | {cite:t}`Waskom:21` | \ No newline at end of file +| [seaborn](https://seaborn.pydata.org) | {cite:t}`Waskom:21` | + +## Optional Dependencies + +Some features in WaveDiff rely on optional third-party packages that are **not required for standard training and evaluation workflows**. + +### TensorFlow Addons (Optional) + +| Package Name | Purpose | +|--------------|---------| +| [tensorflow-addons](https://www.tensorflow.org/addons) | Optional optimizers (e.g. RectifiedAdam) | + +Starting with WaveDiff **v3.1.0**, `tensorflow-addons` is no longer a required dependency, as TensorFlow Addons reached end-of-life in May 2024. + +- By default, WaveDiff uses standard Keras/TensorFlow optimizers (e.g. `Adam`) +- TensorFlow Addons is only imported **at runtime** if explicitly requested in the configuration +- If a TensorFlow Addons optimizer is selected and the package is not installed, WaveDiff will raise a clear runtime error + +To use TensorFlow Addons optimizers, install manually: + +```bash +pip install tensorflow-addons +``` \ No newline at end of file diff --git a/environment.yml b/environment.yml index 1b9264e4..66f9873c 100644 --- a/environment.yml +++ b/environment.yml @@ -12,7 +12,6 @@ dependencies: - pip - pip: - numpy>=1.26,<2.0 - - tensorflow-addons - tensorflow-estimator - zernike - opencv-python diff --git a/pyproject.toml b/pyproject.toml index 187ead5e..0e22269c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ maintainers = [ description = 'A software framework to perform Differentiable wavefront-based PSF modelling.' dependencies = [ - "numpy>=1.26.4,<2.0", + "numpy>=1.18,<1.24", "scipy", "tensorflow==2.11.0", "tensorflow-estimator", @@ -24,7 +24,7 @@ dependencies = [ "seaborn", ] -version = "3.0.0" +version = "3.1.0" [project.optional-dependencies] docs = [ @@ -88,8 +88,15 @@ quote-style = "double" indent-style = "space" line-ending = "lf" +[tool.setuptools] +package-dir = {"" = "src"} + +[tool.setuptools.packages.find] +where = ["src"] +include = ["wf_psf*"] + [tool.setuptools.package-data] -"wf_psf.config" = ["*.conf"] +"wf_psf.config" = ["*.conf", "*.yaml"] # Set per-file-ignores [tool.ruff.lint.per-file-ignores] diff --git a/src/wf_psf/__init__.py b/src/wf_psf/__init__.py index d4394f09..988b02fe 100644 --- a/src/wf_psf/__init__.py +++ b/src/wf_psf/__init__.py @@ -2,6 +2,6 @@ # Dynamically import modules to trigger side effects when wf_psf is imported importlib.import_module("wf_psf.psf_models.psf_models") -importlib.import_module("wf_psf.psf_models.psf_model_semiparametric") -importlib.import_module("wf_psf.psf_models.psf_model_physical_polychromatic") -importlib.import_module("wf_psf.psf_models.tf_psf_field") +importlib.import_module("wf_psf.psf_models.models.psf_model_semiparametric") +importlib.import_module("wf_psf.psf_models.models.psf_model_physical_polychromatic") +importlib.import_module("wf_psf.psf_models.tf_modules.tf_psf_field") diff --git a/src/wf_psf/utils/centroids.py b/src/wf_psf/data/centroids.py similarity index 80% rename from src/wf_psf/utils/centroids.py rename to src/wf_psf/data/centroids.py index 8b4522bf..01135428 100644 --- a/src/wf_psf/utils/centroids.py +++ b/src/wf_psf/data/centroids.py @@ -8,10 +8,80 @@ import numpy as np import scipy.signal as scisig -from wf_psf.utils.preprocessing import shift_x_y_to_zk1_2_wavediff +from fractions import Fraction from typing import Optional +def compute_centroid_correction( + model_params, centroid_dataset, batch_size: int = 1 +) -> np.ndarray: + """Compute centroid corrections using Zernike polynomials. + + This function calculates the Zernike contributions required to match the centroid + of the WaveDiff PSF model to the observed star centroids, processing in batches. + + Parameters + ---------- + model_params : RecursiveNamespace + An object containing parameters for the PSF model, including pixel sampling and initial centroid window parameters. + + centroid_dataset : dict + Dictionary containing star data needed for centroiding: + - "stamps" : np.ndarray + Array of star postage stamps (required). + - "masks" : Optional[np.ndarray] + Array of star masks (optional, can be None). + + batch_size : int, optional + The batch size to use when processing the stars. Default is 16. + + Returns + ------- + zernike_centroid_array : np.ndarray + A 2D NumPy array of shape `(n_stars, 3)`, where `n_stars` is the number of + observed stars. The array contains the computed Zernike (Z1, Z2) contributions, + with zero padding applied to the first column to ensure a consistent shape. + """ + # Retrieve stamps and masks from centroid_dataset + star_postage_stamps = centroid_dataset.get("stamps") + star_masks = centroid_dataset.get("masks") # may be None + + if star_postage_stamps is None: + raise ValueError("centroid_dataset must contain 'stamps'") + + pix_sampling = model_params.pix_sampling * 1e-6 # Change units from [um] to [m] + + reference_shifts = [ + float(Fraction(value)) for value in model_params.reference_shifts + ] + + n_stars = len(star_postage_stamps) + zernike_centroid_array = [] + + # Batch process the stars + for i in range(0, n_stars, batch_size): + batch_postage_stamps = star_postage_stamps[i : i + batch_size] + batch_masks = star_masks[i : i + batch_size] if star_masks is not None else None + + # Compute Zernike 1 and Zernike 2 for the batch + zk1_2_batch = -1.0 * compute_zernike_tip_tilt( + batch_postage_stamps, batch_masks, pix_sampling, reference_shifts + ) + + # Zero pad array for each batch and append + zernike_centroid_array.append( + np.pad( + zk1_2_batch, + pad_width=[(0, 0), (1, 0)], + mode="constant", + constant_values=0, + ) + ) + + # Combine all batches into a single array + return np.concatenate(zernike_centroid_array, axis=0) + + def compute_zernike_tip_tilt( star_images: np.ndarray, star_masks: Optional[np.ndarray] = None, @@ -58,6 +128,8 @@ def compute_zernike_tip_tilt( - This function processes all images at once using vectorized operations. - The Zernike coefficients are computed in the WaveDiff convention. """ + from wf_psf.data.data_zernike_utils import shift_x_y_to_zk1_2_wavediff + # Vectorize the centroid computation centroid_estimator = CentroidEstimator( im=star_images, mask=star_masks, sigma_init=sigma_init, n_iter=n_iter @@ -178,6 +250,18 @@ def __init__( self, im, mask=None, sigma_init=7.5, n_iter=5, auto_run=True, xc=None, yc=None ): """Initialize class attributes.""" + # Convert to np.ndarray if not already + im = np.asarray(im) + if mask is not None: + mask = np.asarray(mask) + + # Check im dimensions convert to batch, if 2D + if im.ndim == 2: + # Single stamp → convert to batch of one + im = np.expand_dims(im, axis=0) + elif im.ndim != 3: + raise ValueError(f"Expected 2D or 3D input, got shape {im.shape}") + self.im = im self.mask = mask if self.mask is not None: diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py new file mode 100644 index 00000000..60051645 --- /dev/null +++ b/src/wf_psf/data/data_handler.py @@ -0,0 +1,454 @@ +"""Data Handler Module. + +Provides tools for loading, preprocessing, and managing data used in both +training and inference workflows. + +Includes: + +- The `DataHandler` class for managing datasets and associated metadata +- Utility functions for loading structured data products +- Preprocessing routines for spectral energy distributions (SEDs), including + format conversion (e.g., to TensorFlow) and transformations + +This module serves as a central interface between raw data and modeling components. + +Authors: Jennifer Pollack , Tobias Liaudat +""" + +import os +import numpy as np +import wf_psf.utils.utils as utils +from wf_psf.psf_models.tf_modules.tf_utils import ensure_tensor +import tensorflow as tf +from typing import Optional, Union +import logging + +logger = logging.getLogger(__name__) + + +class DataHandler: + """ + DataHandler for WaveDiff PSF modeling. + + This class manages loading, preprocessing, and TensorFlow conversion of datasets used + for PSF model training, testing, and inference in the WaveDiff framework. + + Parameters + ---------- + dataset_type : str + Indicates the dataset mode ("train", "test", or "inference"). + data_params : RecursiveNamespace + Configuration object containing dataset parameters (e.g., file paths, preprocessing flags). + simPSF : PSFSimulator + An instance of the PSFSimulator class used to encode SEDs into a TensorFlow-compatible format. + n_bins_lambda : int + Number of wavelength bins used to discretize SEDs. + load_data : bool, optional + If True (default), loads and processes data during initialization. If False, data loading + must be triggered explicitly. + dataset : dict or list, optional + If provided, uses this pre-loaded dataset instead of triggering automatic loading. + sed_data : dict or list, optional + If provided, uses this SED data directly instead of extracting it from the dataset. + + Attributes + ---------- + dataset_type : str + Indicates the dataset mode ("train", "test", or "inference"). + data_params : RecursiveNamespace + Configuration parameters for data access and structure. + simPSF : PSFSimulator + Simulator used to transform SEDs into TensorFlow-ready tensors. + n_bins_lambda : int + Number of wavelength bins in the SED representation. + load_data_on_init : bool + Whether data was loaded automatically during initialization. + dataset : dict + Loaded dataset including keys such as 'positions', 'stars', 'noisy_stars', or similar. + sed_data : tf.Tensor + TensorFlow-formatted SED data with shape [batch_size, n_bins_lambda, features]. + """ + + def __init__( + self, + dataset_type, + data_params, + simPSF, + n_bins_lambda, + load_data: bool = True, + dataset: Optional[Union[dict, list]] = None, + sed_data: Optional[Union[dict, list]] = None, + ): + """ + Initialize the DataHandler for PSF dataset preparation. + + This constructor sets up the dataset handler used for PSF simulation tasks, + such as training, testing, or inference. It supports three modes of use: + + 1. **Manual mode** (`load_data=False`, no `dataset`): data loading and SED processing + must be triggered manually via `load_dataset()` and `process_sed_data()`. + 2. **Pre-loaded dataset mode** (`dataset` is provided): the given dataset is used directly, + and `process_sed_data()` is called with either the given `sed_data` or `dataset["SEDs"]`. + 3. **Automatic loading mode** (`load_data=True` and no `dataset`): the dataset is loaded + from disk using `data_params`, and SEDs are extracted and processed automatically. + + Parameters + ---------- + dataset_type : str + One of {"train", "test", "inference"} indicating dataset usage. + data_params : RecursiveNamespace + Configuration object with paths, preprocessing options, and metadata. + simPSF : PSFSimulator + Used to convert SEDs to TensorFlow format. + n_bins_lambda : int + Number of wavelength bins for the SEDs. + load_data : bool, optional + Whether to automatically load and process the dataset (default: True). + dataset : dict or list, optional + A pre-loaded dataset to use directly (overrides `load_data`). + sed_data : array-like, optional + Pre-loaded SED data to use directly. If not provided but `dataset` is, + SEDs are taken from `dataset["SEDs"]`. + + Raises + ------ + ValueError + If SEDs cannot be found in either `dataset` or as `sed_data`. + + Notes + ----- + - `self.dataset` and `self.sed_data` are both `None` if neither `dataset` nor + `load_data=True` is used. + - TensorFlow conversion is performed at the end of initialization via `convert_dataset_to_tensorflow()`. + """ + self.dataset_type = dataset_type + self.data_params = data_params + self.simPSF = simPSF + self.n_bins_lambda = n_bins_lambda + self.load_data_on_init = load_data + + if dataset is not None: + self.dataset = dataset + self.process_sed_data(sed_data) + self.validate_and_process_dataset() + elif self.load_data_on_init: + self.load_dataset() + self.process_sed_data(self.dataset["SEDs"]) + self.validate_and_process_dataset() + else: + self.dataset = None + self.sed_data = None + + @property + def tf_positions(self): + """Get positions as TensorFlow tensor.""" + return ensure_tensor(self.dataset["positions"]) + + def load_dataset(self): + """Load dataset. + + Load the dataset based on the specified dataset type. + + """ + self.dataset = np.load( + os.path.join(self.data_params.data_dir, self.data_params.file), + allow_pickle=True, + )[()] + + def validate_and_process_dataset(self): + """Validate the dataset structure and convert fields to TensorFlow tensors.""" + self._validate_dataset_structure() + self._convert_dataset_to_tensorflow() + + def _validate_dataset_structure(self): + """Validate dataset structure based on dataset_type.""" + if self.dataset is None: + raise ValueError("Dataset is None") + + if "positions" not in self.dataset: + raise ValueError("Dataset missing required field: 'positions'") + + if self.dataset_type == "training": + if "noisy_stars" not in self.dataset: + raise ValueError( + f"Missing required field 'noisy_stars' in {self.dataset_type} dataset." + ) + elif self.dataset_type == "test": + if "stars" not in self.dataset: + raise ValueError( + f"Missing required field 'stars' in {self.dataset_type} dataset." + ) + elif self.dataset_type == "inference": + pass + else: + raise ValueError(f"Unrecognized dataset_type: {self.dataset_type}") + + def _convert_dataset_to_tensorflow(self): + """Convert dataset to TensorFlow tensors.""" + self.dataset["positions"] = ensure_tensor( + self.dataset["positions"], dtype=tf.float32 + ) + + if self.dataset_type == "train": + self.dataset["noisy_stars"] = ensure_tensor( + self.dataset["noisy_stars"], dtype=tf.float32 + ) + elif self.dataset_type == "test": + self.dataset["stars"] = ensure_tensor( + self.dataset["stars"], dtype=tf.float32 + ) + + def process_sed_data(self, sed_data): + """ + Generate and process SED (Spectral Energy Distribution) data. + + This method transforms raw SED inputs into TensorFlow tensors suitable for model input. + It generates wavelength-binned SED elements using the PSF simulator, converts the result + into a tensor, and transposes it to match the expected shape for training or inference. + + Parameters + ---------- + sed_data : list or array-like + A list or array of raw SEDs, where each SED is typically a vector of flux values + or coefficients. These will be processed using the PSF simulator. + + Raises + ------ + ValueError + If `sed_data` is None. + + Notes + ----- + The resulting tensor is stored in `self.sed_data` and has shape + `(num_samples, n_bins_lambda, n_components)`, where: + - `num_samples` is the number of SEDs, + - `n_bins_lambda` is the number of wavelength bins, + - `n_components` is the number of components per SED (e.g., filters or basis terms). + + The intermediate tensor is created with `tf.float64` for precision during generation, + but is converted to `tf.float32` after processing for use in training. + """ + if sed_data is None: + raise ValueError("SED data must be provided explicitly or via dataset.") + + self.sed_data = [ + utils.generate_SED_elems_in_tensorflow( + _sed, self.simPSF, n_bins=self.n_bins_lambda, tf_dtype=tf.float64 + ) + for _sed in sed_data + ] + # Convert list of generated SED tensors to a single TensorFlow tensor of float32 dtype + self.sed_data = ensure_tensor(self.sed_data, dtype=tf.float32) + self.sed_data = tf.transpose(self.sed_data, perm=[0, 2, 1]) + + +def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: + """ + Extract and concatenate star-related data from training and test datasets. + + This function retrieves arrays (e.g., postage stamps, masks, positions) from + both the training and test datasets using the specified keys, converts them + to NumPy if necessary, and concatenates them along the first axis. + + Parameters + ---------- + data : DataConfigHandler + Object containing training and test datasets. + train_key : str + Key to retrieve data from the training dataset + (e.g., 'noisy_stars', 'masks'). + test_key : str + Key to retrieve data from the test dataset + (e.g., 'stars', 'masks'). + + Returns + ------- + np.ndarray + Concatenated NumPy array containing the selected data from both + training and test sets. + + Raises + ------ + KeyError + If either the training or test dataset does not contain the + requested key. + + Notes + ----- + - Designed for datasets with separate train/test splits, such as when + evaluating metrics on held-out data. + - TensorFlow tensors are automatically converted to NumPy arrays. + - Requires eager execution if TensorFlow tensors are present. + """ + # Ensure the requested keys exist in both training and test datasets + missing_keys = [ + key + for key, dataset in [ + (train_key, data.training_data.dataset), + (test_key, data.test_data.dataset), + ] + if key not in dataset + ] + + if missing_keys: + raise KeyError(f"Missing keys in dataset: {missing_keys}") + + # Retrieve data from training and test sets + train_data = data.training_data.dataset[train_key] + test_data = data.test_data.dataset[test_key] + + # Convert to NumPy if necessary + if tf.is_tensor(train_data): + train_data = train_data.numpy() + if tf.is_tensor(test_data): + test_data = test_data.numpy() + + # Concatenate and return + return np.concatenate((train_data, test_data), axis=0) + + +def get_data_array( + data, + run_type: str, + key: str = None, + train_key: str = None, + test_key: str = None, + allow_missing: bool = True, +) -> Optional[np.ndarray]: + """ + Retrieve data from dataset depending on run type. + + This function provides a unified interface for accessing data across different + execution contexts (training, simulation, metrics, inference). It handles + key resolution with sensible fallbacks and optional missing data tolerance. + + Parameters + ---------- + data : DataConfigHandler + Dataset object containing training, test, or inference data. + Expected to have methods compatible with the specified run_type. + run_type : {"training", "simulation", "metrics", "inference"} + Execution context that determines how data is retrieved: + + - "training", "simulation", "metrics": Uses extract_star_data function + - "inference": Retrieves data directly from dataset using key lookup + + key : str, optional + Primary key for data lookup. Used directly for inference run_type. + If None, falls back to train_key value. Default is None. + train_key : str, optional + Key for training dataset access. If None and key is provided, + defaults to key value. Default is None. + test_key : str, optional + Key for test dataset access. If None, defaults to the resolved + train_key value. Default is None. + allow_missing : bool, default True + Control behavior when data is missing or keys are not found: + + - True: Return None instead of raising exceptions + - False: Raise appropriate exceptions (KeyError, ValueError) + + Returns + ------- + np.ndarray or None + Retrieved data as NumPy array. Returns None only when allow_missing=True + and the requested data is not available. + + Raises + ------ + ValueError + If run_type is not one of the supported values, or if no key can be + resolved for the operation and allow_missing=False. + KeyError + If the specified key is not found in the dataset and allow_missing=False. + + Notes + ----- + Key resolution follows this priority order: + + 1. train_key = train_key or key + 2. test_key = test_key or resolved_train_key + 3. key = key or resolved_train_key (for inference fallback) + + For TensorFlow tensors, the .numpy() method is called to convert to NumPy. + Other data types are converted using np.asarray(). + + Examples + -------- + >>> # Training data retrieval + >>> train_data = get_data_array(data, "training", train_key="noisy_stars") + + >>> # Inference with fallback handling + >>> inference_data = get_data_array(data, "inference", key="positions", + ... allow_missing=True) + >>> if inference_data is None: + ... print("No inference data available") + + >>> # Using key parameter for both train and inference + >>> result = get_data_array(data, "inference", key="positions") + """ + # Validate run_type early + valid_run_types = {"training", "simulation", "metrics", "inference"} + if run_type not in valid_run_types: + raise ValueError(f"run_type must be one of {valid_run_types}, got '{run_type}'") + + # Simplify key resolution with clear precedence + effective_train_key = train_key or key + effective_test_key = test_key or effective_train_key + effective_key = key or effective_train_key + + try: + if run_type in {"simulation", "training", "metrics"}: + return extract_star_data(data, effective_train_key, effective_test_key) + else: # inference + return _get_direct_data(data, effective_key, allow_missing) + except Exception: + if allow_missing: + return None + raise + + +def _get_direct_data(data, key: str, allow_missing: bool) -> Optional[np.ndarray]: + """ + Extract data directly with proper error handling and type conversion. + + Parameters + ---------- + data : DataConfigHandler + Dataset object with a .dataset attribute that supports .get() method. + key : str or None + Key to lookup in the dataset. If None, behavior depends on allow_missing. + allow_missing : bool + If True, return None for missing keys/data instead of raising exceptions. + + Returns + ------- + np.ndarray or None + Data converted to NumPy array, or None if allow_missing=True and + data is unavailable. + + Raises + ------ + ValueError + If key is None and allow_missing=False. + KeyError + If key is not found in dataset and allow_missing=False. + + Notes + ----- + Conversion logic: + - TensorFlow tensors: Converted using .numpy() method + - Other types: Converted using np.asarray() + """ + if key is None: + if allow_missing: + return None + raise ValueError("No key provided for inference data") + + value = data.dataset.get(key, None) + if value is None: + if allow_missing: + return None + raise KeyError(f"Key '{key}' not found in inference dataset") + + return value.numpy() if tf.is_tensor(value) else np.asarray(value) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py new file mode 100644 index 00000000..97837558 --- /dev/null +++ b/src/wf_psf/data/data_zernike_utils.py @@ -0,0 +1,481 @@ +"""Utilities for Zernike Data Handling. + +This module provides utility functions for working with Zernike coefficients, including: + +- Prior generation +- Data loading +- Conversions between physical displacements (e.g., defocus, centroid shifts) and modal Zernike coefficients +- Conversions between physical displacements (e.g., defocus, centroid shifts) and modal Zernike coefficients + +Useful in contexts where Zernike representations are used to model optical aberrations or link physical misalignments to wavefront modes. + +:Author: Tobias Liaudat + +""" + +from dataclasses import dataclass +from typing import Optional, Union +import numpy as np +import tensorflow as tf +from wf_psf.data.centroids import compute_centroid_correction +from wf_psf.data.data_handler import get_data_array +from wf_psf.instrument.ccd_misalignments import compute_ccd_misalignment +from wf_psf.utils.read_config import RecursiveNamespace +import logging + +logger = logging.getLogger(__name__) + + +@dataclass +class ZernikeInputs: + """Zernike-related inputs for PSF modeling, including priors and datasets for corrections. + + All fields are optional to allow flexibility across different run types (training, simulation, inference) and configurations. + + Parameters + ---------- + zernike_prior : Optional[np.ndarray] + The true Zernike prior, if provided (e.g., from PDC). Can be None if not used or not available. + centroid_dataset : Optional[Union[dict, "RecursiveNamespace"]] + Dataset used for computing centroid corrections. Should contain both training and test sets if + used. Can be None if centroid correction is not enabled or no dataset is available. + misalignment_positions : Optional[np.ndarray] + Positions used for computing CCD misalignment corrections. Should be available in inference mode if misalignment correction is enabled. Can be None if not used or not available. + """ + + zernike_prior: Optional[np.ndarray] # true prior, if provided (e.g. from PDC) + centroid_dataset: Optional[ + Union[dict, "RecursiveNamespace"] + ] # only used in training/simulation + misalignment_positions: Optional[np.ndarray] # needed for CCD corrections + + +class ZernikeInputsFactory: + """Factory class to build ZernikeInputs based on run type and dataset configuration. + + This class abstracts the logic of extracting the relevant Zernike-related inputs from the dataset based on the specified run type (training, simulation, inference) and model parameters. It handles the conditional logic for which inputs are needed and how to extract them, providing a clean interface for constructing the ZernikeInputs dataclass instance. + + """ + + @staticmethod + def build( + data, run_type: str, model_params, prior: Optional[np.ndarray] = None + ) -> ZernikeInputs: + """Build a ZernikeInputs dataclass instance based on run type and data. + + Parameters + ---------- + data : Union[dict, DataConfigHandler] + Dataset object containing star positions, priors, and optionally pixel data. + run_type : str + One of 'training', 'simulation', or 'inference'. + model_params : RecursiveNamespace + Model parameters, including flags for prior/corrections. + prior : Optional[np.ndarray] + An explicitly passed prior (overrides any inferred one if provided). + + Returns + ------- + ZernikeInputs + """ + centroid_dataset, positions = None, None + + if run_type in {"training", "simulation", "metrics"}: + stamps = get_data_array( + data, run_type, train_key="noisy_stars", test_key="stars" + ) + masks = get_data_array(data, run_type, key="masks", allow_missing=True) + centroid_dataset = {"stamps": stamps, "masks": masks} + + positions = get_data_array(data=data, run_type=run_type, key="positions") + + if model_params.use_prior: + if prior is not None: + logger.warning( + "Explicit prior provided; ignoring dataset-based prior." + ) + else: + prior = get_np_zernike_prior(data) + + elif run_type == "inference": + stamps = get_data_array(data=data, run_type=run_type, key="sources") + masks = get_data_array(data, run_type, key="masks", allow_missing=True) + centroid_dataset = {"stamps": stamps, "masks": masks} + + positions = get_data_array(data=data, run_type=run_type, key="positions") + + if model_params.use_prior: + # Try to extract prior from `data`, if present + prior = ( + getattr(data.dataset, "zernike_prior", None) + if not isinstance(data.dataset, dict) + else data.dataset.get("zernike_prior") + ) + + if prior is None: + logger.warning( + "model_params.use_prior=True but no prior found in inference data. Proceeding with None." + ) + + else: + raise ValueError(f"Unsupported run_type: {run_type}") + + return ZernikeInputs( + zernike_prior=prior, + centroid_dataset=centroid_dataset, + misalignment_positions=positions, + ) + + +def get_np_zernike_prior(data): + """Get the zernike prior from the provided dataset. + + This method concatenates the stars from both the training + and test datasets to obtain the full prior. + + Parameters + ---------- + data : DataConfigHandler + Object containing training and test datasets. + + Returns + ------- + zernike_prior : np.ndarray + Numpy array containing the full prior. + """ + zernike_prior = np.concatenate( + ( + data.training_data.dataset["zernike_prior"], + data.test_data.dataset["zernike_prior"], + ), + axis=0, + ) + + return zernike_prior + + +def pad_contribution_to_order(contribution: np.ndarray, max_order: int) -> np.ndarray: + """Pad a Zernike contribution array to the max Zernike order.""" + current_order = contribution.shape[1] + pad_width = ((0, 0), (0, max_order - current_order)) + return np.pad(contribution, pad_width=pad_width, mode="constant", constant_values=0) + + +def combine_zernike_contributions(contributions: list[np.ndarray]) -> np.ndarray: + """Combine multiple Zernike contributions, padding each to the max order before summing.""" + if not contributions: + raise ValueError("No contributions provided.") + + if len(contributions) == 1: + return contributions[0] + + max_order = max(contrib.shape[1] for contrib in contributions) + n_samples = contributions[0].shape[0] + + if any(c.shape[0] != n_samples for c in contributions): + raise ValueError("All contributions must have the same number of samples.") + + combined = np.zeros((n_samples, max_order)) + # Pad each contribution to the max order and sum them + for contrib in contributions: + padded = pad_contribution_to_order(contrib, max_order) + combined += padded + + return combined + + +def pad_tf_zernikes(zk_param: tf.Tensor, zk_prior: tf.Tensor, n_zks_total: int): + """ + Pad the Zernike coefficient tensors to match the specified total number of Zernikes. + + Parameters + ---------- + zk_param : tf.Tensor + Zernike coefficients for the parametric part. Shape [batch, n_zks_param, 1, 1]. + zk_prior : tf.Tensor + Zernike coefficients for the prior part. Shape [batch, n_zks_prior, 1, 1]. + n_zks_total : int + Total number of Zernikes to pad to. + + Returns + ------- + padded_zk_param : tf.Tensor + Padded Zernike coefficients for the parametric part. Shape [batch, n_zks_total, 1, 1]. + padded_zk_prior : tf.Tensor + Padded Zernike coefficients for the prior part. Shape [batch, n_zks_total, 1, 1]. + """ + pad_num_param = n_zks_total - tf.shape(zk_param)[1] + pad_num_prior = n_zks_total - tf.shape(zk_prior)[1] + + padded_zk_param = tf.cond( + tf.not_equal(pad_num_param, 0), + lambda: tf.pad(zk_param, [(0, 0), (0, pad_num_param), (0, 0), (0, 0)]), + lambda: zk_param, + ) + + padded_zk_prior = tf.cond( + tf.not_equal(pad_num_prior, 0), + lambda: tf.pad(zk_prior, [(0, 0), (0, pad_num_prior), (0, 0), (0, 0)]), + lambda: zk_prior, + ) + + return padded_zk_param, padded_zk_prior + + +def assemble_zernike_contributions( + model_params, + zernike_prior=None, + centroid_dataset=None, + positions=None, + batch_size=16, +): + """Assemble Zernike contributions from prior, centroid correction, and CCD misalignment. + + This function checks the model parameters to determine which contributions to include, computes each contribution as needed, and combines them into a single Zernike contribution tensor. It handles the logic for when certain contributions are not used or not available, ensuring that the final output is correctly shaped and contains the appropriate information based on the configuration. + + Parameters + ---------- + model_params : RecursiveNamespace + Parameters controlling which contributions to apply. + zernike_prior : Optional[np.ndarray or tf.Tensor] + The precomputed Zernike prior. Can be either a NumPy array or a TensorFlow tensor. + If a Tensor, will be converted to NumPy in eager mode. + centroid_dataset : Optional[object] + Dataset used to compute centroid correction. Must have both training and test sets. + positions : Optional[np.ndarray or tf.Tensor] + Positions used for computing CCD misalignment. Must be available in inference mode. + batch_size : int + Batch size for centroid correction. + + Returns + ------- + tf.Tensor + A tensor representing the full Zernike contribution map. + """ + zernike_contribution_list = [] + + # Prior + if model_params.use_prior and zernike_prior is not None: + logger.info("Adding Zernike prior...") + if isinstance(zernike_prior, tf.Tensor): + if tf.executing_eagerly(): + zernike_prior = zernike_prior.numpy() + else: + raise RuntimeError( + "Zernike prior is a TensorFlow tensor but eager execution is disabled. " + "Cannot call `.numpy()` outside of eager mode." + ) + + elif not isinstance(zernike_prior, np.ndarray): + raise TypeError( + "Unsupported zernike_prior type. Must be np.ndarray or tf.Tensor." + ) + zernike_contribution_list.append(zernike_prior) + else: + logger.info("Skipping Zernike prior (not used or not provided).") + + # Centroid correction (tip/tilt) + if model_params.correct_centroids and centroid_dataset is not None: + logger.info("Computing centroid correction...") + centroid_correction = compute_centroid_correction( + model_params, centroid_dataset, batch_size=batch_size + ) + zernike_contribution_list.append(centroid_correction) + else: + logger.info("Skipping centroid correction (not enabled or no dataset).") + + # CCD misalignment (focus term) + if model_params.add_ccd_misalignments and positions is not None: + logger.info("Computing CCD misalignment correction...") + ccd_misalignment = compute_ccd_misalignment(model_params, positions) + zernike_contribution_list.append(ccd_misalignment) + else: + logger.info( + "Skipping CCD misalignment correction (not enabled or no positions)." + ) + + # If no contributions, return zeros tensor to avoid crashes + if not zernike_contribution_list: + logger.warning("No Zernike contributions found. Returning zero tensor.") + # Infer batch size and zernike order from model_params + n_samples = 1 + n_zks = getattr(model_params.param_hparams, "n_zernikes", 10) + return tf.zeros((n_samples, n_zks), dtype=tf.float32) + + combined_zernike_prior = combine_zernike_contributions(zernike_contribution_list) + + return tf.convert_to_tensor(combined_zernike_prior, dtype=tf.float32) + + +def shift_x_y_to_zk1_2_wavediff(dxy, tel_focal_length=24.5, tel_diameter=1.2): + """Compute Zernike 1(2) for a given shifts in x(y) in WaveDifff conventions. + + All inputs should be in [m]. + A displacement of, for example, 0.5 pixels should be scaled with the corresponding pixel scale, + e.g. 12[um], to get a displacement in [m], which would be `dxy=0.5*12e-6`. + + The output zernike coefficient is in [um] units as expected by wavediff. + + To apply match the centroid with a `dx` that has a corresponding `zk1`, + the new PSF should be generated with `-zk1`. + + The same applies to `dy` and `zk2`. + + Parameters + ---------- + dxy : float + Centroid shift in [m]. It can be on the x-axis or the y-axis. + tel_focal_length : float + Telescope focal length in [m]. + tel_diameter : float + Telescope aperture diameter in [m]. + """ + reference_pix_sampling = 12e-6 + zernike_norm_factor = 2.0 + + # return zernike_norm_factor * (dx/reference_pix_sampling) / (tel_focal_length * tel_diameter / 2) + return ( + zernike_norm_factor + * (tel_diameter / 2) + * np.sin(np.arctan((dxy / reference_pix_sampling) / tel_focal_length)) + * 3.0 + ) + + +def compute_zernike_tip_tilt( + star_images: np.ndarray, + star_masks: Optional[np.ndarray] = None, + pixel_sampling: float = 12e-6, + reference_shifts: list[float] = [-1 / 3, -1 / 3], + sigma_init: float = 2.5, + n_iter: int = 20, +) -> np.ndarray: + """ + Compute Zernike tip-tilt corrections for a batch of PSF images. + + This function estimates the centroid shifts of multiple PSFs and computes + the corresponding Zernike tip-tilt corrections to align them with a reference. + + Parameters + ---------- + star_images : np.ndarray + A batch of PSF images (3D array of shape `(num_images, height, width)`). + star_masks : np.ndarray, optional + A batch of masks (same shape as `star_postage_stamps`). Each mask can have: + - `0` to ignore the pixel. + - `1` to fully consider the pixel. + - Values in `(0,1]` as weights for partial consideration. + Defaults to None. + pixel_sampling : float, optional + The pixel size in meters. Defaults to `12e-6 m` (12 microns). + reference_shifts : list[float], optional + The target centroid shifts in pixels, specified as `[dy, dx]`. + Defaults to `[-1/3, -1/3]` (nominal Euclid conditions). + sigma_init : float, optional + Initial standard deviation for centroid estimation. Default is `2.5`. + n_iter : int, optional + Number of iterations for centroid refinement. Default is `20`. + + Returns + ------- + np.ndarray + An array of shape `(num_images, 2)`, where: + - Column 0 contains `Zk1` (tip) values. + - Column 1 contains `Zk2` (tilt) values. + + Notes + ----- + - This function processes all images at once using vectorized operations. + - The Zernike coefficients are computed in the WaveDiff convention. + """ + from wf_psf.data.centroids import CentroidEstimator + + # Vectorize the centroid computation + centroid_estimator = CentroidEstimator( + im=star_images, mask=star_masks, sigma_init=sigma_init, n_iter=n_iter + ) + + shifts = centroid_estimator.get_intra_pixel_shifts() + + # Ensure reference_shifts is a NumPy array (if it's not already) + reference_shifts = np.array(reference_shifts) + + # Reshape to ensure it's a column vector (1, 2) + reference_shifts = reference_shifts[None, :] + + # Broadcast reference_shifts to match the shape of shifts + reference_shifts = np.broadcast_to(reference_shifts, shifts.shape) + + # Compute displacements + displacements = reference_shifts - shifts # + + # Ensure the correct axis order for displacements (x-axis, then y-axis) + displacements_swapped = displacements[:, [1, 0]] # Adjust axis order if necessary + + # Call shift_x_y_to_zk1_2_wavediff directly on the vector of displacements + zk1_2_array = shift_x_y_to_zk1_2_wavediff( + displacements_swapped.flatten() * pixel_sampling + ) # vectorized call + + # Reshape the result back to the original shape of displacements + zk1_2_array = zk1_2_array.reshape(displacements.shape) + + return zk1_2_array + + +def defocus_to_zk4_zemax(dz, tel_focal_length=24.5, tel_diameter=1.2): + """Compute Zernike 4 value for a given defocus in zemax conventions. + + All inputs should be in [m]. + + Parameters + ---------- + dz : float + Shift in the z-axis, perpendicular to the focal plane. Units in [m]. + tel_focal_length : float + Telescope focal length in [m]. + tel_diameter : float + Telescope aperture diameter in [m]. + """ + # Base calculation + zk4 = dz / (8.0 * (tel_focal_length / tel_diameter) ** 2) + # Apply Z4 normalisation + # This step depends on the normalisation of the Zernike basis used + zk4 /= np.sqrt(3) + # Convert to waves with a reference of 800nm + zk4 /= 800e-9 + # Remove the peak to valley value + zk4 /= 2.0 + + return zk4 + + +def defocus_to_zk4_wavediff(dz, tel_focal_length=24.5, tel_diameter=1.2): + """Compute Zernike 4 value for a given defocus in WaveDifff conventions. + + All inputs should be in [m]. + + The output zernike coefficient is in [um] units as expected by wavediff. + + Parameters + ---------- + dz : float + Shift in the z-axis, perpendicular to the focal plane. Units in [m]. + tel_focal_length : float + Telescope focal length in [m]. + tel_diameter : float + Telescope aperture diameter in [m]. + """ + # Base calculation + zk4 = dz / (8.0 * (tel_focal_length / tel_diameter) ** 2) + # Apply Z4 normalisation + # This step depends on the normalisation of the Zernike basis used + zk4 /= np.sqrt(3) + + # Remove the peak to valley value + zk4 /= 2.0 + + # Change units to [um] as Wavediff uses + zk4 *= 1e6 + + return zk4 diff --git a/src/wf_psf/data/training_preprocessing.py b/src/wf_psf/data/training_preprocessing.py deleted file mode 100644 index c1402f06..00000000 --- a/src/wf_psf/data/training_preprocessing.py +++ /dev/null @@ -1,458 +0,0 @@ -"""Training Data Processing. - -A module to load and preprocess training and validation test data. - -:Authors: Jennifer Pollack and Tobias Liaudat - -""" - -import os -import numpy as np -import wf_psf.utils.utils as utils -import tensorflow as tf -from wf_psf.utils.ccd_misalignments import CCDMisalignmentCalculator -from wf_psf.utils.centroids import compute_zernike_tip_tilt -from fractions import Fraction -import logging - -logger = logging.getLogger(__name__) - - -class DataHandler: - """Data Handler. - - This class manages loading and processing of training and testing data for use during PSF model training and validation. - It provides methods to access and preprocess the data. - - Parameters - ---------- - dataset_type : str - Type of dataset ("train" or "test"). - data_params : RecursiveNamespace - Recursive Namespace object containing parameters for both 'train' and 'test' datasets. - simPSF : PSFSimulator - Instance of the PSFSimulator class for simulating PSF models. - n_bins_lambda : int - Number of wavelength bins for SED processing. - load_data : bool, optional - If True, data is loaded and processed during initialization. If False, data loading - is deferred until explicitly called. Default is True. - - Attributes - ---------- - dataset_type : str - Type of dataset ("train" or "test"). - data_params : RecursiveNamespace - Parameters for the current dataset type. - dataset : dict or None - Dictionary containing the loaded dataset, including positions and stars/noisy_stars. - simPSF : PSFSimulator - Instance of the PSFSimulator class for simulating PSF models. - n_bins_lambda : int - Number of wavelength bins. - sed_data : tf.Tensor or None - TensorFlow tensor containing processed SED data for training/testing. - load_data_on_init : bool - Flag controlling whether data is loaded during initialization. - """ - - def __init__( - self, dataset_type, data_params, simPSF, n_bins_lambda, load_data: bool = True - ): - """ - Initialize the dataset handler for PSF simulation. - - Parameters - ---------- - dataset_type : str - Type of dataset ("train" or "test"). - data_params : RecursiveNamespace - Recursive Namespace object containing parameters for both 'train' and 'test' datasets. - simPSF : PSFSimulator - Instance of the PSFSimulator class for simulating PSF models. - n_bins_lambda : int - Number of wavelength bins for SED processing. - load_data : bool, optional - If True, data is loaded and processed during initialization. If False, data loading - is deferred. Default is True. - """ - self.dataset_type = dataset_type - self.data_params = data_params.__dict__[dataset_type] - self.simPSF = simPSF - self.n_bins_lambda = n_bins_lambda - self.dataset = None - self.sed_data = None - self.load_data_on_init = load_data - if self.load_data_on_init: - self.load_dataset() - self.process_sed_data() - - def load_dataset(self): - """Load dataset. - - Load the dataset based on the specified dataset type. - - """ - self.dataset = np.load( - os.path.join(self.data_params.data_dir, self.data_params.file), - allow_pickle=True, - )[()] - self.dataset["positions"] = tf.convert_to_tensor( - self.dataset["positions"], dtype=tf.float32 - ) - if self.dataset_type == "training": - if "noisy_stars" in self.dataset: - self.dataset["noisy_stars"] = tf.convert_to_tensor( - self.dataset["noisy_stars"], dtype=tf.float32 - ) - else: - logger.warning(f"Missing 'noisy_stars' in {self.dataset_type} dataset.") - elif self.dataset_type == "test": - if "stars" in self.dataset: - self.dataset["stars"] = tf.convert_to_tensor( - self.dataset["stars"], dtype=tf.float32 - ) - else: - logger.warning(f"Missing 'stars' in {self.dataset_type} dataset.") - - def process_sed_data(self): - """Process SED Data. - - A method to generate and process SED data. - - """ - self.sed_data = [ - utils.generate_SED_elems_in_tensorflow( - _sed, self.simPSF, n_bins=self.n_bins_lambda, tf_dtype=tf.float64 - ) - for _sed in self.dataset["SEDs"] - ] - self.sed_data = tf.convert_to_tensor(self.sed_data, dtype=tf.float32) - self.sed_data = tf.transpose(self.sed_data, perm=[0, 2, 1]) - - -def get_np_obs_positions(data): - """Get observed positions in numpy from the provided dataset. - - This method concatenates the positions of the stars from both the training - and test datasets to obtain the observed positions. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - np.ndarray - Numpy array containing the observed positions of the stars. - - Notes - ----- - The observed positions are obtained by concatenating the positions of stars - from both the training and test datasets along the 0th axis. - """ - obs_positions = np.concatenate( - ( - data.training_data.dataset["positions"], - data.test_data.dataset["positions"], - ), - axis=0, - ) - - return obs_positions - - -def get_obs_positions(data): - """Get observed positions from the provided dataset. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - tf.Tensor - Tensor containing the observed positions of the stars. - """ - obs_positions = get_np_obs_positions(data) - - return tf.convert_to_tensor(obs_positions, dtype=tf.float32) - - -def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: - """Extract specific star-related data from training and test datasets. - - This function retrieves and concatenates specific star-related data (e.g., stars, masks) from the - star training and test datasets such as star stamps or masks, based on the provided keys. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - train_key : str - The key to retrieve data from the training dataset (e.g., 'noisy_stars', 'masks'). - test_key : str - The key to retrieve data from the test dataset (e.g., 'stars', 'masks'). - - Returns - ------- - np.ndarray - A NumPy array containing the concatenated data for the given keys. - - Raises - ------ - KeyError - If the specified keys do not exist in the training or test datasets. - - Notes - ----- - - If the dataset contains TensorFlow tensors, they will be converted to NumPy arrays. - - Ensure that eager execution is enabled when calling this function. - """ - # Ensure the requested keys exist in both training and test datasets - missing_keys = [ - key - for key, dataset in [ - (train_key, data.training_data.dataset), - (test_key, data.test_data.dataset), - ] - if key not in dataset - ] - - if missing_keys: - raise KeyError(f"Missing keys in dataset: {missing_keys}") - - # Retrieve data from training and test sets - train_data = data.training_data.dataset[train_key] - test_data = data.test_data.dataset[test_key] - - # Convert to NumPy if necessary - if tf.is_tensor(train_data): - train_data = train_data.numpy() - if tf.is_tensor(test_data): - test_data = test_data.numpy() - - # Concatenate and return - return np.concatenate((train_data, test_data), axis=0) - - -def get_np_zernike_prior(data): - """Get the zernike prior from the provided dataset. - - This method concatenates the stars from both the training - and test datasets to obtain the full prior. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - zernike_prior : np.ndarray - Numpy array containing the full prior. - """ - zernike_prior = np.concatenate( - ( - data.training_data.dataset["zernike_prior"], - data.test_data.dataset["zernike_prior"], - ), - axis=0, - ) - - return zernike_prior - - -def compute_centroid_correction(model_params, data, batch_size: int = 1) -> np.ndarray: - """Compute centroid corrections using Zernike polynomials. - - This function calculates the Zernike contributions required to match the centroid - of the WaveDiff PSF model to the observed star centroids, processing in batches. - - Parameters - ---------- - model_params : RecursiveNamespace - An object containing parameters for the PSF model, including pixel sampling and initial centroid window parameters. - - data : DataConfigHandler - An object containing training and test datasets, including observed PSFs - and optional star masks. - - batch_size : int, optional - The batch size to use when processing the stars. Default is 16. - - - Returns - ------- - zernike_centroid_array : np.ndarray - A 2D NumPy array of shape `(n_stars, 3)`, where `n_stars` is the number of - observed stars. The array contains the computed Zernike contributions, - with zero padding applied to the first column to ensure a consistent shape. - """ - star_postage_stamps = extract_star_data( - data=data, train_key="noisy_stars", test_key="stars" - ) - - # Get star mask catalogue only if "masks" exist in both training and test datasets - star_masks = ( - extract_star_data(data=data, train_key="masks", test_key="masks") - if ( - data.training_data.dataset.get("masks") is not None - and data.test_data.dataset.get("masks") is not None - and tf.size(data.training_data.dataset["masks"]) > 0 - and tf.size(data.test_data.dataset["masks"]) > 0 - ) - else None - ) - - pix_sampling = model_params.pix_sampling * 1e-6 # Change units from [um] to [m] - - # Ensure star_masks is properly handled - star_masks = star_masks if star_masks is not None else None - - reference_shifts = [ - float(Fraction(value)) for value in model_params.reference_shifts - ] - - n_stars = len(star_postage_stamps) - zernike_centroid_array = [] - - # Batch process the stars - for i in range(0, n_stars, batch_size): - batch_postage_stamps = star_postage_stamps[i : i + batch_size] - batch_masks = star_masks[i : i + batch_size] if star_masks is not None else None - - # Compute Zernike 1 and Zernike 2 for the batch - zk1_2_batch = -1.0 * compute_zernike_tip_tilt( - batch_postage_stamps, batch_masks, pix_sampling, reference_shifts - ) - - # Zero pad array for each batch and append - zernike_centroid_array.append( - np.pad( - zk1_2_batch, - pad_width=[(0, 0), (1, 0)], - mode="constant", - constant_values=0, - ) - ) - - # Combine all batches into a single array - return np.concatenate(zernike_centroid_array, axis=0) - - -def compute_ccd_misalignment(model_params, data): - """Compute CCD misalignment. - - Parameters - ---------- - model_params : RecursiveNamespace - Object containing parameters for this PSF model class. - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - zernike_ccd_misalignment_array : np.ndarray - Numpy array containing the Zernike contributions to model the CCD chip misalignments. - """ - obs_positions = get_np_obs_positions(data) - - ccd_misalignment_calculator = CCDMisalignmentCalculator( - tiles_path=model_params.ccd_misalignments_input_path, - x_lims=model_params.x_lims, - y_lims=model_params.y_lims, - tel_focal_length=model_params.tel_focal_length, - tel_diameter=model_params.tel_diameter, - ) - # Compute required zernike 4 for each position - zk4_values = np.array( - [ - ccd_misalignment_calculator.get_zk4_from_position(single_pos) - for single_pos in obs_positions - ] - ).reshape(-1, 1) - - # Zero pad array to get shape (n_stars, n_zernike=4) - zernike_ccd_misalignment_array = np.pad( - zk4_values, pad_width=[(0, 0), (3, 0)], mode="constant", constant_values=0 - ) - - return zernike_ccd_misalignment_array - - -def get_zernike_prior(model_params, data, batch_size: int = 16): - """Get Zernike priors from the provided dataset. - - This method concatenates the Zernike priors from both the training - and test datasets. - - Parameters - ---------- - model_params : RecursiveNamespace - Object containing parameters for this PSF model class. - data : DataConfigHandler - Object containing training and test datasets. - batch_size : int, optional - The batch size to use when processing the stars. Default is 16. - - Returns - ------- - tf.Tensor - Tensor containing the observed positions of the stars. - - Notes - ----- - The Zernike prior are obtained by concatenating the Zernike priors - from both the training and test datasets along the 0th axis. - - """ - # List of zernike contribution - zernike_contribution_list = [] - - if model_params.use_prior: - logger.info("Reading in Zernike prior into Zernike contribution list...") - zernike_contribution_list.append(get_np_zernike_prior(data)) - - if model_params.correct_centroids: - logger.info("Adding centroid correction to Zernike contribution list...") - zernike_contribution_list.append( - compute_centroid_correction(model_params, data, batch_size) - ) - - if model_params.add_ccd_misalignments: - logger.info("Adding CCD mis-alignments to Zernike contribution list...") - zernike_contribution_list.append(compute_ccd_misalignment(model_params, data)) - - if len(zernike_contribution_list) == 1: - zernike_contribution = zernike_contribution_list[0] - else: - # Get max zk order - max_zk_order = np.max( - np.array( - [ - zk_contribution.shape[1] - for zk_contribution in zernike_contribution_list - ] - ) - ) - - zernike_contribution = np.zeros( - (zernike_contribution_list[0].shape[0], max_zk_order) - ) - - # Pad arrays to get the same length and add the final contribution - for it in range(len(zernike_contribution_list)): - current_zk_order = zernike_contribution_list[it].shape[1] - current_zernike_contribution = np.pad( - zernike_contribution_list[it], - pad_width=[(0, 0), (0, int(max_zk_order - current_zk_order))], - mode="constant", - constant_values=0, - ) - - zernike_contribution += current_zernike_contribution - - return tf.convert_to_tensor(zernike_contribution, dtype=tf.float32) diff --git a/src/wf_psf/inference/__init__.py b/src/wf_psf/inference/__init__.py new file mode 100644 index 00000000..5df682d0 --- /dev/null +++ b/src/wf_psf/inference/__init__.py @@ -0,0 +1,2 @@ +# src/wf_psf/inference/__init__.py +"""Inference package for PSF generation.""" diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py new file mode 100644 index 00000000..70d7a9be --- /dev/null +++ b/src/wf_psf/inference/psf_inference.py @@ -0,0 +1,745 @@ +"""Inference. + +A module which provides a PSFInference class to perform inference +with trained PSF models. It is able to load a trained model, +perform inference on a dataset of SEDs and positions, and generate polychromatic PSFs. + +:Authors: Jennifer Pollack , Tobias Liaudat + +""" + +import os +from pathlib import Path +import numpy as np +from wf_psf.data.data_handler import DataHandler +from wf_psf.utils.read_config import read_conf +from wf_psf.utils.utils import ensure_batch +from wf_psf.psf_models import psf_models +from wf_psf.psf_models.psf_model_loader import load_trained_psf_model +import tensorflow as tf + + +class InferenceConfigHandler: + """ + Handle configuration loading and management for PSF inference. + + This class manages the loading of inference, training, and data configuration + files required for PSF inference operations. + + Parameters + ---------- + inference_config_path : str + Path to the inference configuration YAML file. + + Attributes + ---------- + inference_config_path : str + Path to the inference configuration file. + inference_config : RecursiveNamespace or None + Loaded inference configuration. + training_config : RecursiveNamespace or None + Loaded training configuration. + data_config : RecursiveNamespace or None + Loaded data configuration. + trained_model_path : Path + Path to the trained model directory. + model_subdir : str + Subdirectory name for model files. + trained_model_config_path : Path + Path to the training configuration file. + data_config_path : str or None + Path to the data configuration file. + """ + + ids = ("inference_conf",) + + def __init__(self, inference_config_path: str): + self.inference_config_path = inference_config_path + self.inference_config = None + self.training_config = None + self.data_config = None + + def load_configs(self): + """ + Load configuration files based on the inference config. + + Loads the inference configuration first, then uses it to determine and load + the training and data configurations. + + Notes + ----- + Updates the following attributes in-place: + - inference_config + - training_config + - data_config (if data_config_path is specified) + """ + + self.inference_config = read_conf(self.inference_config_path) + self.set_config_paths() + self.training_config = read_conf(self.trained_model_config_path) + + if self.data_config_path is not None: + # Load the data configuration + self.data_config = read_conf(self.data_config_path) + + def set_config_paths(self): + """ + Extract and set the configuration paths from the inference config. + + Sets the following attributes: + - trained_model_path + - model_subdir + - trained_model_config_path + - data_config_path + """ + # Set config paths + config_paths = self.inference_config.inference.configs + + self.trained_model_path = Path(config_paths.trained_model_path) + self.model_subdir = config_paths.model_subdir + self.trained_model_config_path = ( + self.trained_model_path / config_paths.trained_model_config_path + ) + self.data_config_path = config_paths.data_config_path + + @staticmethod + def overwrite_model_params(training_config=None, inference_config=None): + """ + Overwrite training model_params with values from inference_config if available. + + Parameters + ---------- + training_config : RecursiveNamespace + Configuration object from training phase. + inference_config : RecursiveNamespace + Configuration object from inference phase. + + Notes + ----- + Updates are applied in-place to training_config.training.model_params. + """ + model_params = training_config.training.model_params + inf_model_params = inference_config.inference.model_params + + if model_params and inf_model_params: + for key, value in inf_model_params.__dict__.items(): + if hasattr(model_params, key): + setattr(model_params, key, value) + + +class PSFInference: + """ + Perform PSF inference using a pre-trained WaveDiff model. + + This class handles the setup for PSF inference, including loading configuration + files, instantiating the PSF simulator and data handler, and preparing the + input data required for inference. + + Parameters + ---------- + inference_config_path : str + Path to the inference configuration YAML file. + x_field : array-like, optional + x coordinates in SHE convention. + y_field : array-like, optional + y coordinates in SHE convention. + seds : array-like, optional + Spectral energy distributions (SEDs). + sources : array-like, optional + Postage stamps of sources, e.g. star images (shape: [n_stars, h, w]). + masks : array-like, optional + Corresponding masks for the sources (same shape as sources). Defaults to None. + + + Attributes + ---------- + inference_config_path : str + Path to the inference configuration file. + x_field : array-like or None + x coordinates for PSF positions. + y_field : array-like or None + y coordinates for PSF positions. + seds : array-like or None + Spectral energy distributions. + sources : array-like or None + Source postage stamps. + masks : array-like or None + Source masks. + engine : PSFInferenceEngine or None + The inference engine instance. + + Examples + -------- + Basic usage with position coordinates and SEDs: + + .. code-block:: python + + psf_inf = PSFInference( + inference_config_path="config.yaml", + x_field=[100.5, 200.3], + y_field=[150.2, 250.8], + seds=sed_array + ) + psf_inf.run_inference() + psf = psf_inf.get_psf(0) + """ + + def __init__( + self, + inference_config_path: str, + x_field=None, + y_field=None, + seds=None, + sources=None, + masks=None, + ): + + self.inference_config_path = inference_config_path + + # Inputs for the model + self.x_field = x_field + self.y_field = y_field + self.seds = seds + self.sources = sources + self.masks = masks + + # Internal caches for lazy-loading + self._config_handler = None + self._simPSF = None + self._data_handler = None + self._trained_psf_model = None + self._n_bins_lambda = None + self._batch_size = None + self._cycle = None + self._output_dim = None + + # Initialise PSF Inference engine + self.engine = None + + @property + def config_handler(self): + """ + Get or create the configuration handler. + + Returns + ------- + InferenceConfigHandler + The configuration handler instance with loaded configs. + """ + if self._config_handler is None: + self._config_handler = InferenceConfigHandler(self.inference_config_path) + self._config_handler.load_configs() + return self._config_handler + + def prepare_configs(self): + """ + Prepare the configuration for inference. + + Overwrites training model parameters with inference configuration values. + """ + # Overwrite model parameters with inference config + self.config_handler.overwrite_model_params( + self.training_config, self.inference_config + ) + + @property + def inference_config(self): + """ + Get the inference configuration. + + Returns + ------- + RecursiveNamespace + The inference configuration object. + """ + return self.config_handler.inference_config + + @property + def training_config(self): + """ + Get the training configuration. + + Returns + ------- + RecursiveNamespace + The training configuration object. + """ + return self.config_handler.training_config + + @property + def data_config(self): + """ + Get the data configuration. + + Returns + ------- + RecursiveNamespace or None + The data configuration object, or None if not available. + """ + return self.config_handler.data_config + + @property + def simPSF(self): + """ + Get or create the PSF simulator. + + Returns + ------- + simPSF + The PSF simulator instance. + """ + if self._simPSF is None: + self._simPSF = psf_models.simPSF(self.training_config.training.model_params) + return self._simPSF + + def _prepare_dataset_for_inference(self): + """ + Prepare dataset dictionary for inference. + + Returns + ------- + dict or None + Dictionary containing positions, sources, and masks, or None if positions are invalid. + """ + positions = self.get_positions() + if positions is None: + return None + return {"positions": positions, "sources": self.sources, "masks": self.masks} + + @property + def data_handler(self): + """ + Get or create the data handler. + + Returns + ------- + DataHandler + The data handler instance configured for inference. + """ + if self._data_handler is None: + # Instantiate the data handler + self._data_handler = DataHandler( + dataset_type="inference", + data_params=self.data_config, + simPSF=self.simPSF, + n_bins_lambda=self.n_bins_lambda, + load_data=False, + dataset=self._prepare_dataset_for_inference(), + sed_data=self.seds, + ) + self._data_handler.run_type = "inference" + return self._data_handler + + @property + def trained_psf_model(self): + """ + Get or load the trained PSF model. + + Returns + ------- + Model + The loaded trained PSF model. + """ + if self._trained_psf_model is None: + self._trained_psf_model = self.load_inference_model() + return self._trained_psf_model + + def get_positions(self): + """ + Combine x_field and y_field into position pairs. + + Returns + ------- + numpy.ndarray + Array of shape (num_positions, 2) where each row contains [x, y] coordinates. + Returns None if either x_field or y_field is None. + + Raises + ------ + ValueError + If x_field and y_field have different lengths. + """ + if self.x_field is None or self.y_field is None: + return None + + x_arr = np.asarray(self.x_field) + y_arr = np.asarray(self.y_field) + + if x_arr.size == 0 or y_arr.size == 0: + return None + + if x_arr.size != y_arr.size: + raise ValueError( + f"x_field and y_field must have the same length. " + f"Got {x_arr.size} and {y_arr.size}" + ) + + # Flatten arrays to handle any input shape, then stack + x_flat = x_arr.flatten() + y_flat = y_arr.flatten() + + return np.column_stack((x_flat, y_flat)) + + def load_inference_model(self): + """Load the trained PSF model based on the inference configuration. + + Returns + ------- + Model + The loaded trained PSF model. + + Notes + ----- + Constructs the weights path pattern based on the trained model path, + model subdirectory, model name, id name, and cycle number specified in the + configuration files. + """ + model_path = self.config_handler.trained_model_path + model_dir = self.config_handler.model_subdir + model_name = self.training_config.training.model_params.model_name + id_name = self.training_config.training.id_name + + weights_path_pattern = os.path.join( + model_path, + model_dir, + f"{model_dir}*_{model_name}*{id_name}_cycle{self.cycle}*", + ) + + # Load the trained PSF model + return load_trained_psf_model( + self.training_config, + self.data_handler, + weights_path_pattern, + ) + + @property + def n_bins_lambda(self): + """Get the number of wavelength bins for inference. + + Returns + ------- + int + The number of wavelength bins used during inference.""" + if self._n_bins_lambda is None: + self._n_bins_lambda = ( + self.inference_config.inference.model_params.n_bins_lda + ) + return self._n_bins_lambda + + @property + def batch_size(self): + """ + Get the batch size for inference. + + Returns + ------- + int + The batch size for processing during inference. + """ + if self._batch_size is None: + self._batch_size = self.inference_config.inference.batch_size + assert self._batch_size > 0, "Batch size must be greater than 0." + return self._batch_size + + @property + def cycle(self): + """Get the cycle number for inference. + + Returns + ------- + int + The cycle number used for loading the trained model. + """ + if self._cycle is None: + self._cycle = self.inference_config.inference.cycle + return self._cycle + + @property + def output_dim(self): + """Get the output dimension for PSF inference. + + Returns + ------- + int + The output dimension (height and width) of the inferred PSFs. + """ + if self._output_dim is None: + self._output_dim = self.inference_config.inference.model_params.output_dim + return self._output_dim + + def _prepare_positions_and_seds(self): + """ + Preprocess and return tensors for positions and SEDs with consistent shapes. + + Handles single-star, multi-star, and even scalar inputs, ensuring: + - positions: shape (n_samples, 2) + - sed_data: shape (n_samples, n_bins, 2) + """ + # Ensure x_field and y_field are at least 1D + x_arr = np.atleast_1d(self.x_field) + y_arr = np.atleast_1d(self.y_field) + + if x_arr.size != y_arr.size: + raise ValueError( + f"x_field and y_field must have the same length. " + f"Got {x_arr.size} and {y_arr.size}" + ) + + # Combine into positions array (n_samples, 2) + positions = np.column_stack((x_arr, y_arr)) + positions = tf.convert_to_tensor(positions, dtype=tf.float32) + + # Ensure SEDs have shape (n_samples, n_bins, 2) + sed_data = ensure_batch(self.seds) + + if sed_data.shape[0] != positions.shape[0]: + raise ValueError( + f"SEDs batch size {sed_data.shape[0]} does not match number of positions {positions.shape[0]}" + ) + + if sed_data.shape[2] != 2: + raise ValueError( + f"SEDs last dimension must be 2 (flux, wavelength). Got {sed_data.shape}" + ) + + # Process SEDs through the data handler + self.data_handler.process_sed_data(sed_data) + sed_data_tensor = self.data_handler.sed_data + + return positions, sed_data_tensor + + def run_inference(self): + """Run PSF inference and return the full PSF array. + + Returns + ------- + numpy.ndarray + Array of inferred PSFs with shape (n_samples, output_dim, output_dim). + + Notes + ----- + Prepares configurations and input data, initializes the inference engine, + and computes the PSF for all input positions. + """ + # Prepare the configuration for inference + self.prepare_configs() + + # Prepare positions and SEDs for inference + positions, sed_data = self._prepare_positions_and_seds() + + self.engine = PSFInferenceEngine( + trained_model=self.trained_psf_model, + batch_size=self.batch_size, + output_dim=self.output_dim, + ) + return self.engine.compute_psfs(positions, sed_data) + + def _ensure_psf_inference_completed(self): + """Ensure that PSF inference has been completed. + + Runs inference if it has not been done yet. + """ + if self.engine is None or self.engine.inferred_psfs is None: + self.run_inference() + + def get_psfs(self): + """Get all inferred PSFs. + + Returns + ------- + numpy.ndarray + Array of inferred PSFs with shape (n_samples, output_dim, output_dim). + + Notes + ----- + Ensures automatically that inference has been completed before accessing the PSFs. + """ + self._ensure_psf_inference_completed() + return self.engine.get_psfs() + + def get_psf(self, index: int = 0) -> np.ndarray: + """ + Get the PSF at a specific index. + + Parameters + ---------- + index : int, optional + Index of the PSF to retrieve (default is 0). + + Returns + ------- + numpy.ndarray + The inferred PSF at the specified index with shape (output_dim, output_dim). + + Notes + ----- + Ensures automatically that inference has been completed before accessing the PSF. + If only a single star was passed during instantiation, the index defaults to 0 + and bounds checking is relaxed. + """ + self._ensure_psf_inference_completed() + + inferred_psfs = self.engine.get_psfs() + + # If a single-star batch, ignore index bounds + if inferred_psfs.shape[0] == 1: + return inferred_psfs[0] + + # Otherwise, return the PSF at the requested index + return inferred_psfs[index] + + def clear_cache(self): + """ + Clear all cached properties and reset the instance. + + This method resets all lazy-loaded properties, including the config handler, + PSF simulator, data handler, trained model, and inference engine. Useful for + freeing memory or forcing a fresh initialization. + + Notes + ----- + After calling this method, accessing any property will trigger re-initialization. + """ + self._config_handler = None + self._simPSF = None + self._data_handler = None + self._trained_psf_model = None + self._n_bins_lambda = None + self._batch_size = None + self._cycle = None + self._output_dim = None + self.engine = None + + +class PSFInferenceEngine: + """Engine to perform PSF inference using a trained model. + + This class handles the batch-wise computation of PSFs using a trained PSF model. + It manages the batching of input positions and SEDs, and caches the inferred PSFs for later access. + + Parameters + ---------- + trained_model : Model + The trained PSF model to use for inference. + batch_size : int + The batch size for processing during inference. + output_dim : int + The output dimension (height and width) of the inferred PSFs. + + Attributes + ---------- + trained_model : Model + The trained PSF model used for inference. + batch_size : int + The batch size for processing during inference. + output_dim : int + The output dimension (height and width) of the inferred PSFs. + + Examples + -------- + .. code-block:: python + + >>> engine = PSFInferenceEngine(model, batch_size=32, output_dim=64) + >>> psfs = engine.compute_psfs(positions, seds) + >>> single_psf = engine.get_psf(0) + """ + + def __init__(self, trained_model, batch_size: int, output_dim: int): + self.trained_model = trained_model + self.batch_size = batch_size + self.output_dim = output_dim + self._inferred_psfs = None + + @property + def inferred_psfs(self) -> np.ndarray: + """Access the cached inferred PSFs, if available. + + Returns + ------- + numpy.ndarray or None + The cached inferred PSFs, or None if not yet computed. + """ + return self._inferred_psfs + + def compute_psfs(self, positions: tf.Tensor, sed_data: tf.Tensor) -> np.ndarray: + """Compute and cache PSFs for the input source parameters. + + Parameters + ---------- + positions : tf.Tensor + Tensor of shape (n_samples, 2) containing the (x, y) positions + sed_data : tf.Tensor + Tensor of shape (n_samples, n_bins, 2) containing the SEDs + + Returns + ------- + numpy.ndarray + Array of inferred PSFs with shape (n_samples, output_dim, output_dim). + + Notes + ----- + PSFs are computed in batches according to the specified batch_size. + Results are cached internally for subsequent access via get_psfs() or get_psf(). + """ + n_samples = positions.shape[0] + self._inferred_psfs = np.zeros( + (n_samples, self.output_dim, self.output_dim), dtype=np.float32 + ) + + # Initialize counter + counter = 0 + while counter < n_samples: + # Calculate the batch end element + end_sample = min(counter + self.batch_size, n_samples) + + # Define the batch positions + batch_pos = positions[counter:end_sample, :] + batch_seds = sed_data[counter:end_sample, :, :] + batch_inputs = [batch_pos, batch_seds] + + # Generate PSFs for the current batch + batch_psfs = self.trained_model(batch_inputs, training=False) + self.inferred_psfs[counter:end_sample, :, :] = batch_psfs.numpy() + + # Update the counter + counter = end_sample + + return self._inferred_psfs + + def get_psfs(self) -> np.ndarray: + """Get all the generated PSFs. + + Returns + ------- + numpy.ndarray + Array of inferred PSFs with shape (n_samples, output_dim, output_dim). + """ + if self._inferred_psfs is None: + raise ValueError("PSFs not yet computed. Call compute_psfs() first.") + return self._inferred_psfs + + def get_psf(self, index: int) -> np.ndarray: + """Get the PSF at a specific index. + + Returns + ------- + numpy.ndarray + The inferred PSF at the specified index with shape (output_dim, output_dim). + + Raises + ------ + ValueError + If PSFs have not yet been computed. + """ + if self._inferred_psfs is None: + raise ValueError("PSFs not yet computed. Call compute_psfs() first.") + return self._inferred_psfs[index] + + def clear_cache(self): + """ + Clear cached inferred PSFs. + + Resets the internal PSF cache to free memory. After calling this method, + compute_psfs() must be called again before accessing PSFs. + """ + self._inferred_psfs = None diff --git a/src/wf_psf/instrument/__init__.py b/src/wf_psf/instrument/__init__.py new file mode 100644 index 00000000..d72618f3 --- /dev/null +++ b/src/wf_psf/instrument/__init__.py @@ -0,0 +1 @@ +"""Wavefront-based PSF Instrument package.""" diff --git a/src/wf_psf/utils/ccd_misalignments.py b/src/wf_psf/instrument/ccd_misalignments.py similarity index 89% rename from src/wf_psf/utils/ccd_misalignments.py rename to src/wf_psf/instrument/ccd_misalignments.py index 1d2135ba..873509e5 100644 --- a/src/wf_psf/utils/ccd_misalignments.py +++ b/src/wf_psf/instrument/ccd_misalignments.py @@ -10,7 +10,47 @@ import numpy as np import matplotlib.path as mpltPath from scipy.spatial import KDTree -from wf_psf.utils.preprocessing import defocus_to_zk4_wavediff + + +def compute_ccd_misalignment(model_params, positions: np.ndarray) -> np.ndarray: + """Compute CCD misalignment. + + Parameters + ---------- + model_params : RecursiveNamespace + Object containing parameters for this PSF model class. + positions : np.ndarray + Numpy array containing the positions of the stars in the focal plane. + Shape: (n_stars, 2), where n_stars is the number of stars and 2 corresponds to x and y coordinates. + + Returns + ------- + zernike_ccd_misalignment_array : np.ndarray + Numpy array containing the Zernike contributions to model the CCD chip misalignments. + """ + obs_positions = positions + + ccd_misalignment_calculator = CCDMisalignmentCalculator( + tiles_path=model_params.ccd_misalignments_input_path, + x_lims=model_params.x_lims, + y_lims=model_params.y_lims, + tel_focal_length=model_params.tel_focal_length, + tel_diameter=model_params.tel_diameter, + ) + # Compute required zernike 4 for each position + zk4_values = np.array( + [ + ccd_misalignment_calculator.get_zk4_from_position(single_pos) + for single_pos in obs_positions + ] + ).reshape(-1, 1) + + # Zero pad array to get shape (n_stars, n_zernike=4) + zernike_ccd_misalignment_array = np.pad( + zk4_values, pad_width=[(0, 0), (3, 0)], mode="constant", constant_values=0 + ) + + return zernike_ccd_misalignment_array class CCDMisalignmentCalculator: @@ -121,11 +161,7 @@ def _preprocess_tile_data(self) -> None: self.tiles_z_average = np.mean(self.tiles_z_lims) def _initialize_polygons(self): - """Initialize polygons to look for CCD IDs. - - Each CCD is represented by a polygon defined by its corner points. - - """ + """Initialize polygons to look for CCD IDs""" # Build polygon list corresponding to each CCD self.ccd_polygons = [] @@ -346,6 +382,8 @@ def get_zk4_from_position(self, pos): Zernike 4 value in wavediff convention corresponding to the delta z of the given input position `pos`. """ + from wf_psf.data.data_zernike_utils import defocus_to_zk4_wavediff + dz = self.get_dz_from_position(pos) return defocus_to_zk4_wavediff(dz, self.tel_focal_length, self.tel_diameter) diff --git a/src/wf_psf/metrics/metrics.py b/src/wf_psf/metrics/metrics.py index 0447d596..942a0622 100644 --- a/src/wf_psf/metrics/metrics.py +++ b/src/wf_psf/metrics/metrics.py @@ -152,8 +152,7 @@ def compute_poly_metric( # Print RMSE values logger.info("Absolute RMSE:\t %.4e \t +/- %.4e", rmse, std_rmse) - logger.info("Relative RMSE:\t %.4e % \t +/- %.4e %", rel_rmse, std_rel_rmse) - + logger.info(f"Relative RMSE:\t {rel_rmse:.4e}% \t +/- {std_rel_rmse:.4e}%") return rmse, rel_rmse, std_rmse, std_rel_rmse @@ -364,9 +363,8 @@ def compute_opd_metrics(tf_semiparam_field, gt_tf_semiparam_field, pos, batch_si rel_rmse_std = np.std(rel_rmse_vals) # Print RMSE values - logger.info("Absolute RMSE:\t %.4e % \t +/- %.4e %", rmse, rmse_std) - logger.info("Relative RMSE:\t %.4e % \t +/- %.4e %", rel_rmse, rel_rmse_std) - + logger.info("Absolute RMSE:\t %.4e \t +/- %.4e" % (rmse, rmse_std)) + logger.info(f"Relative RMSE:\t {rel_rmse:.4e}% \t +/- {rel_rmse_std:.4e}%") return rmse, rel_rmse, rmse_std, rel_rmse_std @@ -596,10 +594,10 @@ def compute_shape_metrics( # Print relative shape/size errors logger.info( - f"\nRelative sigma(e1) RMSE =\t {rel_rmse_e1:.4e} % \t +/- {std_rel_rmse_e1:.4e} %" + f"\nRelative sigma(e1) RMSE =\t {rel_rmse_e1:.4e}% \t +/- {std_rel_rmse_e1:.4e}%" ) logger.info( - f"Relative sigma(e2) RMSE =\t {rel_rmse_e2:.4e} % \t +/- {std_rel_rmse_e2:.4e} %" + f"Relative sigma(e2) RMSE =\t {rel_rmse_e2:.4e}% \t +/- {std_rel_rmse_e2:.4e}%" ) # Print number of stars diff --git a/src/wf_psf/metrics/metrics_interface.py b/src/wf_psf/metrics/metrics_interface.py index 3dff2c6c..db410f4e 100644 --- a/src/wf_psf/metrics/metrics_interface.py +++ b/src/wf_psf/metrics/metrics_interface.py @@ -311,7 +311,6 @@ def evaluate_model( trained_model_params, data, psf_model, - weights_path, metrics_output, ): """Evaluate the trained model on both training and test datasets by computing various metrics. @@ -329,8 +328,6 @@ def evaluate_model( DataHandler object containing training and test data psf_model: object PSF model object - weights_path: str - Directory location of model weights metrics_output: str Directory location of metrics output @@ -341,8 +338,8 @@ def evaluate_model( try: ## Load datasets # ----------------------------------------------------- - # Get training data - logger.info("Fetching and preprocessing training and test data...") + # Get training and test data + logger.info("Fetching training and test data...") # Initialize metrics_handler metrics_handler = MetricsParamsHandler(metrics_params, trained_model_params) @@ -351,14 +348,6 @@ def evaluate_model( # Prepare np input simPSF_np = data.training_data.simPSF - ## Load the model's weights - try: - logger.info(f"Loading PSF model weights from {weights_path}") - psf_model.load_weights(weights_path) - except Exception as e: - logger.exception("An error occurred with the weights_path file: %s", e) - exit() - # Define datasets datasets = {"test": data.test_data.dataset, "train": data.training_data.dataset} diff --git a/src/wf_psf/psf_models/models/__init__.py b/src/wf_psf/psf_models/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/wf_psf/psf_models/psf_model_parametric.py b/src/wf_psf/psf_models/models/psf_model_parametric.py similarity index 98% rename from src/wf_psf/psf_models/psf_model_parametric.py rename to src/wf_psf/psf_models/models/psf_model_parametric.py index 0cd703d7..4a28417d 100644 --- a/src/wf_psf/psf_models/psf_model_parametric.py +++ b/src/wf_psf/psf_models/models/psf_model_parametric.py @@ -9,7 +9,7 @@ import tensorflow as tf from wf_psf.psf_models.psf_models import register_psfclass -from wf_psf.psf_models.tf_layers import ( +from wf_psf.psf_models.tf_modules.tf_layers import ( TFPolynomialZernikeField, TFZernikeOPD, TFBatchPolychromaticPSF, @@ -215,7 +215,7 @@ def predict_opd(self, input_positions): return opd_maps - def call(self, inputs): + def call(self, inputs, **kwargs): """Define the PSF field forward model. [1] From positions to Zernike coefficients diff --git a/src/wf_psf/psf_models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py similarity index 67% rename from src/wf_psf/psf_models/psf_model_physical_polychromatic.py rename to src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index ad61c1b5..fb8bc902 100644 --- a/src/wf_psf/psf_models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -10,11 +10,14 @@ from typing import Optional import tensorflow as tf from tensorflow.python.keras.engine import data_adapter +from wf_psf.data.data_handler import get_data_array +from wf_psf.data.data_zernike_utils import ( + ZernikeInputsFactory, + assemble_zernike_contributions, + pad_tf_zernikes, +) from wf_psf.psf_models import psf_models as psfm -from wf_psf.utils.read_config import RecursiveNamespace -from wf_psf.utils.configs_handler import DataConfigHandler -from wf_psf.data.training_preprocessing import get_obs_positions, get_zernike_prior -from wf_psf.psf_models.tf_layers import ( +from wf_psf.psf_models.tf_modules.tf_layers import ( TFPolynomialZernikeField, TFZernikeOPD, TFBatchPolychromaticPSF, @@ -22,6 +25,7 @@ TFNonParametricPolynomialVariationsOPD, TFPhysicalLayer, ) +from wf_psf.psf_models.tf_modules.tf_utils import ensure_tensor import logging @@ -97,8 +101,8 @@ def __init__(self, model_params, training_params, data, coeff_mat=None): A Recursive Namespace object containing parameters for this PSF model class. training_params: Recursive Namespace A Recursive Namespace object containing training hyperparameters for this PSF model class. - data: DataConfigHandler - A DataConfigHandler object that provides access to training and test datasets, as well as prior knowledge like Zernike coefficients. + data: DataConfigHandler or dict + A DataConfigHandler object or dict that provides access to single or multiple datasets (e.g. train and test), as well as prior knowledge like Zernike coefficients. coeff_mat: Tensor or None, optional Coefficient matrix defining the parametric PSF field model. @@ -108,204 +112,192 @@ def __init__(self, model_params, training_params, data, coeff_mat=None): Initialized instance of the TFPhysicalPolychromaticField class. """ super().__init__(model_params, training_params, coeff_mat) - self._initialize_parameters_and_layers( - model_params, training_params, data, coeff_mat - ) - - def _initialize_parameters_and_layers( - self, - model_params: RecursiveNamespace, - training_params: RecursiveNamespace, - data: DataConfigHandler, - coeff_mat: Optional[tf.Tensor] = None, - ): - """Initialize Parameters of the PSF model. - - This method sets up the PSF model parameters, observational positions, - Zernike coefficients, and components required for the automatically - differentiable optical forward model. + self.model_params = model_params + self.training_params = training_params + self.data = data + self.run_type = self._get_run_type(data) + self.obs_pos = self.get_obs_pos() - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - training_params: Recursive Namespace - A Recursive Namespace object containing training hyperparameters for this PSF model class. - data: DataConfigHandler object - A DataConfigHandler object providing access to training and tests datasets, as well as prior knowledge like Zernike coefficients. - coeff_mat: Tensor or None, optional - Initialization of the coefficient matrix defining the parametric psf field model. - - Notes - ----- - - Initializes Zernike parameters based on dataset priors. - - Configures the PSF model layers according to `model_params`. - - If `coeff_mat` is provided, the model coefficients are updated accordingly. - """ + # Initialize the model parameters self.output_Q = model_params.output_Q - self.obs_pos = get_obs_positions(data) self.l2_param = model_params.param_hparams.l2_param - # Inputs: Save optimiser history Parametric model features - self.save_optim_history_param = ( - model_params.param_hparams.save_optim_history_param - ) - # Inputs: Save optimiser history NonParameteric model features - self.save_optim_history_nonparam = ( - model_params.nonparam_hparams.save_optim_history_nonparam - ) - self._initialize_zernike_parameters(model_params, data) - self._initialize_layers(model_params, training_params) + self.output_dim = model_params.output_dim - # Initialize the model parameters with non-default value + # Initialise lazy loading of external Zernike prior + self._external_prior = None + + # Set Zernike Polynomial Coefficient Matrix if not None if coeff_mat is not None: self.assign_coeff_matrix(coeff_mat) - def _initialize_zernike_parameters(self, model_params, data): - """Initialize the Zernike parameters. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - data: DataConfigHandler object - A DataConfigHandler object providing access to training and tests datasets, as well as prior knowledge like Zernike coefficients. - """ - self.zks_prior = get_zernike_prior(model_params, data, data.batch_size) - self.n_zks_total = max( - model_params.param_hparams.n_zernikes, - tf.cast(tf.shape(self.zks_prior)[1], tf.int32), - ) - self.zernike_maps = psfm.generate_zernike_maps_3d( - self.n_zks_total, model_params.pupil_diameter + # Compute contributions once eagerly (outside graph) + zks_total_contribution_np = self._assemble_zernike_contributions().numpy() + self._zks_total_contribution = tf.convert_to_tensor( + zks_total_contribution_np, dtype=tf.float32 ) - def _initialize_layers(self, model_params, training_params): - """Initialize the layers of the PSF model. - - This method initializes the layers of the PSF model, including the physical layer, polynomial Zernike field, batch polychromatic layer, and non-parametric OPD layer. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - training_params: Recursive Namespace - A Recursive Namespace object containing training hyperparameters for this PSF model class. - coeff_mat: Tensor or None, optional - Initialization of the coefficient matrix defining the parametric psf field model. - - """ - self._initialize_physical_layer(model_params) - self._initialize_polynomial_Z_field(model_params) - self._initialize_Zernike_OPD(model_params) - self._initialize_batch_polychromatic_layer(model_params, training_params) - self._initialize_nonparametric_opd_layer(model_params, training_params) - - def _initialize_physical_layer(self, model_params): - """Initialize the physical layer of the PSF model. - - This method initializes the physical layer of the PSF model using parameters - specified in the `model_params` object. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - """ - self.tf_physical_layer = TFPhysicalLayer( - self.obs_pos, - self.zks_prior, - interpolation_type=model_params.interpolation_type, - interpolation_args=model_params.interpolation_args, + # Compute n_zks_total as int + self._n_zks_total = max( + self.model_params.param_hparams.n_zernikes, + zks_total_contribution_np.shape[1], ) - def _initialize_polynomial_Z_field(self, model_params): - """Initialize the polynomial Zernike field of the PSF model. - - This method initializes the polynomial Zernike field of the PSF model using - parameters specified in the `model_params` object. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - - """ - self.tf_poly_Z_field = TFPolynomialZernikeField( - x_lims=model_params.x_lims, - y_lims=model_params.y_lims, - random_seed=model_params.param_hparams.random_seed, - n_zernikes=model_params.param_hparams.n_zernikes, - d_max=model_params.param_hparams.d_max, + # Precompute zernike maps as tf.float32 + self._zernike_maps = psfm.generate_zernike_maps_3d( + n_zernikes=self._n_zks_total, pupil_diam=self.model_params.pupil_diameter ) - def _initialize_Zernike_OPD(self, model_params): - """Initialize the Zernike OPD field of the PSF model. - - This method initializes the Zernike Optical Path Difference - field of the PSF model using parameters specified in the `model_params` object. + # Precompute OPD dimension + self._opd_dim = self._zernike_maps.shape[1] - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. + # Precompute obscurations as tf.complex64 + self._obscurations = psfm.tf_obscurations( + pupil_diam=self.model_params.pupil_diameter, + N_filter=self.model_params.LP_filter_length, + rotation_angle=self.model_params.obscuration_rotation_angle, + ) - """ - # Initialize the zernike to OPD layer - self.tf_zernike_OPD = TFZernikeOPD(zernike_maps=self.zernike_maps) + # Eagerly initialise model layers + self.tf_batch_poly_PSF = self._build_tf_batch_poly_PSF() + _ = self.tf_poly_Z_field + _ = self.tf_np_poly_opd - def _initialize_batch_polychromatic_layer(self, model_params, training_params): - """Initialize the batch polychromatic PSF layer. + def _get_run_type(self, data): + if hasattr(data, "run_type"): + run_type = data.run_type + elif isinstance(data, dict) and "run_type" in data: + run_type = data["run_type"] + else: + raise ValueError("data must have a 'run_type' attribute or key") + + if run_type not in {"training", "simulation", "metrics", "inference"}: + raise ValueError(f"Unknown run_type: {run_type}") + return run_type + + def _assemble_zernike_contributions(self): + zks_inputs = ZernikeInputsFactory.build( + data=self.data, + run_type=self.run_type, + model_params=self.model_params, + prior=self._external_prior if hasattr(self, "_external_prior") else None, + ) + return assemble_zernike_contributions( + model_params=self.model_params, + zernike_prior=zks_inputs.zernike_prior, + centroid_dataset=zks_inputs.centroid_dataset, + positions=zks_inputs.misalignment_positions, + batch_size=self.training_params.batch_size, + ) - This method initializes the batch opd to batch polychromatic PSF layer - using the provided `model_params` and `training_params`. + @property + def save_param_history(self) -> bool: + """Check if the model should save the optimization history for parametric features.""" + return getattr( + self.model_params.param_hparams, "save_optim_history_param", False + ) - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - training_params: Recursive Namespace - A Recursive Namespace object containing training hyperparameters for this PSF model class. + @property + def save_nonparam_history(self) -> bool: + """Check if the model should save the optimization history for non-parametric features.""" + return getattr( + self.model_params.nonparam_hparams, "save_optim_history_nonparam", False + ) + def get_obs_pos(self): + assert self.run_type in { + "training", + "simulation", + "metrics", + "inference", + }, f"Unknown run_type: {self.run_type}" - """ - self.batch_size = training_params.batch_size - self.obscurations = psfm.tf_obscurations( - pupil_diam=model_params.pupil_diameter, - N_filter=model_params.LP_filter_length, - rotation_angle=model_params.obscuration_rotation_angle, + raw_pos = get_data_array( + data=self.data, run_type=self.run_type, key="positions" ) - self.output_dim = model_params.output_dim - self.tf_batch_poly_PSF = TFBatchPolychromaticPSF( + obs_pos = ensure_tensor(raw_pos, dtype=tf.float32) + + return obs_pos + + # === Lazy properties ===. + @property + def zks_total_contribution(self): + return self._zks_total_contribution + + @property + def n_zks_total(self): + """Get the total number of Zernike coefficients.""" + return self._n_zks_total + + @property + def zernike_maps(self): + """Get Zernike maps.""" + return self._zernike_maps + + @property + def opd_dim(self): + return self._opd_dim + + @property + def obscurations(self): + return self._obscurations + + @property + def tf_poly_Z_field(self): + """Lazy loading of the polynomial Zernike field layer.""" + if not hasattr(self, "_tf_poly_Z_field"): + self._tf_poly_Z_field = TFPolynomialZernikeField( + x_lims=self.model_params.x_lims, + y_lims=self.model_params.y_lims, + random_seed=self.model_params.param_hparams.random_seed, + n_zernikes=self.model_params.param_hparams.n_zernikes, + d_max=self.model_params.param_hparams.d_max, + ) + return self._tf_poly_Z_field + + @tf_poly_Z_field.deleter + def tf_poly_Z_field(self): + del self._tf_poly_Z_field + + @property + def tf_physical_layer(self): + """Lazy loading of the physical layer of the PSF model.""" + if not hasattr(self, "_tf_physical_layer"): + self._tf_physical_layer = TFPhysicalLayer( + self.obs_pos, + self.zks_total_contribution, + interpolation_type=self.model_params.interpolation_type, + interpolation_args=self.model_params.interpolation_args, + ) + return self._tf_physical_layer + + @property + def tf_zernike_OPD(self): + """Lazy loading of the Zernike Optical Path Difference (OPD) layer.""" + if not hasattr(self, "_tf_zernike_OPD"): + self._tf_zernike_OPD = TFZernikeOPD(zernike_maps=self.zernike_maps) + return self._tf_zernike_OPD + + def _build_tf_batch_poly_PSF(self): + """Eagerly build the TFBatchPolychromaticPSF layer with numpy-based obscurations.""" + return TFBatchPolychromaticPSF( obscurations=self.obscurations, output_Q=self.output_Q, output_dim=self.output_dim, ) - def _initialize_nonparametric_opd_layer(self, model_params, training_params): - """Initialize the non-parametric OPD layer. - - This method initializes the non-parametric OPD layer using the provided - `model_params` and `training_params`. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - training_params: Recursive Namespace - A Recursive Namespace object containing training hyperparameters for this PSF model class. - - """ - # self.d_max_nonparam = model_params.nonparam_hparams.d_max_nonparam - # self.opd_dim = tf.shape(self.zernike_maps)[1].numpy() - - self.tf_np_poly_opd = TFNonParametricPolynomialVariationsOPD( - x_lims=model_params.x_lims, - y_lims=model_params.y_lims, - random_seed=model_params.param_hparams.random_seed, - d_max=model_params.nonparam_hparams.d_max_nonparam, - opd_dim=tf.shape(self.zernike_maps)[1].numpy(), - ) + @property + def tf_np_poly_opd(self): + """Lazy loading of the non-parametric polynomial variations OPD layer.""" + if not hasattr(self, "_tf_np_poly_opd"): + self._tf_np_poly_opd = TFNonParametricPolynomialVariationsOPD( + x_lims=self.model_params.x_lims, + y_lims=self.model_params.y_lims, + random_seed=self.model_params.param_hparams.random_seed, + d_max=self.model_params.nonparam_hparams.d_max_nonparam, + opd_dim=self.opd_dim, + ) + return self._tf_np_poly_opd def get_coeff_matrix(self): """Get coefficient matrix.""" @@ -335,18 +327,15 @@ def set_output_Q(self, output_Q: float, output_dim: Optional[int] = None) -> Non This method updates the `output_Q` parameter, which defines the resampling factor for generating PSFs at different resolutions - relative to the telescope's native sampling. It also allows optionally - updating `output_dim`, which sets the output resolution of the PSF model. + relative to the telescope's native sampling. It also allows optionally updating `output_dim`, which sets the output resolution of the PSF model. If `output_dim` is provided, the PSF model's output resolution is updated. - The method then reinitializes the batch polychromatic PSF generator - to reflect the updated parameters. + The method then reinitializes the batch polychromatic PSF generator to reflect the updated parameters. Parameters ---------- output_Q : float - The resampling factor that determines the output PSF resolution - relative to the telescope's native sampling. + The resampling factor that determines the output PSF resolution relative to the telescope's native sampling. output_dim : Optional[int], default=None The new output dimension for the PSF model. If `None`, the output dimension remains unchanged. @@ -358,6 +347,7 @@ def set_output_Q(self, output_Q: float, output_dim: Optional[int] = None) -> Non self.output_Q = output_Q if output_dim is not None: self.output_dim = output_dim + # Reinitialize the PSF batch poly generator self.tf_batch_poly_PSF = TFBatchPolychromaticPSF( obscurations=self.obscurations, @@ -471,12 +461,16 @@ def predict_step(self, data, evaluate_step=False): # Compute zernikes from parametric model and physical layer zks_coeffs = self.predict_zernikes(input_positions) + # Propagate to obtain the OPD param_opd_maps = self.tf_zernike_OPD(zks_coeffs) + # Calculate the non parametric part nonparam_opd_maps = self.tf_np_poly_opd(input_positions) + # Add the estimations opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) + # Compute the polychromatic PSFs poly_psfs = self.tf_batch_poly_PSF([opd_maps, packed_SEDs]) @@ -519,10 +513,13 @@ def predict_mono_psfs(self, input_positions, lambda_obs, phase_N): # Predict zernikes from parametric model and physical layer zks_coeffs = self.predict_zernikes(input_positions) + # Propagate to obtain the OPD param_opd_maps = self.tf_zernike_OPD(zks_coeffs) + # Calculate the non parametric part nonparam_opd_maps = self.tf_np_poly_opd(input_positions) + # Add the estimations opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) @@ -547,10 +544,13 @@ def predict_opd(self, input_positions): """ # Predict zernikes from parametric model and physical layer zks_coeffs = self.predict_zernikes(input_positions) + # Propagate to obtain the OPD param_opd_maps = self.tf_zernike_OPD(zks_coeffs) + # Calculate the non parametric part nonparam_opd_maps = self.tf_np_poly_opd(input_positions) + # Add the estimations opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) @@ -585,9 +585,10 @@ def compute_zernikes(self, input_positions): zernike_prior = self.tf_physical_layer.call(input_positions) # Pad and sum the zernike coefficients - padded_zernike_params, padded_zernike_prior = self.pad_zernikes( - zernike_params, zernike_prior + padded_zernike_params, padded_zernike_prior = pad_tf_zernikes( + zernike_params, zernike_prior, self.n_zks_total ) + zernike_coeffs = tf.math.add(padded_zernike_params, padded_zernike_prior) return zernike_coeffs @@ -622,8 +623,8 @@ def predict_zernikes(self, input_positions): physical_layer_prediction = self.tf_physical_layer.predict(input_positions) # Pad and sum the Zernike coefficients - padded_zernike_params, padded_physical_layer_prediction = self.pad_zernikes( - zernike_params, physical_layer_prediction + padded_zernike_params, padded_physical_layer_prediction = pad_tf_zernikes( + zernike_params, physical_layer_prediction, self.n_zks_total ) zernike_coeffs = tf.math.add( padded_zernike_params, padded_physical_layer_prediction @@ -688,22 +689,21 @@ def call(self, inputs, training=True): # Compute zernikes from parametric model and physical layer zks_coeffs = self.compute_zernikes(input_positions) - # Propagate to obtain the OPD + # Parametric OPD maps from Zernikes param_opd_maps = self.tf_zernike_OPD(zks_coeffs) - # Add l2 loss on the parametric OPD - self.add_loss( - self.l2_param * tf.math.reduce_sum(tf.math.square(param_opd_maps)) - ) + # Add L2 regularization loss on parametric OPD maps + self.add_loss(self.l2_param * tf.reduce_sum(tf.square(param_opd_maps))) - # Calculate the non parametric part + # Non-parametric correction nonparam_opd_maps = self.tf_np_poly_opd(input_positions) - # Add the estimations - opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) + # Combine both contributions + opd_maps = tf.add(param_opd_maps, nonparam_opd_maps) # Compute the polychromatic PSFs poly_psfs = self.tf_batch_poly_PSF([opd_maps, packed_SEDs]) + # For the inference else: # Compute predictions diff --git a/src/wf_psf/psf_models/psf_model_semiparametric.py b/src/wf_psf/psf_models/models/psf_model_semiparametric.py similarity index 99% rename from src/wf_psf/psf_models/psf_model_semiparametric.py rename to src/wf_psf/psf_models/models/psf_model_semiparametric.py index dc535204..7b2ff04d 100644 --- a/src/wf_psf/psf_models/psf_model_semiparametric.py +++ b/src/wf_psf/psf_models/models/psf_model_semiparametric.py @@ -10,9 +10,9 @@ import numpy as np import tensorflow as tf from wf_psf.psf_models import psf_models as psfm -from wf_psf.psf_models import tf_layers as tfl +from wf_psf.psf_models.tf_modules import tf_layers as tfl from wf_psf.utils.utils import decompose_tf_obscured_opd_basis -from wf_psf.psf_models.tf_layers import ( +from wf_psf.psf_models.tf_modules.tf_layers import ( TFBatchPolychromaticPSF, TFBatchMonochromaticPSF, ) @@ -421,7 +421,7 @@ def project_DD_features(self, tf_zernike_cube=None): s_new = self.tf_np_poly_opd.S_mat - s_mat_projected self.assign_S_mat(s_new) - def call(self, inputs): + def call(self, inputs, **kwargs): """Define the PSF field forward model. [1] From positions to Zernike coefficients diff --git a/src/wf_psf/psf_models/psf_model_loader.py b/src/wf_psf/psf_models/psf_model_loader.py new file mode 100644 index 00000000..e41e3536 --- /dev/null +++ b/src/wf_psf/psf_models/psf_model_loader.py @@ -0,0 +1,57 @@ +"""PSF Model Loader. + +This module provides helper functions for loading trained PSF models. +It includes utilities to: +- Load a model from disk using its configuration and weights. +- Prepare inputs for inference or evaluation workflows. + +Author: Jennifer Pollack +""" + +import logging +from wf_psf.psf_models.psf_models import get_psf_model, get_psf_model_weights_filepath + +logger = logging.getLogger(__name__) + + +def load_trained_psf_model(training_conf, data_conf, weights_path_pattern): + """ + Loads a trained PSF model and applies saved weights. + + Parameters + ---------- + training_conf : RecursiveNamespace + Configuration object containing model parameters and training hyperparameters. + Supports attribute-style access to nested fields. + data_conf : RecursiveNamespace or dict + Configuration RecursiveNamespace object or a dictionary containing data parameters (e.g. pixel data, positions, masks, etc). + weights_path_pattern : str + Glob-style pattern used to locate the model weights file. + + Returns + ------- + model : tf.keras.Model or compatible + The PSF model instance with loaded weights. + + Raises + ------ + RuntimeError + If loading the model weights fails for any reason. + """ + model = get_psf_model( + training_conf.training.model_params, + training_conf.training.training_hparams, + data_conf, + ) + + weights_path = get_psf_model_weights_filepath(weights_path_pattern) + + try: + logger.info(f"Loading PSF model weights from {weights_path}") + status = model.load_weights(weights_path) + status.expect_partial() + + except Exception as e: + logger.exception("Failed to load model weights.") + raise RuntimeError("Model weight loading failed.") from e + return model diff --git a/src/wf_psf/psf_models/psf_models.py b/src/wf_psf/psf_models/psf_models.py index 463d1c52..4c44f698 100644 --- a/src/wf_psf/psf_models/psf_models.py +++ b/src/wf_psf/psf_models/psf_models.py @@ -187,24 +187,24 @@ def build_PSF_model(model_inst, optimizer=None, loss=None, metrics=None): def get_psf_model_weights_filepath(weights_filepath): """Get PSF model weights filepath. - A function to return the basename of the user-specified psf model weights path. + A function to return the basename of the user-specified PSF model weights path. Parameters ---------- weights_filepath: str - Basename of the psf model weights to be loaded. + Basename of the PSF model weights to be loaded. Returns ------- str - The absolute path concatenated to the basename of the psf model weights to be loaded. + The absolute path concatenated to the basename of the PSF model weights to be loaded. """ try: return glob.glob(weights_filepath)[0].split(".")[0] except IndexError: logger.exception( - "PSF weights file not found. Check that you've specified the correct weights file in the metrics config file." + "PSF weights file not found. Check that you've specified the correct weights file in the your config file." ) raise PSFModelError("PSF model weights error.") diff --git a/src/wf_psf/psf_models/tf_modules/__init__.py b/src/wf_psf/psf_models/tf_modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/wf_psf/psf_models/tf_layers.py b/src/wf_psf/psf_models/tf_modules/tf_layers.py similarity index 98% rename from src/wf_psf/psf_models/tf_layers.py rename to src/wf_psf/psf_models/tf_modules/tf_layers.py index eda43305..cdd01e16 100644 --- a/src/wf_psf/psf_models/tf_layers.py +++ b/src/wf_psf/psf_models/tf_modules/tf_layers.py @@ -7,7 +7,8 @@ """ import tensorflow as tf -from wf_psf.psf_models.tf_modules import TFMonochromaticPSF +from wf_psf.psf_models.tf_modules.tf_modules import TFMonochromaticPSF +from wf_psf.psf_models.tf_modules.tf_utils import find_position_indices from wf_psf.utils.utils import calc_poly_position_mat import wf_psf.utils.utils as utils from wf_psf.utils.interpolation import tfa_interpolate_spline_rbf @@ -997,13 +998,10 @@ def call(self, positions): If the shape of the input `positions` tensor is not compatible. """ + # Find indices for all positions in one batch operation + idx = find_position_indices(self.obs_pos, positions) - def calc_index(idx_pos): - return tf.where(tf.equal(self.obs_pos, idx_pos))[0, 0] - - # Calculate the indices of the input batch - indices = tf.map_fn(calc_index, positions, fn_output_signature=tf.int64) - # Recover the prior zernikes from the batch indexes - batch_zks = tf.gather(self.zks_prior, indices=indices, axis=0, batch_dims=0) + # Gather the corresponding Zernike coefficients + batch_zks = tf.gather(self.zks_prior, idx, axis=0) return batch_zks[:, :, tf.newaxis, tf.newaxis] diff --git a/src/wf_psf/psf_models/tf_modules.py b/src/wf_psf/psf_models/tf_modules/tf_modules.py similarity index 100% rename from src/wf_psf/psf_models/tf_modules.py rename to src/wf_psf/psf_models/tf_modules/tf_modules.py diff --git a/src/wf_psf/psf_models/tf_psf_field.py b/src/wf_psf/psf_models/tf_modules/tf_psf_field.py similarity index 97% rename from src/wf_psf/psf_models/tf_psf_field.py rename to src/wf_psf/psf_models/tf_modules/tf_psf_field.py index 0c9ba2f7..07b523d1 100644 --- a/src/wf_psf/psf_models/tf_psf_field.py +++ b/src/wf_psf/psf_models/tf_modules/tf_psf_field.py @@ -9,15 +9,16 @@ import numpy as np import tensorflow as tf from tensorflow.python.keras.engine import data_adapter -from wf_psf.psf_models.tf_layers import ( +from wf_psf.psf_models.tf_modules.tf_layers import ( TFZernikeOPD, TFBatchPolychromaticPSF, TFBatchMonochromaticPSF, TFPhysicalLayer, ) -from wf_psf.psf_models.psf_model_semiparametric import TFSemiParametricField -from wf_psf.data.training_preprocessing import get_obs_positions +from wf_psf.psf_models.models.psf_model_semiparametric import TFSemiParametricField +from wf_psf.data.data_handler import get_data_array from wf_psf.psf_models import psf_models as psfm +from wf_psf.psf_models.tf_modules.tf_utils import ensure_tensor import logging logger = logging.getLogger(__name__) @@ -221,7 +222,9 @@ def __init__(self, model_params, training_params, data, coeff_mat): self.output_Q = model_params.output_Q # Inputs: TF_physical_layer - self.obs_pos = get_obs_positions(data) + self.obs_pos = ensure_tensor( + get_data_array(data, data.run_type, key="positions"), dtype=tf.float32 + ) self.zks_prior = get_ground_truth_zernike(data) self.n_zks_prior = tf.shape(self.zks_prior)[1].numpy() diff --git a/src/wf_psf/psf_models/tf_modules/tf_utils.py b/src/wf_psf/psf_models/tf_modules/tf_utils.py new file mode 100644 index 00000000..4bd1246a --- /dev/null +++ b/src/wf_psf/psf_models/tf_modules/tf_utils.py @@ -0,0 +1,99 @@ +"""TensorFlow Utilities Module. + +Provides lightweight utility functions for safely converting and managing data types +within TensorFlow-based workflows. + +Includes: +- `ensure_tensor`: ensures inputs are TensorFlow tensors with specified dtype + +These tools are designed to support PSF model components, including lazy property evaluation, +data input validation, and type normalization. + +This module is intended for internal use in model layers and inference components to enforce +TensorFlow-compatible inputs. + +Authors: Jennifer Pollack +""" + +import tensorflow as tf + + +@tf.function +def find_position_indices(obs_pos, batch_positions): + """Find indices of batch positions within observed positions using vectorized operations. + + This function locates the indices of multiple query positions within a + reference set of observed positions using broadcasting and vectorized operations. + Each position in the batch must have an exact match in the observed positions. + + Parameters + ---------- + obs_pos : tf.Tensor + Reference positions tensor of shape (n_obs, 2), where n_obs is the number of + observed positions. Each row contains [x, y] coordinates. + batch_positions : tf.Tensor + Query positions tensor of shape (batch_size, 2), where batch_size is the number + of positions to look up. Each row contains [x, y] coordinates. + + Returns + ------- + indices : tf.Tensor + Tensor of shape (batch_size,) containing the indices of each batch position + within obs_pos. The dtype is tf.int64. + + Raises + ------ + tf.errors.InvalidArgumentError + If any position in batch_positions is not found in obs_pos. + + Notes + ----- + Uses exact equality matching - positions must match exactly. More efficient than + iterative lookups for multiple positions due to vectorized operations. + """ + # Shape: obs_pos (n_obs, 2), batch_positions (batch_size, 2) + # Expand for broadcasting: (1, n_obs, 2) and (batch_size, 1, 2) + obs_expanded = tf.expand_dims(obs_pos, 0) + pos_expanded = tf.expand_dims(batch_positions, 1) + + # Compare all positions at once: (batch_size, n_obs) + matches = tf.reduce_all(tf.equal(obs_expanded, pos_expanded), axis=2) + + # Find the index of the matching position for each batch item + # argmax returns the first True value's index along axis=1 + indices = tf.argmax(tf.cast(matches, tf.int32), axis=1) + + # Verify all positions were found + tf.debugging.assert_equal( + tf.reduce_all(tf.reduce_any(matches, axis=1)), + True, + message="Some positions not found in obs_pos", + ) + + return indices + + +def ensure_tensor(input_array, dtype=tf.float32): + """ + Ensure the input is a TensorFlow tensor of the specified dtype. + + Parameters + ---------- + input_array : array-like, tf.Tensor, or np.ndarray + The input to convert. + dtype : tf.DType, optional + The desired TensorFlow dtype (default: tf.float32). + + Returns + ------- + tf.Tensor + A TensorFlow tensor with the specified dtype. + """ + if tf.is_tensor(input_array): + # If already a tensor, optionally cast dtype if different + if input_array.dtype != dtype: + return tf.cast(input_array, dtype) + return input_array + else: + # Convert numpy arrays or other types to tensor + return tf.convert_to_tensor(input_array, dtype=dtype) diff --git a/src/wf_psf/psf_models/zernikes.py b/src/wf_psf/psf_models/zernikes.py deleted file mode 100644 index dcfa6e39..00000000 --- a/src/wf_psf/psf_models/zernikes.py +++ /dev/null @@ -1,58 +0,0 @@ -"""Zernikes. - -A module to make Zernike maps. - -:Author: Tobias Liaudat and Jennifer Pollack - -""" - -import numpy as np -import zernike as zk -import logging - -logger = logging.getLogger(__name__) - - -def zernike_generator(n_zernikes, wfe_dim): - """ - Generate Zernike maps. - - Based on the zernike github repository. - https://github.com/jacopoantonello/zernike - - Parameters - ---------- - n_zernikes: int - Number of Zernike modes desired. - wfe_dim: int - Dimension of the Zernike map [wfe_dim x wfe_dim]. - - Returns - ------- - zernikes: list of np.ndarray - List containing the Zernike modes. - The values outside the unit circle are filled with NaNs. - """ - # Calculate which n (from the (n,m) Zernike convention) we need - # so that we have the desired total number of Zernike coefficients - min_n = (-3 + np.sqrt(1 + 8 * n_zernikes)) / 2 - n = int(np.ceil(min_n)) - - # Initialize the zernike generator - cart = zk.RZern(n) - # Create a [-1,1] mesh - ddx = np.linspace(-1.0, 1.0, wfe_dim) - ddy = np.linspace(-1.0, 1.0, wfe_dim) - xv, yv = np.meshgrid(ddx, ddy) - cart.make_cart_grid(xv, yv) - - c = np.zeros(cart.nk) - zernikes = [] - - # Extract each Zernike map one by one - for i in range(n_zernikes): - c *= 0.0 - c[i] = 1.0 - zernikes.append(cart.eval_grid(c, matrix=True)) - - return zernikes diff --git a/src/wf_psf/tests/__init__.py b/src/wf_psf/tests/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/src/wf_psf/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/src/wf_psf/tests/conftest.py b/src/wf_psf/tests/conftest.py index 5b617c63..beb6b9fb 100644 --- a/src/wf_psf/tests/conftest.py +++ b/src/wf_psf/tests/conftest.py @@ -13,7 +13,7 @@ from wf_psf.training.train import TrainingParamsHandler from wf_psf.utils.configs_handler import DataConfigHandler from wf_psf.psf_models import psf_models -from wf_psf.data.training_preprocessing import DataHandler +from wf_psf.data.data_handler import DataHandler training_config = RecursiveNamespace( id_name="-coherent_euclid_200stars", diff --git a/src/wf_psf/tests/data/validation/masked_loss/results/plot_results.ipynb b/src/wf_psf/tests/data/validation/masked_loss/results/plot_results.ipynb index bec994a4..c36a9e31 100644 --- a/src/wf_psf/tests/data/validation/masked_loss/results/plot_results.ipynb +++ b/src/wf_psf/tests/data/validation/masked_loss/results/plot_results.ipynb @@ -17,19 +17,19 @@ "outputs": [], "source": [ "# Trained on masked data, tested on masked data\n", - "metrics_path_mm = '../runs/masked_train_masked_test/wf-outputs/wf-outputs-202503131718/metrics/metrics-polymask_train_mask_test.npy'\n", + "metrics_path_mm = \"../runs/masked_train_masked_test/wf-outputs/wf-outputs-202503131718/metrics/metrics-polymask_train_mask_test.npy\"\n", "mask_train_mask_test = np.load(metrics_path_mm, allow_pickle=True)[()]\n", "\n", "# Trained on masked data, tested on unmasked data\n", - "metrics_path_mu = '../runs/masked_train_unit_mask_test/wf-outputs/wf-outputs-202503131720/metrics/metrics-polymasked_train_unit_mask_test.npy'\n", + "metrics_path_mu = \"../runs/masked_train_unit_mask_test/wf-outputs/wf-outputs-202503131720/metrics/metrics-polymasked_train_unit_mask_test.npy\"\n", "mask_train_nomask_test = np.load(metrics_path_mu, allow_pickle=True)[()]\n", "\n", "# Trained on unmasked data, tested on unmasked data\n", - "metrics_path_c = '../runs/control_train/wf-outputs/wf-outputs-202503131716/metrics/metrics-polycontrol_train.npy'\n", + "metrics_path_c = \"../runs/control_train/wf-outputs/wf-outputs-202503131716/metrics/metrics-polycontrol_train.npy\"\n", "control_train = np.load(metrics_path_c, allow_pickle=True)[()]\n", "\n", "# Trained and tested with unitary masks\n", - "metrics_path_u = '../runs/unit_masked_train/wf-outputs/wf-outputs-202503131721/metrics/metrics-polyunit_masked_train.npy'\n", + "metrics_path_u = \"../runs/unit_masked_train/wf-outputs/wf-outputs-202503131721/metrics/metrics-polyunit_masked_train.npy\"\n", "unitary = np.load(metrics_path_u, allow_pickle=True)[()]" ] }, @@ -50,8 +50,8 @@ ], "source": [ "print(mask_train_mask_test.keys())\n", - "print(mask_train_mask_test['test_metrics'].keys())\n", - "print(mask_train_mask_test['test_metrics']['poly_metric'].keys())" + "print(mask_train_mask_test[\"test_metrics\"].keys())\n", + "print(mask_train_mask_test[\"test_metrics\"][\"poly_metric\"].keys())" ] }, { @@ -60,17 +60,25 @@ "metadata": {}, "outputs": [], "source": [ - "mask_test_mask_test_rel_rmse = mask_train_mask_test['test_metrics']['poly_metric']['rel_rmse']\n", - "mask_test_mask_test_std_rel_rmse = mask_train_mask_test['test_metrics']['poly_metric']['std_rel_rmse']\n", + "mask_test_mask_test_rel_rmse = mask_train_mask_test[\"test_metrics\"][\"poly_metric\"][\n", + " \"rel_rmse\"\n", + "]\n", + "mask_test_mask_test_std_rel_rmse = mask_train_mask_test[\"test_metrics\"][\"poly_metric\"][\n", + " \"std_rel_rmse\"\n", + "]\n", "\n", - "mask_test_nomask_test_rel_rmse = mask_train_nomask_test['test_metrics']['poly_metric']['rel_rmse']\n", - "mask_test_nomask_test_std_rel_rmse = mask_train_nomask_test['test_metrics']['poly_metric']['std_rel_rmse']\n", + "mask_test_nomask_test_rel_rmse = mask_train_nomask_test[\"test_metrics\"][\"poly_metric\"][\n", + " \"rel_rmse\"\n", + "]\n", + "mask_test_nomask_test_std_rel_rmse = mask_train_nomask_test[\"test_metrics\"][\n", + " \"poly_metric\"\n", + "][\"std_rel_rmse\"]\n", "\n", - "control_test_rel_rmse = control_train['test_metrics']['poly_metric']['rel_rmse']\n", - "control_test_std_rel_rmse = control_train['test_metrics']['poly_metric']['std_rel_rmse']\n", + "control_test_rel_rmse = control_train[\"test_metrics\"][\"poly_metric\"][\"rel_rmse\"]\n", + "control_test_std_rel_rmse = control_train[\"test_metrics\"][\"poly_metric\"][\"std_rel_rmse\"]\n", "\n", - "unitary_test_rel_rmse = unitary['test_metrics']['poly_metric']['rel_rmse']\n", - "unitary_test_std_rel_rmse = unitary['test_metrics']['poly_metric']['std_rel_rmse']" + "unitary_test_rel_rmse = unitary[\"test_metrics\"][\"poly_metric\"][\"rel_rmse\"]\n", + "unitary_test_std_rel_rmse = unitary[\"test_metrics\"][\"poly_metric\"][\"std_rel_rmse\"]" ] }, { @@ -92,12 +100,29 @@ "source": [ "# Plot the results\n", "fig, ax = plt.subplots(figsize=(8, 4))\n", - "plt.title('Relative RMSE 1x - Test dataset')\n", - "ax.errorbar([0, 1, 2, 3], [control_test_rel_rmse, mask_test_mask_test_rel_rmse, mask_test_nomask_test_rel_rmse, unitary_test_rel_rmse], yerr=[control_test_std_rel_rmse, mask_test_mask_test_std_rel_rmse, mask_test_nomask_test_std_rel_rmse, unitary_test_std_rel_rmse], fmt='o')\n", + "plt.title(\"Relative RMSE 1x - Test dataset\")\n", + "ax.errorbar(\n", + " [0, 1, 2, 3],\n", + " [\n", + " control_test_rel_rmse,\n", + " mask_test_mask_test_rel_rmse,\n", + " mask_test_nomask_test_rel_rmse,\n", + " unitary_test_rel_rmse,\n", + " ],\n", + " yerr=[\n", + " control_test_std_rel_rmse,\n", + " mask_test_mask_test_std_rel_rmse,\n", + " mask_test_nomask_test_std_rel_rmse,\n", + " unitary_test_std_rel_rmse,\n", + " ],\n", + " fmt=\"o\",\n", + ")\n", "ax.set_xticks([0, 1, 2, 3])\n", - "ax.set_xticklabels(['Control Train', 'Mask Train Mask Test', 'Mask Train Unit Mask Test', 'Unitary'])\n", - "ax.grid('minor')\n", - "ax.set_ylabel('Relative RMSE')\n", + "ax.set_xticklabels(\n", + " [\"Control Train\", \"Mask Train Mask Test\", \"Mask Train Unit Mask Test\", \"Unitary\"]\n", + ")\n", + "ax.grid(\"minor\")\n", + "ax.set_ylabel(\"Relative RMSE\")\n", "plt.show()" ] }, @@ -107,17 +132,27 @@ "metadata": {}, "outputs": [], "source": [ - "mask_train_mask_test_rel_rmse = mask_train_mask_test['train_metrics']['poly_metric']['rel_rmse']\n", - "mask_train_mask_test_std_rel_rmse = mask_train_mask_test['train_metrics']['poly_metric']['std_rel_rmse']\n", + "mask_train_mask_test_rel_rmse = mask_train_mask_test[\"train_metrics\"][\"poly_metric\"][\n", + " \"rel_rmse\"\n", + "]\n", + "mask_train_mask_test_std_rel_rmse = mask_train_mask_test[\"train_metrics\"][\n", + " \"poly_metric\"\n", + "][\"std_rel_rmse\"]\n", "\n", - "mask_train_nomask_test_rel_rmse = mask_train_nomask_test['train_metrics']['poly_metric']['rel_rmse']\n", - "mask_train_nomask_test_std_rel_rmse = mask_train_nomask_test['train_metrics']['poly_metric']['std_rel_rmse']\n", + "mask_train_nomask_test_rel_rmse = mask_train_nomask_test[\"train_metrics\"][\n", + " \"poly_metric\"\n", + "][\"rel_rmse\"]\n", + "mask_train_nomask_test_std_rel_rmse = mask_train_nomask_test[\"train_metrics\"][\n", + " \"poly_metric\"\n", + "][\"std_rel_rmse\"]\n", "\n", - "control_train_rel_rmse = control_train['train_metrics']['poly_metric']['rel_rmse']\n", - "control_train_std_rel_rmse = control_train['train_metrics']['poly_metric']['std_rel_rmse']\n", + "control_train_rel_rmse = control_train[\"train_metrics\"][\"poly_metric\"][\"rel_rmse\"]\n", + "control_train_std_rel_rmse = control_train[\"train_metrics\"][\"poly_metric\"][\n", + " \"std_rel_rmse\"\n", + "]\n", "\n", - "unitary_rel_rmse = unitary['train_metrics']['poly_metric']['rel_rmse']\n", - "unitary_std_rel_rmse = unitary['train_metrics']['poly_metric']['std_rel_rmse']" + "unitary_rel_rmse = unitary[\"train_metrics\"][\"poly_metric\"][\"rel_rmse\"]\n", + "unitary_std_rel_rmse = unitary[\"train_metrics\"][\"poly_metric\"][\"std_rel_rmse\"]" ] }, { @@ -139,12 +174,29 @@ "source": [ "# Plot the results\n", "fig, ax = plt.subplots(figsize=(8, 4))\n", - "plt.title('Relative RMSE 1x - Train dataset')\n", - "ax.errorbar([0, 1, 2, 3], [control_train_rel_rmse, mask_train_mask_test_rel_rmse, mask_train_nomask_test_rel_rmse, unitary_rel_rmse], yerr=[control_train_std_rel_rmse, mask_train_mask_test_std_rel_rmse, mask_train_nomask_test_std_rel_rmse, unitary_std_rel_rmse], fmt='o')\n", + "plt.title(\"Relative RMSE 1x - Train dataset\")\n", + "ax.errorbar(\n", + " [0, 1, 2, 3],\n", + " [\n", + " control_train_rel_rmse,\n", + " mask_train_mask_test_rel_rmse,\n", + " mask_train_nomask_test_rel_rmse,\n", + " unitary_rel_rmse,\n", + " ],\n", + " yerr=[\n", + " control_train_std_rel_rmse,\n", + " mask_train_mask_test_std_rel_rmse,\n", + " mask_train_nomask_test_std_rel_rmse,\n", + " unitary_std_rel_rmse,\n", + " ],\n", + " fmt=\"o\",\n", + ")\n", "ax.set_xticks([0, 1, 2, 3])\n", - "ax.set_xticklabels(['Control Train', 'Mask Train Mask Test', 'Mask Train Unit Mask Test', 'Unitary'])\n", - "ax.set_ylabel('Relative RMSE')\n", - "ax.grid('minor')\n", + "ax.set_xticklabels(\n", + " [\"Control Train\", \"Mask Train Mask Test\", \"Mask Train Unit Mask Test\", \"Unitary\"]\n", + ")\n", + "ax.set_ylabel(\"Relative RMSE\")\n", + "ax.grid(\"minor\")\n", "plt.show()" ] }, @@ -167,16 +219,50 @@ "source": [ "# Plot test and train relative RMSE in the same plot\n", "fig, ax = plt.subplots(figsize=(8, 4))\n", - "plt.title('Relative RMSE 1x - Train and Test dataset')\n", - "ax.errorbar([0, 1, 2, 3], [control_train_rel_rmse, mask_train_mask_test_rel_rmse, mask_train_nomask_test_rel_rmse, unitary_rel_rmse], yerr=[control_train_std_rel_rmse, mask_train_mask_test_std_rel_rmse, mask_train_nomask_test_std_rel_rmse, unitary_std_rel_rmse], fmt='o', label='Train')\n", - "ax.errorbar([0.02, 1.02, 2.02, 3.02], [control_test_rel_rmse, mask_test_mask_test_rel_rmse, mask_test_nomask_test_rel_rmse, unitary_test_rel_rmse], yerr=[control_test_std_rel_rmse, mask_test_mask_test_std_rel_rmse, mask_test_nomask_test_std_rel_rmse, unitary_test_std_rel_rmse], fmt='o', label='Test')\n", + "plt.title(\"Relative RMSE 1x - Train and Test dataset\")\n", + "ax.errorbar(\n", + " [0, 1, 2, 3],\n", + " [\n", + " control_train_rel_rmse,\n", + " mask_train_mask_test_rel_rmse,\n", + " mask_train_nomask_test_rel_rmse,\n", + " unitary_rel_rmse,\n", + " ],\n", + " yerr=[\n", + " control_train_std_rel_rmse,\n", + " mask_train_mask_test_std_rel_rmse,\n", + " mask_train_nomask_test_std_rel_rmse,\n", + " unitary_std_rel_rmse,\n", + " ],\n", + " fmt=\"o\",\n", + " label=\"Train\",\n", + ")\n", + "ax.errorbar(\n", + " [0.02, 1.02, 2.02, 3.02],\n", + " [\n", + " control_test_rel_rmse,\n", + " mask_test_mask_test_rel_rmse,\n", + " mask_test_nomask_test_rel_rmse,\n", + " unitary_test_rel_rmse,\n", + " ],\n", + " yerr=[\n", + " control_test_std_rel_rmse,\n", + " mask_test_mask_test_std_rel_rmse,\n", + " mask_test_nomask_test_std_rel_rmse,\n", + " unitary_test_std_rel_rmse,\n", + " ],\n", + " fmt=\"o\",\n", + " label=\"Test\",\n", + ")\n", "ax.set_xticks([0, 1, 2, 3])\n", - "ax.set_xticklabels(['Control Train', 'Mask Train Mask Test', 'Mask Train Unit Mask Test', 'Unitary'])\n", - "ax.set_ylabel('Relative RMSE')\n", - "ax.grid('minor')\n", + "ax.set_xticklabels(\n", + " [\"Control Train\", \"Mask Train Mask Test\", \"Mask Train Unit Mask Test\", \"Unitary\"]\n", + ")\n", + "ax.set_ylabel(\"Relative RMSE\")\n", + "ax.grid(\"minor\")\n", "plt.legend()\n", "# plt.show()\n", - "plt.savefig('masked_loss_validation.pdf')\n" + "plt.savefig(\"masked_loss_validation.pdf\")" ] }, { diff --git a/src/wf_psf/tests/test_data/__init__.py b/src/wf_psf/tests/test_data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/wf_psf/tests/test_utils/centroids_test.py b/src/wf_psf/tests/test_data/centroids_test.py similarity index 62% rename from src/wf_psf/tests/test_utils/centroids_test.py rename to src/wf_psf/tests/test_data/centroids_test.py index 8557704f..185da8d7 100644 --- a/src/wf_psf/tests/test_utils/centroids_test.py +++ b/src/wf_psf/tests/test_data/centroids_test.py @@ -8,8 +8,9 @@ import numpy as np import pytest +from wf_psf.data.centroids import compute_centroid_correction, CentroidEstimator +from wf_psf.utils.read_config import RecursiveNamespace from unittest.mock import MagicMock, patch -from wf_psf.utils.centroids import compute_zernike_tip_tilt, CentroidEstimator # Function to compute centroid based on first-order moments @@ -28,25 +29,6 @@ def calculate_centroid(image, mask=None): return (xc, yc) -@pytest.fixture -def simple_image(): - """Fixture for a batch of simple star images.""" - num_images = 1 # Change this to test with multiple images - image = np.zeros((num_images, 5, 5)) # Create a 3D array - image[:, 2, 2] = 1 # Place the star at the center for each image - return image - - -@pytest.fixture -def multiple_images(): - """Fixture for a batch of images with stars at different positions.""" - images = np.zeros((3, 5, 5)) # 3 images, each of size 5x5 - images[0, 2, 2] = 1 # Star at center of image 0 - images[1, 1, 3] = 1 # Star at (1, 3) in image 1 - images[2, 3, 1] = 1 # Star at (3, 1) in image 2 - return images - - @pytest.fixture def simple_star_and_mask(): """Fixture for an image with multiple non-zero pixels for centroid calculation.""" @@ -68,12 +50,6 @@ def simple_star_and_mask(): return image, mask -@pytest.fixture -def identity_mask(): - """Creates a mask where all pixels are fully considered.""" - return np.ones((5, 5)) - - @pytest.fixture def simple_image_with_mask(simple_image): """Fixture for a batch of star images with masks.""" @@ -129,133 +105,84 @@ def batch_images(): return images -def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_mask): - """Test compute_zernike_tip_tilt with single batch input and mocks.""" - # Mock the CentroidEstimator class - mock_centroid_calc = mocker.patch( - "wf_psf.utils.centroids.CentroidEstimator", autospec=True - ) - - # Create a mock instance and configure get_intra_pixel_shifts() - mock_instance = mock_centroid_calc.return_value - mock_instance.get_intra_pixel_shifts.return_value = np.array( - [[0.05, -0.02]] - ) # Shape (1, 2) - - # Mock shift_x_y_to_zk1_2_wavediff to return predictable values - mock_shift_fn = mocker.patch( - "wf_psf.utils.centroids.shift_x_y_to_zk1_2_wavediff", - side_effect=lambda shift: shift * 0.5, # Mocked conversion for test - ) - - # Define test inputs (batch of 1 image) - pixel_sampling = 12e-6 - reference_shifts = [-1 / 3, -1 / 3] # Default Euclid conditions - - # Run the function - zernike_corrections = compute_zernike_tip_tilt( - simple_image, identity_mask, pixel_sampling, reference_shifts - ) - zernike_corrections = compute_zernike_tip_tilt( - simple_image, identity_mask, pixel_sampling, reference_shifts - ) - - # Expected shifts based on centroid calculation - expected_dx = reference_shifts[1] - (-0.02) # Expected x-axis shift in meters - expected_dy = reference_shifts[0] - 0.05 # Expected y-axis shift in meters - - # Expected calls to the mocked function - # Extract the arguments passed to mock_shift_fn - args, _ = mock_shift_fn.call_args_list[0] # Get the first call args - - # Compare expected values with the actual arguments passed to the mock function - np.testing.assert_allclose( - args[0][0], expected_dx * pixel_sampling, rtol=1e-7, atol=0 +def test_compute_centroid_correction_with_masks(mock_data): + """Test compute_centroid_correction function with masks present.""" + # Given that compute_centroid_correction expects a model_params and data object + model_params = RecursiveNamespace( + pix_sampling=12e-6, # Example pixel sampling in meters + correct_centroids=True, + reference_shifts=["-1/3", "-1/3"], ) - # Check dy values similarly - np.testing.assert_allclose( - args[0][1], expected_dy * pixel_sampling, rtol=1e-7, atol=0 - ) - - # Expected values based on mock side_effect (0.5 * shift) - np.testing.assert_allclose( - zernike_corrections[0, 0], expected_dx * pixel_sampling * 0.5 - ) # Zk1 - np.testing.assert_allclose( - zernike_corrections[0, 1], expected_dy * pixel_sampling * 0.5 - ) # Zk2 - - -def test_compute_zernike_tip_tilt_batch(mocker, multiple_images): - """Test compute_zernike_tip_tilt with batch input and mocks.""" - # Mock the CentroidEstimator class - mock_centroid_calc = mocker.patch( - "wf_psf.utils.centroids.CentroidEstimator", autospec=True - ) - - # Create a mock instance and configure get_intra_pixel_shifts() - mock_instance = mock_centroid_calc.return_value - mock_instance.get_intra_pixel_shifts.return_value = np.array( - [[0.05, -0.02], [0.04, -0.01], [0.06, -0.03]] - ) # Shape (3, 2) - - # Mock shift_x_y_to_zk1_2_wavediff to return predictable values - mock_shift_fn = mocker.patch( - "wf_psf.utils.centroids.shift_x_y_to_zk1_2_wavediff", - side_effect=lambda shift: shift * 0.5, # Mocked conversion for test - ) - - # Define test inputs (batch of 3 images) - pixel_sampling = 12e-6 - reference_shifts = [-1 / 3, -1 / 3] # Default Euclid conditions - - # Run the function - zernike_corrections = compute_zernike_tip_tilt( - star_images=multiple_images, - pixel_sampling=pixel_sampling, - reference_shifts=reference_shifts, - ) + # Wrap mock_data into a dict to match the function signature + centroid_dataset = { + "stamps": mock_data.training_data.dataset["noisy_stars"], + "masks": mock_data.training_data.dataset["masks"], + } - # Check if the mock function was called once with the full batch - assert len(mock_shift_fn.call_args_list) == 1, ( - f"Expected 1 call, but got {len(mock_shift_fn.call_args_list)}" + # Mock the internal function calls: + with ( + patch( + "wf_psf.data.centroids.compute_zernike_tip_tilt" + ) as mock_compute_zernike_tip_tilt, + ): + # Mock compute_zernike_tip_tilt to return synthetic Zernike coefficients + mock_compute_zernike_tip_tilt.return_value = np.array([[0.1, 0.2], [0.3, 0.4]]) + + # Call the function under test + result = compute_centroid_correction(model_params, centroid_dataset) + + # Ensure the result has the correct shape + assert result.shape == (4, 3) # Should be (n_stars, 3 Zernike components) + + assert np.allclose( + result[0, :], np.array([0, -0.1, -0.2]) + ) # First star Zernike coefficients + assert np.allclose( + result[1, :], np.array([0, -0.3, -0.4]) + ) # Second star Zernike coefficients + + +def test_compute_centroid_correction_without_masks(mock_data): + """Test compute_centroid_correction function when no masks are provided.""" + # Define model parameters + model_params = RecursiveNamespace( + pix_sampling=12e-6, # Example pixel sampling in meters + correct_centroids=True, + reference_shifts=["-1/3", "-1/3"], ) - # Get the arguments passed to the mock function for the batch of images - args, _ = mock_shift_fn.call_args_list[0] + # Wrap mock_data into a dict to match the function signature + centroid_dataset = { + "stamps": mock_data.training_data.dataset["noisy_stars"], + } - print("Shape of args[0]:", args[0].shape) - print("Contents of args[0]:", args[0]) - print("Mock function call args list:", mock_shift_fn.call_args_list) + # Mock internal function calls + with ( + patch( + "wf_psf.data.centroids.compute_zernike_tip_tilt" + ) as mock_compute_zernike_tip_tilt, + ): - # Reshape args[0] to (N, 2) for batch processing - args_array = np.array(args[0]).reshape(-1, 2) + # Mock compute_zernike_tip_tilt assuming no masks + mock_compute_zernike_tip_tilt.return_value = np.array([[0.1, 0.2], [0.3, 0.4]]) - # Process the displacements and expected values for each image in the batch - expected_dx = ( - reference_shifts[1] - mock_instance.get_intra_pixel_shifts.return_value[:, 1] - ) # Expected x-axis shift in meters + # Call function under test + result = compute_centroid_correction(model_params, centroid_dataset) - expected_dy = ( - reference_shifts[0] - mock_instance.get_intra_pixel_shifts.return_value[:, 0] - ) # Expected y-axis shift in meters + # Validate result shape + assert result.shape == (4, 3) # (n_stars, 3 Zernike components) - # Compare expected values with the actual arguments passed to the mock function - np.testing.assert_allclose( - args_array[:, 0], expected_dx * pixel_sampling, rtol=1e-7, atol=0 - ) - np.testing.assert_allclose( - args_array[:, 1], expected_dy * pixel_sampling, rtol=1e-7, atol=0 - ) - - # Expected values based on mock side_effect (0.5 * shift) - np.testing.assert_allclose( - zernike_corrections[:, 0], expected_dx * pixel_sampling * 0.5 - ) # Zk1 for each image - np.testing.assert_allclose( - zernike_corrections[:, 1], expected_dy * pixel_sampling * 0.5 - ) # Zk2 for each image + # Validate expected values (adjust based on behavior) + expected_result = np.array( + [ + [0, -0.1, -0.2], # From training data + [0, -0.3, -0.4], + [0, -0.1, -0.2], # From test data (reused mocked return) + [0, -0.3, -0.4], + ] + ) + assert np.allclose(result, expected_result) # Test for centroid calculation without mask @@ -442,9 +369,9 @@ def test_intra_pixel_shifts(simple_image_with_centroid): expected_y_shift = 2.7 - 2.0 # yc - yc0 # Check that the shifts are correct - assert np.isclose(shifts[0], expected_x_shift), ( - f"Expected {expected_x_shift}, got {shifts[0]}" - ) - assert np.isclose(shifts[1], expected_y_shift), ( - f"Expected {expected_y_shift}, got {shifts[1]}" - ) + assert np.isclose( + shifts[0], expected_x_shift + ), f"Expected {expected_x_shift}, got {shifts[0]}" + assert np.isclose( + shifts[1], expected_y_shift + ), f"Expected {expected_y_shift}, got {shifts[1]}" diff --git a/src/wf_psf/tests/test_data/conftest.py b/src/wf_psf/tests/test_data/conftest.py index 6159d53a..131922e5 100644 --- a/src/wf_psf/tests/test_data/conftest.py +++ b/src/wf_psf/tests/test_data/conftest.py @@ -9,8 +9,13 @@ """ import pytest +import numpy as np +import tensorflow as tf +from types import SimpleNamespace + from wf_psf.utils.read_config import RecursiveNamespace from wf_psf.psf_models import psf_models +from wf_psf.tests.test_data.test_data_utils import MockData training_config = RecursiveNamespace( id_name="-coherent_euclid_200stars", @@ -93,6 +98,76 @@ ) +@pytest.fixture +def mock_data(scope="module"): + """Fixture to provide mock data for testing.""" + # Mock positions and Zernike priors + training_positions = tf.constant([[1, 2], [3, 4]]) + test_positions = tf.constant([[5, 6], [7, 8]]) + training_zernike_priors = tf.constant([[0.1, 0.2], [0.3, 0.4]]) + test_zernike_priors = tf.constant([[0.5, 0.6], [0.7, 0.8]]) + + # Define dummy 5x5 image patches for stars (mock star images) + # Define varied values for 5x5 star images + noisy_stars = tf.constant( + [np.arange(25).reshape(5, 5), np.arange(25, 50).reshape(5, 5)], dtype=tf.float32 + ) + + noisy_masks = tf.constant([np.eye(5), np.ones((5, 5))], dtype=tf.float32) + + stars = tf.constant([np.full((5, 5), 100), np.full((5, 5), 200)], dtype=tf.float32) + + masks = tf.constant([np.zeros((5, 5)), np.tri(5)], dtype=tf.float32) + + return MockData( + training_positions, + test_positions, + training_zernike_priors, + test_zernike_priors, + noisy_stars, + noisy_masks, + stars, + masks, + ) + + +@pytest.fixture +def mock_data_inference(): + """Flat dataset for inference path only.""" + return SimpleNamespace( + dataset={ + "positions": np.array([[9, 9], [10, 10]]), + "zernike_prior": np.array([[0.9, 0.9]]), + # no "missing_key" → used to trigger allow_missing behavior + } + ) + + +@pytest.fixture +def simple_image(scope="module"): + """Fixture for a simple star image.""" + num_images = 1 # Change this to test with multiple images + image = np.zeros((num_images, 5, 5)) # Create a 3D array + image[:, 2, 2] = 1 # Place the star at the center for each image + return image + + +@pytest.fixture +def identity_mask(scope="module"): + """Creates a mask where all pixels are fully considered.""" + return np.ones((5, 5)) + + +@pytest.fixture +def multiple_images(scope="module"): + """Fixture for a batch of images with stars at different positions.""" + images = np.zeros((3, 5, 5)) # 3 images, each of size 5x5 + images[0, 2, 2] = 1 # Star at center of image 0 + images[1, 1, 3] = 1 # Star at (1, 3) in image 1 + images[2, 3, 1] = 1 # Star at (3, 1) in image 2 + return images + + @pytest.fixture(scope="module", params=[data]) def data_params(): return data diff --git a/src/wf_psf/tests/test_data/data_handler_test.py b/src/wf_psf/tests/test_data/data_handler_test.py new file mode 100644 index 00000000..d29771a1 --- /dev/null +++ b/src/wf_psf/tests/test_data/data_handler_test.py @@ -0,0 +1,361 @@ +import pytest +import numpy as np +import tensorflow as tf +from wf_psf.data.data_handler import ( + DataHandler, + get_data_array, + extract_star_data, +) +from wf_psf.utils.read_config import RecursiveNamespace + + +def mock_sed(): + # Create a fake SED with shape (n_wavelengths,) — match what your real SEDs look like + return np.linspace(0.1, 1.0, 50) + + +def test_process_sed_data_auto_load(data_params, simPSF): + # load_data=True → dataset is used and SEDs processed automatically + data_handler = DataHandler( + "training", data_params.training, simPSF, n_bins_lambda=10, load_data=True + ) + assert data_handler.sed_data is not None + assert data_handler.sed_data.shape[1] == 10 # n_bins_lambda + + +def test_load_train_dataset(tmp_path, simPSF): + # Create a temporary directory and a temporary data file + data_dir = tmp_path / "data" + data_dir.mkdir() + temp_data_dir = data_dir / "train_data.npy" + + # Mock dataset + mock_dataset = { + "positions": np.array([[1, 2], [3, 4]]), + "noisy_stars": np.array([[5, 6], [7, 8]]), + "SEDs": np.array([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]]), + } + + # Save the mock dataset to the temporary data file + np.save(temp_data_dir, mock_dataset) + + # Initialize DataHandler instance + data_params = RecursiveNamespace(data_dir=str(data_dir), file="train_data.npy") + + n_bins_lambda = 10 + data_handler = DataHandler( + "training", data_params, simPSF, n_bins_lambda, load_data=False + ) + + # Call the load_dataset method + data_handler.load_dataset() + + # Assertions + assert np.array_equal(data_handler.dataset["positions"], mock_dataset["positions"]) + assert np.array_equal( + data_handler.dataset["noisy_stars"], mock_dataset["noisy_stars"] + ) + assert np.array_equal(data_handler.dataset["SEDs"], mock_dataset["SEDs"]) + + +def test_load_test_dataset(tmp_path, simPSF): + # Create a temporary directory and a temporary data file + data_dir = tmp_path / "data" + data_dir.mkdir() + temp_data_dir = data_dir / "test_data.npy" + + # Mock dataset + mock_dataset = { + "positions": np.array([[1, 2], [3, 4]]), + "stars": np.array([[5, 6], [7, 8]]), + "SEDs": np.array([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]]), + } + + # Save the mock dataset to the temporary data file + np.save(temp_data_dir, mock_dataset) + + # Initialize DataHandler instance + data_params = RecursiveNamespace(data_dir=str(data_dir), file="test_data.npy") + + n_bins_lambda = 10 + data_handler = DataHandler( + dataset_type="test", + data_params=data_params, + simPSF=simPSF, + n_bins_lambda=n_bins_lambda, + load_data=False, + ) + + # Call the load_dataset method + data_handler.load_dataset() + + # Assertions + assert np.array_equal(data_handler.dataset["positions"], mock_dataset["positions"]) + assert np.array_equal(data_handler.dataset["stars"], mock_dataset["stars"]) + assert np.array_equal(data_handler.dataset["SEDs"], mock_dataset["SEDs"]) + + +def test_validate_train_dataset_missing_noisy_stars_raises(tmp_path, simPSF): + """Test that validation raises an error if 'noisy_stars' is missing in training data.""" + data_dir = tmp_path / "data" + data_dir.mkdir() + temp_data_file = data_dir / "train_data.npy" + + mock_dataset = { + "positions": np.array([[1, 2], [3, 4]]), # No 'noisy_stars' key + "SEDs": np.array([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]]), + } + + np.save(temp_data_file, mock_dataset) + + data_params = RecursiveNamespace(data_dir=str(data_dir), file="train_data.npy") + + n_bins_lambda = 10 + data_handler = DataHandler( + "training", data_params, simPSF, n_bins_lambda, load_data=False + ) + + with pytest.raises( + ValueError, match="Missing required field 'noisy_stars' in training dataset." + ): + data_handler.load_dataset() + data_handler.validate_and_process_dataset() + + +def test_load_test_dataset_missing_stars(tmp_path, simPSF): + """Test that a warning is raised if 'stars' is missing in test data.""" + data_dir = tmp_path / "data" + data_dir.mkdir() + temp_data_file = data_dir / "test_data.npy" + + mock_dataset = { + "positions": np.array([[1, 2], [3, 4]]), # No 'stars' key + "SEDs": np.array([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]]), + } + + np.save(temp_data_file, mock_dataset) + + data_params = RecursiveNamespace(data_dir=str(data_dir), file="test_data.npy") + + n_bins_lambda = 10 + data_handler = DataHandler( + "test", data_params, simPSF, n_bins_lambda, load_data=False + ) + + with pytest.raises( + ValueError, match="Missing required field 'stars' in test dataset." + ): + data_handler.load_dataset() + data_handler.validate_and_process_dataset() + + +def test_extract_star_data_valid_keys(mock_data): + """Test extracting valid data from the dataset.""" + result = extract_star_data(mock_data, train_key="noisy_stars", test_key="stars") + + expected = tf.concat( + [ + tf.constant( + [np.arange(25).reshape(5, 5), np.arange(25, 50).reshape(5, 5)], + dtype=tf.float32, + ), + tf.constant([np.full((5, 5), 100), np.full((5, 5), 200)], dtype=tf.float32), + ], + axis=0, + ) + + np.testing.assert_array_equal(result, expected) + + +def test_extract_star_data_masks(mock_data): + """Test extracting star masks from the dataset.""" + result = extract_star_data(mock_data, train_key="masks", test_key="masks") + + mask0 = np.eye(5, dtype=np.float32) + mask1 = np.ones((5, 5), dtype=np.float32) + mask2 = np.zeros((5, 5), dtype=np.float32) + mask3 = np.tri(5, dtype=np.float32) + + expected = np.array([mask0, mask1, mask2, mask3], dtype=np.float32) + + np.testing.assert_array_equal(result, expected) + + +def test_extract_star_data_missing_key(mock_data): + """Test that the function raises a KeyError when a key is missing.""" + with pytest.raises(KeyError, match="Missing keys in dataset: \\['invalid_key'\\]"): + extract_star_data(mock_data, train_key="invalid_key", test_key="stars") + + +def test_extract_star_data_partially_missing_key(mock_data): + """Test that the function raises a KeyError if only one key is missing.""" + with pytest.raises( + KeyError, match="Missing keys in dataset: \\['missing_stars'\\]" + ): + extract_star_data(mock_data, train_key="noisy_stars", test_key="missing_stars") + + +def test_extract_star_data_tensor_conversion(mock_data): + """Test that the function properly converts TensorFlow tensors to NumPy arrays.""" + result = extract_star_data(mock_data, train_key="noisy_stars", test_key="stars") + + assert isinstance(result, np.ndarray), "The result should be a NumPy array" + assert result.dtype == np.float32, "The NumPy array should have dtype float32" + + +def test_reference_shifts_broadcasting(): + reference_shifts = [-1 / 3, -1 / 3] # Example reference_shifts + shifts = np.random.rand(2, 2400) # Example shifts array + + # Ensure reference_shifts is a NumPy array (if it's not already) + reference_shifts = np.array(reference_shifts) + + # Broadcast reference_shifts to match the shape of shifts + reference_shifts = np.broadcast_to( + reference_shifts[:, None], shifts.shape + ) # Shape will be (2, 2400) + + # Ensure shapes are compatible for subtraction + displacements = reference_shifts - shifts + + # Test the result + assert displacements.shape == shifts.shape, "Shapes do not match" + assert np.all(displacements.shape == (2, 2400)), "Broadcasting failed" + + +@pytest.mark.parametrize( + "run_type,data_fixture,key,train_key,test_key,allow_missing,expect", + [ + # =================================================== + # training/simulation/metrics → extract_star_data path + # =================================================== + ( + "training", + "mock_data", + None, + "positions", + None, + False, + np.array([[1, 2], [3, 4], [5, 6], [7, 8]]), + ), + ( + "simulation", + "mock_data", + "none", + "noisy_stars", + "stars", + True, + # will concatenate noisy_stars from train and stars from test + # expected shape: (4, 5, 5) + # validate shape only, not full content (too large) + "shape:(4, 5, 5)", + ), + ( + "metrics", + "mock_data", + "zernike_prior", + None, + None, + True, + np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]]), + ), + # ================= + # inference (success) + # ================= + ( + "inference", + "mock_data_inference", + "positions", + None, + None, + False, + np.array([[9, 9], [10, 10]]), + ), + ( + "inference", + "mock_data_inference", + "zernike_prior", + None, + None, + False, + np.array([[0.9, 0.9]]), + ), + # ============================== + # inference → allow_missing=True + # ============================== + ( + "inference", + "mock_data_inference", + None, + None, + None, + True, + None, + ), + ( + "inference", + "mock_data_inference", + "missing_key", + None, + None, + True, + None, + ), + # ================================= + # inference → allow_missing=False → errors + # ================================= + ( + "inference", + "mock_data_inference", + None, + None, + None, + False, + pytest.raises(ValueError), + ), + ( + "inference", + "mock_data_inference", + "missing_key", + None, + None, + False, + pytest.raises(KeyError), + ), + ], +) +def test_get_data_array_v2( + request, run_type, data_fixture, key, train_key, test_key, allow_missing, expect +): + data = request.getfixturevalue(data_fixture) + + if hasattr(expect, "__enter__") and hasattr(expect, "__exit__"): + with expect: + get_data_array( + data, + run_type, + key=key, + train_key=train_key, + test_key=test_key, + allow_missing=allow_missing, + ) + return + + result = get_data_array( + data, + run_type, + key=key, + train_key=train_key, + test_key=test_key, + allow_missing=allow_missing, + ) + + if expect is None: + assert result is None + elif isinstance(expect, str) and expect.startswith("shape:"): + expected_shape = tuple(eval(expect.replace("shape:", ""))) + assert isinstance(result, np.ndarray) + assert result.shape == expected_shape + else: + assert isinstance(result, np.ndarray) + assert np.allclose(result, expect, rtol=1e-6, atol=1e-8) diff --git a/src/wf_psf/tests/test_data/data_zernike_utils_test.py b/src/wf_psf/tests/test_data/data_zernike_utils_test.py new file mode 100644 index 00000000..66d23309 --- /dev/null +++ b/src/wf_psf/tests/test_data/data_zernike_utils_test.py @@ -0,0 +1,543 @@ +import pytest +import numpy as np +from unittest.mock import MagicMock, patch +import tensorflow as tf +from wf_psf.data.data_zernike_utils import ( + ZernikeInputsFactory, + get_np_zernike_prior, + pad_contribution_to_order, + combine_zernike_contributions, + assemble_zernike_contributions, + compute_zernike_tip_tilt, + pad_tf_zernikes, +) +from types import SimpleNamespace as RecursiveNamespace + + +@pytest.fixture +def mock_model_params(): + return RecursiveNamespace( + use_prior=True, + correct_centroids=True, + add_ccd_misalignments=True, + param_hparams=RecursiveNamespace(n_zernikes=6), + ) + + +@pytest.fixture +def dummy_prior(): + return np.ones((4, 6), dtype=np.float32) + + +@pytest.fixture +def dummy_centroid_dataset(): + return {"training": "dummy_train", "test": "dummy_test"} + + +def test_training_without_prior(mock_model_params, mock_data): + mock_model_params.use_prior = False + + # Clear priors to simulate no prior being used + mock_data.training_data.dataset.pop("zernike_prior", None) + mock_data.test_data.dataset.pop("zernike_prior", None) + + zinputs = ZernikeInputsFactory.build( + data=mock_data, run_type="training", model_params=mock_model_params + ) + + mock_data_stamps = np.concatenate( + [ + mock_data.training_data.dataset["noisy_stars"], + mock_data.test_data.dataset["stars"], + ] + ) + mock_data_masks = np.concatenate( + [ + mock_data.training_data.dataset["masks"], + mock_data.test_data.dataset["masks"], + ] + ) + + assert np.allclose( + zinputs.centroid_dataset["stamps"], mock_data_stamps, rtol=1e-6, atol=1e-8 + ) + + assert np.allclose( + zinputs.centroid_dataset["masks"], mock_data_masks, rtol=1e-6, atol=1e-8 + ) + + assert zinputs.zernike_prior is None + + expected_positions = np.concatenate( + [ + mock_data.training_data.dataset["positions"], + mock_data.test_data.dataset["positions"], + ] + ) + np.testing.assert_array_equal(zinputs.misalignment_positions, expected_positions) + + +def test_training_with_dataset_prior(mock_model_params, mock_data): + mock_model_params.use_prior = True + + zinputs = ZernikeInputsFactory.build( + data=mock_data, run_type="training", model_params=mock_model_params + ) + + expected_priors = np.concatenate( + ( + mock_data.training_data.dataset["zernike_prior"], + mock_data.test_data.dataset["zernike_prior"], + ), + axis=0, + ) + np.testing.assert_array_equal(zinputs.zernike_prior, expected_priors) + + +def test_training_with_explicit_prior(mock_model_params, caplog): + mock_model_params.use_prior = True + data = MagicMock() + data.training_dataset = {"positions": np.ones((1, 2))} + data.test_dataset = {"positions": np.zeros((1, 2))} + + explicit_prior = np.array([9.0, 9.0, 9.0]) + + with caplog.at_level("WARNING"): + zinputs = ZernikeInputsFactory.build( + data, "training", mock_model_params, prior=explicit_prior + ) + + assert "Explicit prior provided; ignoring dataset-based prior." in caplog.text + assert (zinputs.zernike_prior == explicit_prior).all() + + +def test_inference_with_dict_and_prior(mock_model_params): + mock_model_params.use_prior = True + data = RecursiveNamespace( + dataset={ + "positions": tf.ones((5, 2)), + "zernike_prior": tf.constant([42.0, 0.0]), + } + ) + + zinputs = ZernikeInputsFactory.build(data, "inference", mock_model_params) + + for key in ["stamps", "masks"]: + assert zinputs.centroid_dataset[key] is None + + # NumPy array comparison + np.testing.assert_array_equal( + zinputs.misalignment_positions, data.dataset["positions"].numpy() + ) + + # TensorFlow tensor comparison + tf.debugging.assert_equal(zinputs.zernike_prior, data.dataset["zernike_prior"]) + + +def test_invalid_run_type(mock_model_params): + data = {"positions": np.ones((2, 2))} + with pytest.raises(ValueError, match="Unsupported run_type"): + ZernikeInputsFactory.build(data, "invalid_mode", mock_model_params) + + +def test_get_np_zernike_prior(): + # Mock training and test data + training_prior = np.array([[1, 2, 3], [4, 5, 6]]) + test_prior = np.array([[7, 8, 9]]) + + # Construct fake DataConfigHandler structure using RecursiveNamespace + data = RecursiveNamespace( + training_data=RecursiveNamespace(dataset={"zernike_prior": training_prior}), + test_data=RecursiveNamespace(dataset={"zernike_prior": test_prior}), + ) + + expected_prior = np.concatenate((training_prior, test_prior), axis=0) + + result = get_np_zernike_prior(data) + + # Assert shape and values match expected + np.testing.assert_array_equal(result, expected_prior) + + +def test_pad_contribution_to_order(): + # Input: batch of 2 samples, each with 3 Zernike coefficients + input_contribution = np.array( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ) + + max_order = 5 # Target size: pad to 5 coefficients + + expected_output = np.array( + [ + [1.0, 2.0, 3.0, 0.0, 0.0], + [4.0, 5.0, 6.0, 0.0, 0.0], + ] + ) + + padded = pad_contribution_to_order(input_contribution, max_order) + + assert padded.shape == (2, 5), "Output shape should match padded shape" + np.testing.assert_array_equal(padded, expected_output) + + +def test_no_padding_needed(): + """If current order equals max_order, return should be unchanged.""" + input_contribution = np.array([[1, 2, 3], [4, 5, 6]]) + max_order = 3 + output = pad_contribution_to_order(input_contribution, max_order) + assert output.shape == input_contribution.shape + np.testing.assert_array_equal(output, input_contribution) + + +def test_padding_to_much_higher_order(): + """Pad from order 2 to order 10.""" + input_contribution = np.array([[1, 2], [3, 4]]) + max_order = 10 + expected_output = np.hstack([input_contribution, np.zeros((2, 8))]) + output = pad_contribution_to_order(input_contribution, max_order) + assert output.shape == (2, 10) + np.testing.assert_array_equal(output, expected_output) + + +def test_empty_contribution(): + """Test behavior with empty input array (0 features).""" + input_contribution = np.empty((3, 0)) # 3 samples, 0 coefficients + max_order = 4 + expected_output = np.zeros((3, 4)) + output = pad_contribution_to_order(input_contribution, max_order) + assert output.shape == (3, 4) + np.testing.assert_array_equal(output, expected_output) + + +def test_zero_samples(): + """Test with zero samples (empty batch).""" + input_contribution = np.empty((0, 3)) # 0 samples, 3 coefficients + max_order = 5 + expected_output = np.empty((0, 5)) + output = pad_contribution_to_order(input_contribution, max_order) + assert output.shape == (0, 5) + np.testing.assert_array_equal(output, expected_output) + + +def test_combine_zernike_contributions_basic_case(): + """Combine two contributions with matching sample count and varying order.""" + contrib1 = np.array([[1, 2], [3, 4]]) # shape (2, 2) + contrib2 = np.array([[5], [6]]) # shape (2, 1) + expected = np.array([[1 + 5, 2 + 0], [3 + 6, 4 + 0]]) # padded contrib2 to (2, 2) + result = combine_zernike_contributions([contrib1, contrib2]) + np.testing.assert_array_equal(result, expected) + + +def test_combine_multiple_contributions(): + """Combine three contributions.""" + c1 = np.array([[1, 2, 3]]) # shape (1, 3) + c2 = np.array([[4, 5]]) # shape (1, 2) + c3 = np.array([[6]]) # shape (1, 1) + expected = np.array([[1 + 4 + 6, 2 + 5 + 0, 3 + 0 + 0]]) # shape (1, 3) + result = combine_zernike_contributions([c1, c2, c3]) + np.testing.assert_array_equal(result, expected) + + +def test_empty_input_list(): + """Raise ValueError when input list is empty.""" + with pytest.raises(ValueError, match="No contributions provided."): + combine_zernike_contributions([]) + + +def test_inconsistent_sample_count(): + """Raise error or produce incorrect shape if contributions have inconsistent sample counts.""" + c1 = np.array([[1, 2], [3, 4]]) # shape (2, 2) + c2 = np.array([[5, 6]]) # shape (1, 2) + with pytest.raises(ValueError): + combine_zernike_contributions([c1, c2]) + + +def test_single_contribution(): + """Combining a single contribution should return the same array (no-op).""" + contrib = np.array([[7, 8, 9], [10, 11, 12]]) + result = combine_zernike_contributions([contrib]) + np.testing.assert_array_equal(result, contrib) + + +def test_zero_order_contributions(): + """Contributions with 0 Zernike coefficients.""" + contrib1 = np.empty((2, 0)) # 2 samples, 0 coefficients + contrib2 = np.empty((2, 0)) + expected = np.empty((2, 0)) + result = combine_zernike_contributions([contrib1, contrib2]) + assert result.shape == (2, 0) + np.testing.assert_array_equal(result, expected) + + +@patch("wf_psf.data.data_zernike_utils.compute_centroid_correction") +@patch("wf_psf.data.data_zernike_utils.compute_ccd_misalignment") +def test_full_contribution_combination( + mock_ccd, mock_centroid, mock_model_params, dummy_prior, dummy_centroid_dataset +): + mock_centroid.return_value = np.full((4, 6), 2.0) + mock_ccd.return_value = np.full((4, 6), 3.0) + dummy_positions = np.full((4, 6), 1.0) + + result = assemble_zernike_contributions( + model_params=mock_model_params, + zernike_prior=dummy_prior, + centroid_dataset=dummy_centroid_dataset, + positions=dummy_positions, + ) + + expected = dummy_prior + 2.0 + 3.0 + np.testing.assert_allclose(result.numpy(), expected) + + +def test_prior_only(mock_model_params, dummy_prior): + mock_model_params.correct_centroids = False + mock_model_params.add_ccd_misalignments = False + + result = assemble_zernike_contributions( + model_params=mock_model_params, + zernike_prior=dummy_prior, + centroid_dataset=None, + positions=None, + ) + + np.testing.assert_array_equal(result.numpy(), dummy_prior) + + +def test_no_contributions_returns_zeros(): + model_params = RecursiveNamespace( + use_prior=False, + correct_centroids=False, + add_ccd_misalignments=False, + param_hparams=RecursiveNamespace(n_zernikes=8), + ) + + result = assemble_zernike_contributions(model_params) + + assert isinstance(result, tf.Tensor) + assert result.shape == (1, 8) + np.testing.assert_array_equal(result.numpy(), np.zeros((1, 8))) + + +def test_prior_as_tensor(mock_model_params): + tensor_prior = tf.ones((4, 6), dtype=tf.float32) + + mock_model_params.correct_centroids = False + mock_model_params.add_ccd_misalignments = False + + result = assemble_zernike_contributions( + model_params=mock_model_params, zernike_prior=tensor_prior + ) + assert tf.executing_eagerly(), "TensorFlow must be in eager mode for this test" + assert isinstance(result, tf.Tensor) + np.testing.assert_array_equal(result.numpy(), np.ones((4, 6))) + + +@patch("wf_psf.data.data_zernike_utils.compute_centroid_correction") +def test_inconsistent_shapes_raises_error( + mock_centroid, mock_model_params, dummy_prior, dummy_centroid_dataset +): + mock_model_params.add_ccd_misalignments = False + mock_centroid.return_value = np.ones((5, 6)) # 5 samples instead of 4 + + with pytest.raises( + ValueError, match="All contributions must have the same number of samples" + ): + assemble_zernike_contributions( + model_params=mock_model_params, + zernike_prior=dummy_prior, + centroid_dataset=dummy_centroid_dataset, + positions=None, + ) + + +def test_pad_zernikes_num_of_zernikes_equal(): + # Prepare your test tensors + zk_param = tf.constant([[[[1.0]]], [[[2.0]]]]) # Shape (2, 1, 1, 1) + zk_prior = tf.constant([[[[1.0]]], [[[2.0]]]]) # Same shape + + # Reshape to (1, 2, 1, 1) + zk_param = tf.reshape(zk_param, (1, 2, 1, 1)) + zk_prior = tf.reshape(zk_prior, (1, 2, 1, 1)) + + # Reset _n_zks_total to max number of zernikes (2 here) + n_zks_total = max(tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy()) + + # Call pad_zernikes method + padded_zk_param, padded_zk_prior = pad_tf_zernikes(zk_param, zk_prior, n_zks_total) + + # Assert shapes are equal and correct + assert padded_zk_param.shape[1] == n_zks_total + assert padded_zk_prior.shape[1] == n_zks_total + + # If num zernikes already equal, output should be unchanged + np.testing.assert_array_equal(padded_zk_param.numpy(), zk_param.numpy()) + np.testing.assert_array_equal(padded_zk_prior.numpy(), zk_prior.numpy()) + + +def test_pad_zernikes_prior_greater_than_param(): + zk_param = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) + zk_prior = tf.constant([[[[1]], [[2]], [[3]], [[4]], [[5]]]]) # Shape: (5, 1, 1, 1) + + # Reshape the tensors to have the desired shapes + zk_param = tf.reshape(zk_param, (1, 2, 1, 1)) # Reshaping tensor1 to (1, 2, 1, 1) + zk_prior = tf.reshape(zk_prior, (1, 5, 1, 1)) # Reshaping tensor2 to (1, 5, 1, 1) + + # Reset n_zks_total attribute + n_zks_total = max(tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy()) + + # Call the method under test + padded_zk_param, padded_zk_prior = pad_tf_zernikes(zk_param, zk_prior, n_zks_total) + + # Assert that the padded tensors have the correct shapes + assert padded_zk_param.shape == (1, 5, 1, 1) + assert padded_zk_prior.shape == (1, 5, 1, 1) + + +def test_pad_zernikes_param_greater_than_prior(): + zk_param = tf.constant([[[[10]], [[20]], [[30]], [[40]]]]) # Shape: (4, 1, 1, 1) + zk_prior = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) + + # Reshape the tensors to have the desired shapes + zk_param = tf.reshape(zk_param, (1, 4, 1, 1)) # Reshaping tensor1 to (1, 2, 1, 1) + zk_prior = tf.reshape(zk_prior, (1, 2, 1, 1)) # Reshaping tensor2 to (1, 4, 1, 1) + + # Reset n_zks_total attribute + n_zks_total = max(tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy()) + + # Call the method under test + padded_zk_param, padded_zk_prior = pad_tf_zernikes(zk_param, zk_prior, n_zks_total) + + # Assert that the padded tensors have the correct shapes + assert padded_zk_param.shape == (1, 4, 1, 1) + assert padded_zk_prior.shape == (1, 4, 1, 1) + + +def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_mask): + """Test compute_zernike_tip_tilt handling with single batch input and mocks.""" + # Mock the CentroidEstimator class + mock_centroid_calc = mocker.patch( + "wf_psf.data.centroids.CentroidEstimator", autospec=True + ) + + # Create a mock instance and configure get_intra_pixel_shifts() + mock_instance = mock_centroid_calc.return_value + mock_instance.get_intra_pixel_shifts.return_value = np.array( + [[0.05, -0.02]] + ) # Shape (1, 2) + + # Mock shift_x_y_to_zk1_2_wavediff to return predictable values + mock_shift_fn = mocker.patch( + "wf_psf.data.data_zernike_utils.shift_x_y_to_zk1_2_wavediff", + side_effect=lambda shift: shift * 0.5, # Mocked conversion for test + ) + + # Define test inputs (batch of 1 image) + pixel_sampling = 12e-6 + reference_shifts = [-1 / 3, -1 / 3] # Default Euclid conditions + + # Run the function + zernike_corrections = compute_zernike_tip_tilt( + simple_image, identity_mask, pixel_sampling, reference_shifts + ) + zernike_corrections = compute_zernike_tip_tilt( + simple_image, identity_mask, pixel_sampling, reference_shifts + ) + + # Expected shifts based on centroid calculation + expected_dx = reference_shifts[1] - (-0.02) # Expected x-axis shift in meters + expected_dy = reference_shifts[0] - 0.05 # Expected y-axis shift in meters + + # Expected calls to the mocked function + # Extract the arguments passed to mock_shift_fn + args, _ = mock_shift_fn.call_args_list[0] # Get the first call args + + # Compare expected values with the actual arguments passed to the mock function + np.testing.assert_allclose( + args[0][0], expected_dx * pixel_sampling, rtol=1e-7, atol=0 + ) + + # Check dy values similarly + np.testing.assert_allclose( + args[0][1], expected_dy * pixel_sampling, rtol=1e-7, atol=0 + ) + + # Expected values based on mock side_effect (0.5 * shift) + np.testing.assert_allclose( + zernike_corrections[0, 0], expected_dx * pixel_sampling * 0.5 + ) # Zk1 + np.testing.assert_allclose( + zernike_corrections[0, 1], expected_dy * pixel_sampling * 0.5 + ) # Zk2 + + +def test_compute_zernike_tip_tilt_batch(mocker, multiple_images): + """Test compute_zernike_tip_tilt batch handling of multiple inputs.""" + # Mock the CentroidEstimator class + mock_centroid_calc = mocker.patch( + "wf_psf.data.centroids.CentroidEstimator", autospec=True + ) + + # Create a mock instance and configure get_intra_pixel_shifts() + mock_instance = mock_centroid_calc.return_value + mock_instance.get_intra_pixel_shifts.return_value = np.array( + [[0.05, -0.02], [0.04, -0.01], [0.06, -0.03]] + ) # Shape (3, 2) + + # Mock shift_x_y_to_zk1_2_wavediff to return predictable values + mock_shift_fn = mocker.patch( + "wf_psf.data.data_zernike_utils.shift_x_y_to_zk1_2_wavediff", + side_effect=lambda shift: shift * 0.5, # Mocked conversion for test + ) + + # Define test inputs (batch of 3 images) + pixel_sampling = 12e-6 + reference_shifts = [-1 / 3, -1 / 3] # Default Euclid conditions + + # Run the function + zernike_corrections = compute_zernike_tip_tilt( + star_images=multiple_images, + pixel_sampling=pixel_sampling, + reference_shifts=reference_shifts, + ) + + # Check if the mock function was called once with the full batch + assert ( + len(mock_shift_fn.call_args_list) == 1 + ), f"Expected 1 call, but got {len(mock_shift_fn.call_args_list)}" + + # Get the arguments passed to the mock function for the batch of images + args, _ = mock_shift_fn.call_args_list[0] + + print("Shape of args[0]:", args[0].shape) + print("Contents of args[0]:", args[0]) + print("Mock function call args list:", mock_shift_fn.call_args_list) + + # Reshape args[0] to (N, 2) for batch processing + args_array = np.array(args[0]).reshape(-1, 2) + + # Process the displacements and expected values for each image in the batch + expected_dx = ( + reference_shifts[1] - mock_instance.get_intra_pixel_shifts.return_value[:, 1] + ) # Expected x-axis shift in meters + expected_dy = ( + reference_shifts[0] - mock_instance.get_intra_pixel_shifts.return_value[:, 0] + ) # Expected y-axis shift in meters + + # Compare expected values with the actual arguments passed to the mock function + np.testing.assert_allclose( + args_array[:, 0], expected_dx * pixel_sampling, rtol=1e-7, atol=0 + ) + np.testing.assert_allclose( + args_array[:, 1], expected_dy * pixel_sampling, rtol=1e-7, atol=0 + ) + + # Expected values based on mock side_effect (0.5 * shift) + np.testing.assert_allclose( + zernike_corrections[:, 0], expected_dx * pixel_sampling * 0.5 + ) # Zk1 for each image + np.testing.assert_allclose( + zernike_corrections[:, 1], expected_dy * pixel_sampling * 0.5 + ) # Zk2 for each image diff --git a/src/wf_psf/tests/test_data/test_data_utils.py b/src/wf_psf/tests/test_data/test_data_utils.py new file mode 100644 index 00000000..de111427 --- /dev/null +++ b/src/wf_psf/tests/test_data/test_data_utils.py @@ -0,0 +1,36 @@ +class MockDataset: + def __init__(self, positions, zernike_priors, star_type, stars, masks): + self.dataset = { + "positions": positions, + "zernike_prior": zernike_priors, + star_type: stars, + "masks": masks, + } + + +class MockData: + def __init__( + self, + training_positions, + test_positions, + training_zernike_priors=None, + test_zernike_priors=None, + noisy_stars=None, + noisy_masks=None, + stars=None, + masks=None, + ): + self.training_data = MockDataset( + positions=training_positions, + zernike_priors=training_zernike_priors, + star_type="noisy_stars", + stars=noisy_stars, + masks=noisy_masks, + ) + self.test_data = MockDataset( + positions=test_positions, + zernike_priors=test_zernike_priors, + star_type="stars", + stars=stars, + masks=masks, + ) diff --git a/src/wf_psf/tests/test_data/training_preprocessing_test.py b/src/wf_psf/tests/test_data/training_preprocessing_test.py deleted file mode 100644 index 3efc8272..00000000 --- a/src/wf_psf/tests/test_data/training_preprocessing_test.py +++ /dev/null @@ -1,407 +0,0 @@ -import pytest -import numpy as np -import tensorflow as tf -from wf_psf.utils.read_config import RecursiveNamespace -from wf_psf.data.training_preprocessing import ( - DataHandler, - get_obs_positions, - get_zernike_prior, - extract_star_data, - compute_centroid_correction, -) -from unittest.mock import patch - - -class MockData: - def __init__( - self, - training_positions, - test_positions, - training_zernike_priors, - test_zernike_priors, - noisy_stars=None, - noisy_masks=None, - stars=None, - masks=None, - ): - self.training_data = MockDataset( - positions=training_positions, - zernike_priors=training_zernike_priors, - star_type="noisy_stars", - stars=noisy_stars, - masks=noisy_masks, - ) - self.test_data = MockDataset( - positions=test_positions, - zernike_priors=test_zernike_priors, - star_type="stars", - stars=stars, - masks=masks, - ) - - -class MockDataset: - def __init__(self, positions, zernike_priors, star_type, stars, masks): - self.dataset = { - "positions": positions, - "zernike_prior": zernike_priors, - star_type: stars, - "masks": masks, - } - - -@pytest.fixture -def mock_data(): - # Mock data for testing - # Mock training and test positions and Zernike priors - training_positions = np.array([[1, 2], [3, 4]]) - test_positions = np.array([[5, 6], [7, 8]]) - training_zernike_priors = np.array([[0.1, 0.2], [0.3, 0.4]]) - test_zernike_priors = np.array([[0.5, 0.6], [0.7, 0.8]]) - # Mock noisy stars, stars and masks - noisy_stars = tf.constant([[1, 2], [3, 4]], dtype=tf.float32) - noisy_masks = tf.constant([[1], [0]], dtype=tf.float32) - stars = tf.constant([[5, 6], [7, 8]], dtype=tf.float32) - masks = tf.constant([[0], [1]], dtype=tf.float32) - - return MockData( - training_positions, - test_positions, - training_zernike_priors, - test_zernike_priors, - noisy_stars, - noisy_masks, - stars, - masks, - ) - - -def test_load_train_dataset(tmp_path, data_params, simPSF): - # Create a temporary directory and a temporary data file - data_dir = tmp_path / "data" - data_dir.mkdir() - temp_data_dir = data_dir / "train_data.npy" - - # Mock dataset - mock_dataset = { - "positions": np.array([[1, 2], [3, 4]]), - "noisy_stars": np.array([[5, 6], [7, 8]]), - "SEDs": np.array([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]]), - } - - # Save the mock dataset to the temporary data file - np.save(temp_data_dir, mock_dataset) - - # Initialize DataHandler instance - data_params = RecursiveNamespace( - training=RecursiveNamespace(data_dir=str(data_dir), file="train_data.npy") - ) - - n_bins_lambda = 10 - data_handler = DataHandler( - "training", data_params, simPSF, n_bins_lambda, load_data=False - ) - - # Call the load_dataset method - data_handler.load_dataset() - - # Assertions - assert np.array_equal(data_handler.dataset["positions"], mock_dataset["positions"]) - assert np.array_equal( - data_handler.dataset["noisy_stars"], mock_dataset["noisy_stars"] - ) - assert np.array_equal(data_handler.dataset["SEDs"], mock_dataset["SEDs"]) - - -def test_load_test_dataset(tmp_path, data_params, simPSF): - # Create a temporary directory and a temporary data file - data_dir = tmp_path / "data" - data_dir.mkdir() - temp_data_dir = data_dir / "test_data.npy" - - # Mock dataset - mock_dataset = { - "positions": np.array([[1, 2], [3, 4]]), - "stars": np.array([[5, 6], [7, 8]]), - "SEDs": np.array([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]]), - } - - # Save the mock dataset to the temporary data file - np.save(temp_data_dir, mock_dataset) - - # Initialize DataHandler instance - data_params = RecursiveNamespace( - test=RecursiveNamespace(data_dir=str(data_dir), file="test_data.npy") - ) - - n_bins_lambda = 10 - data_handler = DataHandler( - "test", data_params, simPSF, n_bins_lambda, load_data=False - ) - - # Call the load_dataset method - data_handler.load_dataset() - - # Assertions - assert np.array_equal(data_handler.dataset["positions"], mock_dataset["positions"]) - assert np.array_equal(data_handler.dataset["stars"], mock_dataset["stars"]) - assert np.array_equal(data_handler.dataset["SEDs"], mock_dataset["SEDs"]) - - -def test_load_train_dataset_missing_noisy_stars(tmp_path, data_params, simPSF): - """Test that a warning is raised if 'noisy_stars' is missing in training data.""" - data_dir = tmp_path / "data" - data_dir.mkdir() - temp_data_file = data_dir / "train_data.npy" - - mock_dataset = { - "positions": np.array([[1, 2], [3, 4]]), # No 'noisy_stars' key - "SEDs": np.array([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]]), - } - - np.save(temp_data_file, mock_dataset) - - data_params = RecursiveNamespace( - training=RecursiveNamespace(data_dir=str(data_dir), file="train_data.npy") - ) - - n_bins_lambda = 10 - data_handler = DataHandler( - "training", data_params, simPSF, n_bins_lambda, load_data=False - ) - - with patch("wf_psf.data.training_preprocessing.logger.warning") as mock_warning: - data_handler.load_dataset() - mock_warning.assert_called_with("Missing 'noisy_stars' in training dataset.") - - -def test_load_test_dataset_missing_stars(tmp_path, data_params, simPSF): - """Test that a warning is raised if 'stars' is missing in test data.""" - data_dir = tmp_path / "data" - data_dir.mkdir() - temp_data_file = data_dir / "test_data.npy" - - mock_dataset = { - "positions": np.array([[1, 2], [3, 4]]), # No 'stars' key - "SEDs": np.array([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]]), - } - - np.save(temp_data_file, mock_dataset) - - data_params = RecursiveNamespace( - test=RecursiveNamespace(data_dir=str(data_dir), file="test_data.npy") - ) - - n_bins_lambda = 10 - data_handler = DataHandler( - "test", data_params, simPSF, n_bins_lambda, load_data=False - ) - - with patch("wf_psf.data.training_preprocessing.logger.warning") as mock_warning: - data_handler.load_dataset() - mock_warning.assert_called_with("Missing 'stars' in test dataset.") - - -def test_process_sed_data(data_params, simPSF): - mock_dataset = { - "positions": np.array([[1, 2], [3, 4]]), - "noisy_stars": np.array([[5, 6], [7, 8]]), - "SEDs": np.array([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]]), - } - # Initialize DataHandler instance - n_bins_lambda = 4 - data_handler = DataHandler("training", data_params, simPSF, n_bins_lambda, False) - - data_handler.dataset = mock_dataset - data_handler.process_sed_data() - # Assertions - assert isinstance(data_handler.sed_data, tf.Tensor) - assert data_handler.sed_data.dtype == tf.float32 - assert data_handler.sed_data.shape == ( - len(data_handler.dataset["positions"]), - n_bins_lambda, - len(["feasible_N", "feasible_wv", "SED_norm"]), - ) - - -def test_get_obs_positions(mock_data): - observed_positions = get_obs_positions(mock_data) - expected_positions = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) - assert tf.reduce_all(tf.equal(observed_positions, expected_positions)) - - -def test_get_zernike_prior(model_params, mock_data): - zernike_priors = get_zernike_prior(model_params, mock_data) - expected_shape = ( - 4, - 2, - ) # Assuming 2 Zernike priors for each dataset (training and test) - assert zernike_priors.shape == expected_shape - - -def test_get_zernike_prior_dtype(model_params, mock_data): - zernike_priors = get_zernike_prior(model_params, mock_data) - assert zernike_priors.dtype == np.float32 - - -def test_get_zernike_prior_concatenation(model_params, mock_data): - zernike_priors = get_zernike_prior(model_params, mock_data) - expected_zernike_priors = tf.convert_to_tensor( - np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]]), dtype=tf.float32 - ) - - assert np.array_equal(zernike_priors, expected_zernike_priors) - - -def test_get_zernike_prior_empty_data(model_params): - empty_data = MockData(np.array([]), np.array([]), np.array([]), np.array([])) - zernike_priors = get_zernike_prior(model_params, empty_data) - assert zernike_priors.shape == tf.TensorShape([0]) # Check for empty array shape - - -def test_extract_star_data_valid_keys(mock_data): - """Test extracting valid data from the dataset.""" - result = extract_star_data(mock_data, train_key="noisy_stars", test_key="stars") - - expected = np.array([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=np.float32) - np.testing.assert_array_equal(result, expected) - - -def test_extract_star_data_masks(mock_data): - """Test extracting star masks from the dataset.""" - result = extract_star_data(mock_data, train_key="masks", test_key="masks") - - expected = np.array([[1], [0], [0], [1]], dtype=np.float32) - np.testing.assert_array_equal(result, expected) - - -def test_extract_star_data_missing_key(mock_data): - """Test that the function raises a KeyError when a key is missing.""" - with pytest.raises(KeyError, match="Missing keys in dataset: \\['invalid_key'\\]"): - extract_star_data(mock_data, train_key="invalid_key", test_key="stars") - - -def test_extract_star_data_partially_missing_key(mock_data): - """Test that the function raises a KeyError if only one key is missing.""" - with pytest.raises( - KeyError, match="Missing keys in dataset: \\['missing_stars'\\]" - ): - extract_star_data(mock_data, train_key="noisy_stars", test_key="missing_stars") - - -def test_extract_star_data_tensor_conversion(mock_data): - """Test that the function properly converts TensorFlow tensors to NumPy arrays.""" - result = extract_star_data(mock_data, train_key="noisy_stars", test_key="stars") - - assert isinstance(result, np.ndarray), "The result should be a NumPy array" - assert result.dtype == np.float32, "The NumPy array should have dtype float32" - - -def test_compute_centroid_correction_with_masks(mock_data): - """Test compute_centroid_correction function with masks present.""" - # Given that compute_centroid_correction expects a model_params and data object - model_params = RecursiveNamespace( - pix_sampling=12e-6, # Example pixel sampling in meters - correct_centroids=True, - reference_shifts=["-1/3", "-1/3"], - ) - - # Mock the internal function calls: - with ( - patch( - "wf_psf.data.training_preprocessing.extract_star_data" - ) as mock_extract_star_data, - patch( - "wf_psf.data.training_preprocessing.compute_zernike_tip_tilt" - ) as mock_compute_zernike_tip_tilt, - ): - # Mock the return values of extract_star_data and compute_zernike_tip_tilt - mock_extract_star_data.side_effect = lambda data, train_key, test_key: ( - np.array([[1, 2], [3, 4]]) - if train_key == "noisy_stars" - else np.array([[5, 6], [7, 8]]) - ) - mock_compute_zernike_tip_tilt.return_value = np.array([[0.1, 0.2], [0.3, 0.4]]) - - # Call the function under test - result = compute_centroid_correction(model_params, mock_data) - - # Ensure the result has the correct shape - assert result.shape == (4, 3) # Should be (n_stars, 3 Zernike components) - - assert np.allclose( - result[0, :], np.array([0, -0.1, -0.2]) - ) # First star Zernike coefficients - assert np.allclose( - result[1, :], np.array([0, -0.3, -0.4]) - ) # Second star Zernike coefficients - - -def test_compute_centroid_correction_without_masks(mock_data): - """Test compute_centroid_correction function when no masks are provided.""" - # Remove masks from mock_data - mock_data.test_data.dataset["masks"] = None - mock_data.training_data.dataset["masks"] = None - - # Define model parameters - model_params = RecursiveNamespace( - pix_sampling=12e-6, # Example pixel sampling in meters - correct_centroids=True, - reference_shifts=["-1/3", "-1/3"], - ) - - # Mock internal function calls - with ( - patch( - "wf_psf.data.training_preprocessing.extract_star_data" - ) as mock_extract_star_data, - patch( - "wf_psf.data.training_preprocessing.compute_zernike_tip_tilt" - ) as mock_compute_zernike_tip_tilt, - ): - # Mock extract_star_data to return synthetic star postage stamps - mock_extract_star_data.side_effect = lambda data, train_key, test_key: ( - np.array([[1, 2], [3, 4]]) - if train_key == "noisy_stars" - else np.array([[5, 6], [7, 8]]) - ) - - # Mock compute_zernike_tip_tilt assuming no masks - mock_compute_zernike_tip_tilt.return_value = np.array([[0.1, 0.2], [0.3, 0.4]]) - - # Call function under test - result = compute_centroid_correction(model_params, mock_data) - - # Validate result shape - assert result.shape == (4, 3) # (n_stars, 3 Zernike components) - - # Validate expected values (adjust based on behavior) - expected_result = np.array( - [ - [0, -0.1, -0.2], # From training data - [0, -0.3, -0.4], - [0, -0.1, -0.2], # From test data (reused mocked return) - [0, -0.3, -0.4], - ] - ) - assert np.allclose(result, expected_result) - - -def test_reference_shifts_broadcasting(): - reference_shifts = [-1 / 3, -1 / 3] # Example reference_shifts - shifts = np.random.rand(2, 2400) # Example shifts array - - # Ensure reference_shifts is a NumPy array (if it's not already) - reference_shifts = np.array(reference_shifts) - - # Broadcast reference_shifts to match the shape of shifts - reference_shifts = np.broadcast_to( - reference_shifts[:, None], shifts.shape - ) # Shape will be (2, 2400) - - # Ensure shapes are compatible for subtraction - displacements = reference_shifts - shifts - - # Test the result - assert displacements.shape == shifts.shape, "Shapes do not match" - assert np.all(displacements.shape == (2, 2400)), "Broadcasting failed" diff --git a/src/wf_psf/tests/test_inference/psf_inference_test.py b/src/wf_psf/tests/test_inference/psf_inference_test.py new file mode 100644 index 00000000..4cff7a13 --- /dev/null +++ b/src/wf_psf/tests/test_inference/psf_inference_test.py @@ -0,0 +1,570 @@ +"""UNIT TESTS FOR PACKAGE MODULE: PSF Inference. + +This module contains unit tests for the wf_psf.inference.psf_inference module. + +:Author: Jennifer Pollack + +""" + +import numpy as np +import os +from pathlib import Path +import pytest +import tensorflow as tf +from types import SimpleNamespace +from unittest.mock import MagicMock, patch, PropertyMock +from wf_psf.inference.psf_inference import ( + InferenceConfigHandler, + PSFInference, + PSFInferenceEngine, +) + +from wf_psf.utils.read_config import RecursiveNamespace + + +def _patch_data_handler(): + """Helper for patching data_handler to avoid full PSF logic.""" + patcher = patch.object(PSFInference, "data_handler", new_callable=PropertyMock) + mock_data_handler = patcher.start() + mock_instance = MagicMock() + mock_data_handler.return_value = mock_instance + + def fake_process(x): + mock_instance.sed_data = tf.convert_to_tensor(x) + + mock_instance.process_sed_data.side_effect = fake_process + return patcher, mock_instance + + +@pytest.fixture +def mock_training_config(): + training_config = RecursiveNamespace( + training=RecursiveNamespace( + id_name="mock_id", + model_params=RecursiveNamespace( + model_name="mock_model", + output_Q=2, + output_dim=32, + pupil_diameter=256, + oversampling_rate=3, + interpolation_type=None, + interpolation_args=None, + sed_interp_pts_per_bin=0, + sed_extrapolate=True, + sed_interp_kind="linear", + sed_sigma=0, + x_lims=[0.0, 1000.0], + y_lims=[0.0, 1000.0], + pix_sampling=12, + tel_diameter=1.2, + tel_focal_length=24.5, + euclid_obsc=True, + LP_filter_length=3, + param_hparams=RecursiveNamespace( + n_zernikes=10, + ), + ), + ) + ) + return training_config + + +@pytest.fixture +def mock_inference_config(): + inference_config = RecursiveNamespace( + inference=RecursiveNamespace( + batch_size=16, + cycle=2, + configs=RecursiveNamespace( + trained_model_path="/path/to/trained/model", + model_subdir="psf_model", + trained_model_config_path="config/training_config.yaml", + data_config_path=None, + ), + model_params=RecursiveNamespace(n_bins_lda=8, output_Q=1, output_dim=64), + ) + ) + return inference_config + + +@pytest.fixture +def psf_test_setup(mock_inference_config): + num_sources = 2 + num_bins = 10 + output_dim = 32 + + mock_positions = tf.convert_to_tensor([[0.1, 0.1], [0.2, 0.2]], dtype=tf.float32) + mock_seds = tf.convert_to_tensor( + np.random.rand(num_sources, num_bins, 2), dtype=tf.float32 + ) + expected_psfs = np.random.rand(num_sources, output_dim, output_dim).astype( + np.float32 + ) + + inference = PSFInference( + "dummy_path.yaml", + x_field=[0.1, 0.2], + y_field=[0.1, 0.2], + seds=np.random.rand(num_sources, num_bins, 2), + ) + inference._config_handler = MagicMock() + inference._config_handler.inference_config = mock_inference_config + inference._trained_psf_model = MagicMock() + + return { + "inference": inference, + "mock_positions": mock_positions, + "mock_seds": mock_seds, + "expected_psfs": expected_psfs, + "num_sources": num_sources, + "num_bins": num_bins, + "output_dim": output_dim, + } + + +@pytest.fixture +def psf_single_star_setup(mock_inference_config): + num_sources = 1 + num_bins = 10 + output_dim = 32 + + # Single position + mock_positions = tf.convert_to_tensor([[0.1, 0.1]], dtype=tf.float32) + # Shape (1, 2, num_bins) + mock_seds = tf.convert_to_tensor( + np.random.rand(num_sources, 2, num_bins), dtype=tf.float32 + ) + expected_psfs = np.random.rand(num_sources, output_dim, output_dim).astype( + np.float32 + ) + + inference = PSFInference( + "dummy_path.yaml", + x_field=0.1, # scalar for single star + y_field=0.1, + seds=np.random.rand(num_bins, 2), # shape (num_bins, 2) before batching + ) + inference._config_handler = MagicMock() + inference._config_handler.inference_config = mock_inference_config + inference._trained_psf_model = MagicMock() + + return { + "inference": inference, + "mock_positions": mock_positions, + "mock_seds": mock_seds, + "expected_psfs": expected_psfs, + "num_sources": num_sources, + "num_bins": num_bins, + "output_dim": output_dim, + } + + +@pytest.fixture +def mock_compute_psfs_with_cache(psf_test_setup): + """ + Fixture that patches PSFInferenceEngine.compute_psfs with a side effect + that populates the engine's cache. + + Returns + ------- + dict + Dictionary containing: + - 'mock': The mock object for compute_psfs + - 'inference': The PSFInference instance + - 'positions': Mock positions tensor + - 'seds': Mock SEDs tensor + - 'expected_psfs': Expected PSF array + """ + inference = psf_test_setup["inference"] + mock_positions = psf_test_setup["mock_positions"] + mock_seds = psf_test_setup["mock_seds"] + expected_psfs = psf_test_setup["expected_psfs"] + + with patch.object(PSFInferenceEngine, "compute_psfs") as mock_compute_psfs: + + def fake_compute_psfs(positions, seds): + inference.engine._inferred_psfs = expected_psfs + return expected_psfs + + mock_compute_psfs.side_effect = fake_compute_psfs + + yield { + "mock": mock_compute_psfs, + "inference": inference, + "positions": mock_positions, + "seds": mock_seds, + "expected_psfs": expected_psfs, + } + + +def test_set_config_paths(mock_inference_config): + """Test setting configuration paths.""" + # Initialize handler and inject mock config + config_handler = InferenceConfigHandler("fake/path") + config_handler.inference_config = mock_inference_config + + # Call the method under test + config_handler.set_config_paths() + + # Assertions + assert config_handler.trained_model_path == Path("/path/to/trained/model") + assert config_handler.model_subdir == "psf_model" + assert config_handler.trained_model_config_path == Path( + "/path/to/trained/model/config/training_config.yaml" + ) + assert config_handler.data_config_path == None + + +def test_overwrite_model_params(mock_training_config, mock_inference_config): + """Test that model_params can be overwritten.""" + # Mock the model_params object with some initial values + training_config = mock_training_config + inference_config = mock_inference_config + + InferenceConfigHandler.overwrite_model_params(training_config, inference_config) + + # Assert that the model_params were overwritten correctly + assert ( + training_config.training.model_params.output_Q == 1 + ), "output_Q should be overwritten" + assert ( + training_config.training.model_params.output_dim == 64 + ), "output_dim should be overwritten" + + assert ( + training_config.training.id_name == "mock_id" + ), "id_name should not be overwritten" + + +def test_prepare_configs(mock_training_config, mock_inference_config): + """Test preparing configurations for inference.""" + # Mock the model_params object with some initial values + training_config = mock_training_config + inference_config = mock_inference_config + + # Make copy of the original training config model_params + original_model_params = mock_training_config.training.model_params + + # Instantiate PSFInference + psf_inf = PSFInference("/dummy/path.yaml") + + # Mock the config handler attribute with a mock InferenceConfigHandler + mock_config_handler = MagicMock(spec=InferenceConfigHandler) + mock_config_handler.training_config = training_config + mock_config_handler.inference_config = inference_config + + # Patch the overwrite_model_params to use the real static method + mock_config_handler.overwrite_model_params.side_effect = ( + InferenceConfigHandler.overwrite_model_params + ) + + psf_inf._config_handler = mock_config_handler + + # Run prepare_configs + psf_inf.prepare_configs() + + # Assert that the training model_params were updated + assert original_model_params.output_Q == 1 + assert original_model_params.output_dim == 64 + + +def test_config_handler_lazy_load(monkeypatch): + inference = PSFInference("dummy_path.yaml") + + called = {} + + class DummyHandler: + def load_configs(self): + called["load"] = True + self.inference_config = {} + self.training_config = {} + self.data_config = {} + + def overwrite_model_params(self, *args): + pass + + monkeypatch.setattr( + "wf_psf.inference.psf_inference.InferenceConfigHandler", + lambda path: DummyHandler(), + ) + + inference.prepare_configs() + + assert "load" in called # Confirm lazy load happened + + +def test_batch_size_positive(): + inference = PSFInference("dummy_path.yaml") + inference._config_handler = MagicMock() + inference._config_handler.inference_config = SimpleNamespace( + inference=SimpleNamespace( + batch_size=4, model_params=SimpleNamespace(output_dim=32) + ) + ) + assert inference.batch_size == 4 + + +@patch("wf_psf.inference.psf_inference.DataHandler") +@patch("wf_psf.inference.psf_inference.load_trained_psf_model") +def test_load_inference_model( + mock_load_trained_psf_model, + mock_data_handler, + mock_training_config, + mock_inference_config, +): + mock_data_config = MagicMock() + mock_data_handler.return_value = mock_data_config + mock_config_handler = MagicMock(spec=InferenceConfigHandler) + mock_config_handler.trained_model_path = "mock/path/to/model" + mock_config_handler.training_config = mock_training_config + mock_config_handler.inference_config = mock_inference_config + mock_config_handler.model_subdir = "psf_model" + mock_config_handler.data_config = MagicMock() + + psf_inf = PSFInference("dummy_path.yaml") + psf_inf._config_handler = mock_config_handler + + psf_inf.load_inference_model() + + weights_path_pattern = os.path.join( + mock_config_handler.trained_model_path, + mock_config_handler.model_subdir, + f"{mock_config_handler.model_subdir}*_{mock_config_handler.training_config.training.model_params.model_name}*{mock_config_handler.training_config.training.id_name}_cycle{mock_config_handler.inference_config.inference.cycle}*", + ) + + # Assert calls to the mocked methods + mock_load_trained_psf_model.assert_called_once_with( + mock_training_config, mock_data_config, weights_path_pattern + ) + + +@patch.object(PSFInference, "prepare_configs") +@patch.object(PSFInference, "_prepare_positions_and_seds") +@patch.object(PSFInferenceEngine, "compute_psfs") +def test_run_inference( + mock_compute_psfs, + mock_prepare_positions_and_seds, + mock_prepare_configs, + psf_test_setup, +): + inference = psf_test_setup["inference"] + mock_positions = psf_test_setup["mock_positions"] + mock_seds = psf_test_setup["mock_seds"] + expected_psfs = psf_test_setup["expected_psfs"] + + mock_prepare_positions_and_seds.return_value = (mock_positions, mock_seds) + mock_compute_psfs.return_value = expected_psfs + + psfs = inference.run_inference() + + assert isinstance(psfs, np.ndarray) + assert psfs.shape == expected_psfs.shape + mock_prepare_positions_and_seds.assert_called_once() + mock_compute_psfs.assert_called_once_with(mock_positions, mock_seds) + mock_prepare_configs.assert_called_once() + + +@patch("wf_psf.inference.psf_inference.psf_models.simPSF") +def test_simpsf_uses_updated_model_params( + mock_simpsf, mock_training_config, mock_inference_config +): + """Test that simPSF uses the updated model parameters.""" + training_config = mock_training_config + inference_config = mock_inference_config + + # Set the expected output_Q + expected_output_Q = inference_config.inference.model_params.output_Q + training_config.training.model_params.output_Q = expected_output_Q + + # Create fake psf instance + fake_psf_instance = MagicMock() + fake_psf_instance.output_Q = expected_output_Q + mock_simpsf.return_value = fake_psf_instance + + mock_config_handler = MagicMock(spec=InferenceConfigHandler) + mock_config_handler.trained_model_path = "mock/path/to/model" + mock_config_handler.training_config = training_config + mock_config_handler.inference_config = inference_config + mock_config_handler.model_subdir = "psf_model" + mock_config_handler.data_config = MagicMock() + + modeller = PSFInference("dummy_path.yaml") + modeller._config_handler = mock_config_handler + + modeller.prepare_configs() + result = modeller.simPSF + + # Confirm simPSF was called once with the updated model_params + mock_simpsf.assert_called_once() + called_args, _ = mock_simpsf.call_args + model_params_passed = called_args[0] + assert model_params_passed.output_Q == expected_output_Q + assert result.output_Q == expected_output_Q + + +@patch.object(PSFInference, "_prepare_positions_and_seds") +def test_get_psfs_runs_inference( + mock_prepare_positions_and_seds, mock_compute_psfs_with_cache +): + """Test that get_psfs uses cached PSFs after first computation.""" + mock = mock_compute_psfs_with_cache["mock"] + inference = mock_compute_psfs_with_cache["inference"] + mock_positions = mock_compute_psfs_with_cache["positions"] + mock_seds = mock_compute_psfs_with_cache["seds"] + expected_psfs = mock_compute_psfs_with_cache["expected_psfs"] + + mock_prepare_positions_and_seds.return_value = (mock_positions, mock_seds) + + psfs_1 = inference.get_psfs() + assert np.all(psfs_1 == expected_psfs) + + psfs_2 = inference.get_psfs() + assert np.all(psfs_2 == expected_psfs) + + assert mock.call_count == 1 + + +def test_single_star_inference_shape(psf_single_star_setup): + setup = psf_single_star_setup + + _, mock_instance = _patch_data_handler() + + # Run the method under test + positions, sed_data = setup["inference"]._prepare_positions_and_seds() + + # Check shapes + assert sed_data.shape == (1, setup["num_bins"], 2) + assert positions.shape == (1, 2) + + # Verify the call happened + mock_instance.process_sed_data.assert_called_once() + args, _ = mock_instance.process_sed_data.call_args + input_array = args[0] + + # Check input SED had the right shape before being tensorized + assert input_array.shape == ( + 1, + setup["num_bins"], + 2, + ), "process_sed_data should have been called with shape (1, num_bins, 2)" + + +def test_multiple_star_inference_shape(psf_test_setup): + """Test that _prepare_positions_and_seds returns correct shapes for multiple stars.""" + setup = psf_test_setup + + _, mock_instance = _patch_data_handler() + + # Run the method under test + positions, sed_data = setup["inference"]._prepare_positions_and_seds() + + # Check shapes + assert sed_data.shape == (2, setup["num_bins"], 2) + assert positions.shape == (2, 2) + + # Verify the call happened + mock_instance.process_sed_data.assert_called_once() + args, _ = mock_instance.process_sed_data.call_args + input_array = args[0] + + # Check input SED had the right shape before being tensorized + assert input_array.shape == ( + 2, + setup["num_bins"], + 2, + ), "process_sed_data should have been called with shape (2, num_bins, 2)" + + +def test_valueerror_on_mismatched_batches(psf_single_star_setup): + """Raise if sed_data batch size != positions batch size and sed_data != 1.""" + setup = psf_single_star_setup + inference = setup["inference"] + + patcher, _ = _patch_data_handler() + try: + # Force sed_data to have 2 sources while positions has 1 + bad_sed = np.ones((2, setup["num_bins"], 2), dtype=np.float32) + + # Replace fixture's sed_data with mismatched one + inference.seds = bad_sed + inference.positions = np.ones((1, 2), dtype=np.float32) + + with pytest.raises( + ValueError, match="SEDs batch size 2 does not match number of positions 1" + ): + inference._prepare_positions_and_seds() + finally: + patcher.stop() + + +def test_valueerror_on_mismatched_positions(psf_single_star_setup): + """Raise if positions batch size != sed_data batch size (opposite mismatch).""" + setup = psf_single_star_setup + inference = setup["inference"] + + patcher, _ = _patch_data_handler() + try: + # Force positions to have 3 entries while sed_data has 2 + bad_sed = np.ones((2, setup["num_bins"], 2), dtype=np.float32) + inference.seds = bad_sed + inference.x_field = np.ones((3, 1), dtype=np.float32) + inference.y_field = np.ones((3, 1), dtype=np.float32) + + with pytest.raises( + ValueError, match="SEDs batch size 2 does not match number of positions 3" + ): + inference._prepare_positions_and_seds() + finally: + patcher.stop() + + +def test_inference_clear_cache(psf_test_setup): + """Test that PSFInference clear_cache resets the instance of PSFInference.""" + inference = psf_test_setup["inference"] + inference._simPSF = MagicMock() + inference._data_handler = MagicMock() + inference._trained_psf_model = MagicMock() + inference._n_bins_lambda = MagicMock() + inference._batch_size = MagicMock() + inference._cycle = MagicMock() + inference._output_dim = MagicMock() + inference.engine = MagicMock() + + # Clear the cache + inference.clear_cache() + + # Check that the internal cache is None + assert ( + inference._config_handler == None, + inference._simPSF == None, + inference._data_handler == None, + inference._trained_psf_model == None, + inference._n_bins_lambda == None, + inference._batch_size == None, + inference._cycle == None, + inference._output_dim == None, + inference.engine == None, + ), "Inference attributes should be cleared to None" # type: ignore + + +def test_engine_clear_cache(psf_test_setup): + """Test that clear_cache resets the internal PSF cache.""" + inference = psf_test_setup["inference"] + expected_psfs = psf_test_setup["expected_psfs"] + + # Create the engine and compute PSFs + inference.engine = PSFInferenceEngine( + trained_model=inference.trained_psf_model, + batch_size=inference.batch_size, + output_dim=inference.output_dim, + ) + + inference.engine._inferred_psfs = expected_psfs + + # Clear the cache + inference.engine.clear_cache() + + # Check that the internal cache is None + assert ( + inference.engine._inferred_psfs is None + ), "PSF cache should be cleared to None" diff --git a/src/wf_psf/tests/test_metrics/metrics_interface_test.py b/src/wf_psf/tests/test_metrics/metrics_interface_test.py index 44bd09bf..35f3b9a1 100644 --- a/src/wf_psf/tests/test_metrics/metrics_interface_test.py +++ b/src/wf_psf/tests/test_metrics/metrics_interface_test.py @@ -1,7 +1,7 @@ from unittest.mock import patch, MagicMock import pytest from wf_psf.metrics.metrics_interface import evaluate_model -from wf_psf.data.training_preprocessing import DataHandler +from wf_psf.data.data_handler import DataHandler @pytest.fixture @@ -106,7 +106,6 @@ def test_evaluate_model_flags( trained_model_params=mock_trained_model_params, data=mock_data, psf_model=mock_psf_model, - weights_path="/mock/path", metrics_output="/mock/output", ) @@ -134,7 +133,6 @@ def test_missing_ground_truth_model_raises( trained_model_params=mock_trained_model_params, data=mock_data, psf_model=mock_psf_model, - weights_path="/mock/weights/path", metrics_output="/mock/metrics/output", ) @@ -168,7 +166,6 @@ def test_plotting_config_passed( trained_model_params=mock_trained_model_params, data=mock_data, psf_model=mock_psf_model, - weights_path="/mock/path", metrics_output="/mock/output", ) @@ -200,8 +197,6 @@ def test_evaluate_model( ) as mock_evaluate_shape_results_dict, patch("numpy.save", new_callable=MagicMock) as mock_np_save, ): - # Mock the logger - _ = mocker.patch("wf_psf.metrics.metrics_interface.logger") # Call evaluate_model evaluate_model( @@ -209,7 +204,6 @@ def test_evaluate_model( trained_model_params=mock_trained_model_params, data=mock_data, psf_model=mock_psf_model, - weights_path="/mock/weights/path", metrics_output="/mock/metrics/output", ) diff --git a/src/wf_psf/tests/test_psf_models/conftest.py b/src/wf_psf/tests/test_psf_models/conftest.py index cbaae8d9..4693b343 100644 --- a/src/wf_psf/tests/test_psf_models/conftest.py +++ b/src/wf_psf/tests/test_psf_models/conftest.py @@ -12,7 +12,7 @@ from wf_psf.utils.read_config import RecursiveNamespace from wf_psf.training.train import TrainingParamsHandler from wf_psf.psf_models import psf_models -from wf_psf.data.training_preprocessing import DataHandler +from wf_psf.data.data_handler import DataHandler training_config = RecursiveNamespace( id_name="_sample_w_bis1_2k", diff --git a/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py b/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py index cae7b141..e900a6d3 100644 --- a/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py +++ b/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py @@ -9,7 +9,8 @@ import pytest import numpy as np import tensorflow as tf -from wf_psf.psf_models.psf_model_physical_polychromatic import ( +from unittest.mock import patch +from wf_psf.psf_models.models.psf_model_physical_polychromatic import ( TFPhysicalPolychromaticField, ) from wf_psf.utils.configs_handler import DataConfigHandler @@ -28,14 +29,27 @@ def zks_prior(): @pytest.fixture -def mock_data(mocker): +def mock_data(mocker, zks_prior): mock_instance = mocker.Mock(spec=DataConfigHandler) - # Configure the mock data object to have the necessary attributes + mock_instance.run_type = "training" + + training_dataset = { + "positions": np.array([[1, 2], [3, 4]]), + "zernike_prior": zks_prior, + "noisy_stars": np.zeros((2, 1, 1, 1)), + } + test_dataset = { + "positions": np.array([[5, 6], [7, 8]]), + "zernike_prior": zks_prior, + "stars": np.zeros((2, 1, 1, 1)), + } + mock_instance.training_data = mocker.Mock() - mock_instance.training_data.dataset = {"positions": np.array([[1, 2], [3, 4]])} + mock_instance.training_data.dataset = training_dataset mock_instance.test_data = mocker.Mock() - mock_instance.test_data.dataset = {"positions": np.array([[5, 6], [7, 8]])} - mock_instance.batch_size = 32 + mock_instance.test_data.dataset = test_dataset + mock_instance.batch_size = 16 + return mock_instance @@ -47,256 +61,57 @@ def mock_model_params(mocker): return model_params_mock -def test_initialize_parameters(mocker, mock_data, mock_model_params, zks_prior): - # Create mock objects for model_params, training_params - # model_params_mock = mocker.MagicMock() - mock_training_params = mocker.Mock() - - # Mock internal methods called during initialization - mocker.patch( - "wf_psf.psf_models.psf_model_physical_polychromatic.get_zernike_prior", - return_value=zks_prior, - ) - - mocker.patch( - "wf_psf.data.training_preprocessing.get_obs_positions", return_value=True - ) - - # Create TFPhysicalPolychromaticField instance - field_instance = TFPhysicalPolychromaticField( - mock_model_params, mock_training_params, mock_data - ) - - mocker.patch.object(field_instance, "_initialize_zernike_parameters") - mocker.patch.object(field_instance, "_initialize_layers") - mocker.patch.object(field_instance, "assign_coeff_matrix") - - # Call the method being tested - field_instance._initialize_parameters_and_layers( - mock_model_params, mock_training_params, mock_data - ) - - # Check if internal methods were called with the correct arguments - field_instance._initialize_zernike_parameters.assert_called_once_with( - mock_model_params, mock_data - ) - field_instance._initialize_layers.assert_called_once_with( - mock_model_params, mock_training_params - ) - field_instance.assign_coeff_matrix.assert_not_called() # Because coeff_mat is None in this test - - -def test_initialize_zernike_parameters(mocker, mock_model_params, mock_data, zks_prior): - # Create training params mock object - mock_training_params = mocker.Mock() - - # Mock internal methods called during initialization - mocker.patch( - "wf_psf.psf_models.psf_model_physical_polychromatic.get_zernike_prior", - return_value=zks_prior, - ) - - # Create TFPhysicalPolychromaticField instance - field_instance = TFPhysicalPolychromaticField( - mock_model_params, mock_training_params, mock_data - ) - - # Assert that the attributes are set correctly - # assert field_instance.n_zernikes == mock_model_params.param_hparams.n_zernikes - assert np.array_equal(field_instance.zks_prior.numpy(), zks_prior.numpy()) - assert field_instance.n_zks_total == mock_model_params.param_hparams.n_zernikes - assert isinstance( - field_instance.zernike_maps, tf.Tensor - ) # Check if the returned value is a TensorFlow tensor - assert ( - field_instance.zernike_maps.dtype == tf.float32 - ) # Check if the data type of the tensor is float32 - - # Expected shape of the tensor based on the input parameters - expected_shape = ( - field_instance.n_zks_total, - mock_model_params.pupil_diameter, - mock_model_params.pupil_diameter, - ) - assert field_instance.zernike_maps.shape == expected_shape - - # Modify model_params to simulate zks_prior > n_zernikes - mock_model_params.param_hparams.n_zernikes = 2 - - # Call the method again to initialize the parameters - field_instance._initialize_zernike_parameters(mock_model_params, mock_data) - - assert field_instance.n_zks_total == tf.cast( - tf.shape(field_instance.zks_prior)[1], tf.int32 - ) - # Expected shape of the tensor based on the input parameters - expected_shape = ( - field_instance.n_zks_total, - mock_model_params.pupil_diameter, - mock_model_params.pupil_diameter, - ) - assert field_instance.zernike_maps.shape == expected_shape - - -def test_initialize_physical_layer_mocking( - mocker, mock_model_params, mock_data, zks_prior -): - # Create training params mock object - mock_training_params = mocker.Mock() - - # Mock internal methods called during initialization - mocker.patch( - "wf_psf.psf_models.psf_model_physical_polychromatic.get_zernike_prior", - return_value=zks_prior, - ) - - # Create a mock for the TFPhysicalLayer class - mock_physical_layer_class = mocker.patch( - "wf_psf.psf_models.psf_model_physical_polychromatic.TFPhysicalLayer" - ) - - # Create TFPhysicalPolychromaticField instance - field_instance = TFPhysicalPolychromaticField( - mock_model_params, mock_training_params, mock_data - ) - - # Assert that the TFPhysicalLayer class was called with the expected arguments - mock_physical_layer_class.assert_called_once_with( - field_instance.obs_pos, - field_instance.zks_prior, - interpolation_type=mock_model_params.interpolation_type, - interpolation_args=mock_model_params.interpolation_args, - ) - - @pytest.fixture -def physical_layer_instance(mocker, mock_model_params, mock_data, zks_prior): - # Create training params mock object - mock_training_params = mocker.Mock() - - # Mock internal methods called during initialization - mocker.patch( - "wf_psf.psf_models.psf_model_physical_polychromatic.get_zernike_prior", - return_value=zks_prior, - ) +def physical_layer_instance(mocker, mock_model_params, mock_data): + # Patch expensive methods during construction to avoid errors + with patch( + "wf_psf.psf_models.models.psf_model_physical_polychromatic.TFPhysicalPolychromaticField._assemble_zernike_contributions", + return_value=tf.constant([[[[1.0]]], [[[2.0]]]]), + ): + from wf_psf.psf_models.models.psf_model_physical_polychromatic import ( + TFPhysicalPolychromaticField, + ) + + instance = TFPhysicalPolychromaticField( + mock_model_params, mocker.Mock(), mock_data + ) + return instance - # Create a mock for the TFPhysicalLayer class - mocker.patch("wf_psf.psf_models.psf_model_physical_polychromatic.TFPhysicalLayer") - # Create TFPhysicalPolychromaticField instance - psf_field_instance = TFPhysicalPolychromaticField( - mock_model_params, mock_training_params, mock_data - ) - return psf_field_instance - - -def test_pad_zernikes_num_of_zernikes_equal(physical_layer_instance): - # Define input tensors with same length and num of Zernikes - zk_param = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) - zk_prior = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) - - # Reshape the tensors to have the desired shapes - zk_param = tf.reshape(zk_param, (1, 2, 1, 1)) # Reshaping tensor1 to (1, 2, 1, 1) - zk_prior = tf.reshape(zk_prior, (1, 2, 1, 1)) # Reshaping tensor2 to (1, 2, 1, 1) - - # Reset n_zks_total attribute - physical_layer_instance.n_zks_total = max( - tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() - ) - # Call the method under test - padded_zk_param, padded_zk_prior = physical_layer_instance.pad_zernikes( - zk_param, zk_prior - ) - - # Assert that the padded tensors have the correct shapes - assert padded_zk_param.shape == (1, 2, 1, 1) - assert padded_zk_prior.shape == (1, 2, 1, 1) - - -def test_pad_zernikes_prior_greater_than_param(physical_layer_instance): - zk_param = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) - zk_prior = tf.constant([[[[1]], [[2]], [[3]], [[4]], [[5]]]]) # Shape: (5, 1, 1, 1) - - # Reshape the tensors to have the desired shapes - zk_param = tf.reshape(zk_param, (1, 2, 1, 1)) # Reshaping tensor1 to (1, 2, 1, 1) - zk_prior = tf.reshape(zk_prior, (1, 5, 1, 1)) # Reshaping tensor2 to (1, 5, 1, 1) - - # Reset n_zks_total attribute - physical_layer_instance.n_zks_total = max( - tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() - ) - - # Call the method under test - padded_zk_param, padded_zk_prior = physical_layer_instance.pad_zernikes( - zk_param, zk_prior +def test_compute_zernikes(mocker, physical_layer_instance): + # Expected output of mock components + padded_zernike_param = tf.constant( + [[[[10]], [[20]], [[30]], [[40]]]], dtype=tf.float32 ) - - # Assert that the padded tensors have the correct shapes - assert padded_zk_param.shape == (1, 5, 1, 1) - assert padded_zk_prior.shape == (1, 5, 1, 1) - - -def test_pad_zernikes_param_greater_than_prior(physical_layer_instance): - zk_param = tf.constant([[[[10]], [[20]], [[30]], [[40]]]]) # Shape: (4, 1, 1, 1) - zk_prior = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) - - # Reshape the tensors to have the desired shapes - zk_param = tf.reshape(zk_param, (1, 4, 1, 1)) # Reshaping tensor1 to (1, 2, 1, 1) - zk_prior = tf.reshape(zk_prior, (1, 2, 1, 1)) # Reshaping tensor2 to (1, 4, 1, 1) - - # Reset n_zks_total attribute - physical_layer_instance.n_zks_total = max( - tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() + padded_zernike_prior = tf.constant([[[[1]], [[2]], [[0]], [[0]]]], dtype=tf.float32) + n_zks_total = physical_layer_instance.n_zks_total + expected_values_list = [11, 22, 30, 40] + [0] * (n_zks_total - 4) + expected_values = tf.constant( + [[[[v]] for v in expected_values_list]], dtype=tf.float32 ) - - # Call the method under test - padded_zk_param, padded_zk_prior = physical_layer_instance.pad_zernikes( - zk_param, zk_prior + # Patch tf_poly_Z_field method + mocker.patch.object( + TFPhysicalPolychromaticField, + "tf_poly_Z_field", + return_value=padded_zernike_param, ) - # Assert that the padded tensors have the correct shapes - assert padded_zk_param.shape == (1, 4, 1, 1) - assert padded_zk_prior.shape == (1, 4, 1, 1) - - -def test_compute_zernikes(mocker, physical_layer_instance): - # Mock padded tensors - padded_zk_param = tf.constant( - [[[[10]], [[20]], [[30]], [[40]]]] - ) # Shape: (1, 4, 1, 1) - padded_zk_prior = tf.constant([[[[1]], [[2]], [[0]], [[0]]]]) # Shape: (1, 4, 1, 1) - - # Reset n_zks_total attribute - physical_layer_instance.n_zks_total = 4 # Assuming a specific value for simplicity - - # Define the mock return values for tf_poly_Z_field and tf_physical_layer.call - padded_zernike_param = tf.constant( - [[[[10]], [[20]], [[30]], [[40]]]] - ) # Shape: (1, 4, 1, 1) - padded_zernike_prior = tf.constant( - [[[[1]], [[2]], [[0]], [[0]]]] - ) # Shape: (1, 4, 1, 1) - + # Patch tf_physical_layer.call method + mock_tf_physical_layer = mocker.Mock() + mock_tf_physical_layer.call.return_value = padded_zernike_prior mocker.patch.object( - physical_layer_instance, "tf_poly_Z_field", return_value=padded_zk_param + TFPhysicalPolychromaticField, "tf_physical_layer", mock_tf_physical_layer ) - mocker.patch.object(physical_layer_instance, "call", return_value=padded_zk_prior) - mocker.patch.object( - physical_layer_instance, - "pad_zernikes", + + # Patch pad_tf_zernikes function + mocker.patch( + "wf_psf.data.data_zernike_utils.pad_tf_zernikes", return_value=(padded_zernike_param, padded_zernike_prior), ) - # Call the method under test + # Run the test zernike_coeffs = physical_layer_instance.compute_zernikes(tf.constant([[0.0, 0.0]])) - # Define the expected values - expected_values = tf.constant( - [[[[11]], [[22]], [[30]], [[40]]]] - ) # Shape: (1, 4, 1, 1) - - # Assert that the shapes are equal + # Assertions + tf.debugging.assert_equal(zernike_coeffs, expected_values) assert zernike_coeffs.shape == expected_values.shape - - # Assert that the tensor values are equal - assert tf.reduce_all(tf.equal(zernike_coeffs, expected_values)) diff --git a/src/wf_psf/tests/test_psf_models/psf_models_test.py b/src/wf_psf/tests/test_psf_models/psf_models_test.py index 066e1328..2b907eff 100644 --- a/src/wf_psf/tests/test_psf_models/psf_models_test.py +++ b/src/wf_psf/tests/test_psf_models/psf_models_test.py @@ -7,8 +7,8 @@ """ -from wf_psf.psf_models import ( - psf_models, +from wf_psf.psf_models import psf_models +from wf_psf.psf_models.models import ( psf_model_semiparametric, psf_model_physical_polychromatic, ) diff --git a/src/wf_psf/tests/test_utils/configs_handler_test.py b/src/wf_psf/tests/test_utils/configs_handler_test.py index d95761e9..57dfdc8a 100644 --- a/src/wf_psf/tests/test_utils/configs_handler_test.py +++ b/src/wf_psf/tests/test_utils/configs_handler_test.py @@ -4,15 +4,17 @@ :Author: Jennifer Pollack - """ import pytest +from wf_psf.data.data_handler import DataHandler from wf_psf.utils import configs_handler from wf_psf.utils.read_config import RecursiveNamespace from wf_psf.utils.io import FileIOHandler -from wf_psf.utils.configs_handler import TrainingConfigHandler, DataConfigHandler -from wf_psf.data.training_preprocessing import DataHandler +from wf_psf.utils.configs_handler import ( + TrainingConfigHandler, + DataConfigHandler, +) import os @@ -116,10 +118,21 @@ def test_data_config_handler_init(mock_training_conf, mock_data_read_conf, mocke "wf_psf.psf_models.psf_models.simPSF", return_value=mock_simPSF_instance ) - # Patch the load_dataset and process_sed_data methods inside DataHandler - mocker.patch.object(DataHandler, "load_dataset") + # Patch process_sed_data method mocker.patch.object(DataHandler, "process_sed_data") + # Patch validate_and_process_datasetmethod + mocker.patch.object(DataHandler, "validate_and_process_dataset") + + # Patch load_dataset to assign dataset + def mock_load_dataset(self): + self.dataset = { + "SEDs": ["dummy_sed_data"], + "positions": ["dummy_positions_data"], + } + + mocker.patch.object(DataHandler, "load_dataset", new=mock_load_dataset) + # Create DataConfigHandler instance data_config_handler = DataConfigHandler( "/path/to/data_config.yaml", @@ -141,7 +154,7 @@ def test_data_config_handler_init(mock_training_conf, mock_data_read_conf, mocke assert ( data_config_handler.batch_size == mock_training_conf.training.training_hparams.batch_size - ) # Default value + ) def test_training_config_handler_init(mocker, mock_training_conf, mock_file_handler): @@ -229,21 +242,3 @@ def test_run_method_calls_train_with_correct_arguments( mock_th.optimizer_dir, mock_th.psf_model_dir, ) - - -def test_MetricsConfigHandler_weights_basename_filepath( - path_to_tmp_output_dir, path_to_config_dir -): - test_file_handler = FileIOHandler(path_to_tmp_output_dir, path_to_config_dir) - - metrics_config_file = "validation/main_random_seed/config/metrics_config.yaml" - - metrics_object = configs_handler.MetricsConfigHandler( - os.path.join(path_to_config_dir, metrics_config_file), test_file_handler - ) - weights_filepath = metrics_object.weights_basename_filepath - - assert ( - weights_filepath - == "src/wf_psf/tests/data/validation/main_random_seed/checkpoint/checkpoint*_poly*_sample_w_bis1_2k_cycle2*" - ) diff --git a/src/wf_psf/tests/test_utils/utils_test.py b/src/wf_psf/tests/test_utils/utils_test.py index dacf0bdc..cc7f2a2b 100644 --- a/src/wf_psf/tests/test_utils/utils_test.py +++ b/src/wf_psf/tests/test_utils/utils_test.py @@ -20,11 +20,6 @@ from unittest import mock - -def test_sanity(): - assert 1 + 1 == 2 - - def test_downsample_basic(): """Test apply_mask when a zeroed mask is provided.""" img_dim = (10, 10) @@ -37,9 +32,9 @@ def test_downsample_basic(): # The result should be an array of False values, as the mask excludes all pixels expected_result = np.zeros(img_dim, dtype=bool) - assert np.array_equal(result, expected_result), ( - "apply_mask did not handle the zeroed mask correctly." - ) + assert np.array_equal( + result, expected_result + ), "apply_mask did not handle the zeroed mask correctly." def test_initialization(): @@ -121,9 +116,9 @@ def test_apply_mask_with_none_mask(): result = estimator.apply_mask(None) # Pass None as the mask # It should return the window itself when no mask is provided - assert np.array_equal(result, estimator.window), ( - "apply_mask should return the window when mask is None." - ) + assert np.array_equal( + result, estimator.window + ), "apply_mask should return the window when mask is None." def test_apply_mask_with_valid_mask(): @@ -139,9 +134,9 @@ def test_apply_mask_with_valid_mask(): # Check that the mask was applied correctly: pixel (5, 5) should be False, others True expected_result = estimator.window & custom_mask - assert np.array_equal(result, expected_result), ( - "apply_mask did not apply the mask correctly." - ) + assert np.array_equal( + result, expected_result + ), "apply_mask did not apply the mask correctly." def test_apply_mask_with_zeroed_mask(): @@ -156,9 +151,9 @@ def test_apply_mask_with_zeroed_mask(): # The result should be an array of False values, as the mask excludes all pixels expected_result = np.zeros(img_dim, dtype=bool) - assert np.array_equal(result, expected_result), ( - "apply_mask did not handle the zeroed mask correctly." - ) + assert np.array_equal( + result, expected_result + ), "apply_mask did not handle the zeroed mask correctly." def test_unobscured_zernike_projection(): @@ -252,6 +247,7 @@ def test_tf_decompose_obscured_opd_basis(): assert rmse_error < tol + def test_downsample_basic(): """Downsample a small array to a smaller square size.""" arr = np.arange(16).reshape(4, 4).astype(np.float32) @@ -262,9 +258,10 @@ def test_downsample_basic(): assert result.shape == (output_dim, output_dim), "Output shape mismatch" # Values should be averaged/downsampled; simple check - assert np.all(result >= arr.min()) and np.all(result <= arr.max()), \ - "Values outside input range" - + assert np.all(result >= arr.min()) and np.all( + result <= arr.max() + ), "Values outside input range" + def test_downsample_identity(): """Downsample to the same size should return same array (approximately).""" @@ -274,10 +271,12 @@ def test_downsample_identity(): # Since OpenCV / skimage may do minor interpolation, allow small tolerance np.testing.assert_allclose(result, arr, rtol=1e-6, atol=1e-6) + # ---------------------------- # Backend fallback tests # ---------------------------- + @mock.patch("wf_psf.utils.utils._HAS_CV2", False) @mock.patch("wf_psf.utils.utils._HAS_SKIMAGE", False) def test_downsample_no_backend(): @@ -296,10 +295,11 @@ def test_downsample_values_average(): # All output values should be close to input value np.testing.assert_allclose(result, 3.0, rtol=1e-6, atol=1e-6) + @mock.patch("wf_psf.utils.utils._HAS_CV2", True) def test_downsample_non_square_array(): """Check downsampling works for non-square arrays.""" arr = np.arange(12).reshape(3, 4).astype(np.float32) output_dim = 2 result = downsample_im(arr, output_dim) - assert result.shape == (2, 2) \ No newline at end of file + assert result.shape == (2, 2) diff --git a/src/wf_psf/training/train.py b/src/wf_psf/training/train.py index bb0e3df9..ab2f0ac1 100644 --- a/src/wf_psf/training/train.py +++ b/src/wf_psf/training/train.py @@ -7,6 +7,7 @@ """ +import gc import numpy as np import time import tensorflow as tf @@ -273,10 +274,7 @@ def _prepare_callbacks( def get_loss_metrics_monitor_and_outputs(training_handler, data_conf): - """Generate factory for loss, metrics, monitor, and outputs. - - A function to generate loss, metrics, monitor, and outputs - for training. + """Factory to return fresh loss, metrics (param & non-param), monitor, and outputs for the current cycle. Parameters ---------- @@ -369,12 +367,12 @@ def train( psf_model_dir : str Directory where the final trained PSF model weights will be saved per cycle. - Notes - ----- - - Utilizes TensorFlow and TensorFlow Addons for model training and optimization. - - Supports masked mean squared error loss for training with masked data. - - Allows for projection of data-driven features onto parametric models between cycles. - - Supports resetting of non-parametric features to initial states. + Returns + ------- + None + + Side Effects + ------------ - Saves model weights to `psf_model_dir` per training cycle (or final one if not all saved) - Saves optimizer histories to `optimizer_dir` - Logs cycle information and time durations @@ -538,3 +536,8 @@ def train( final_time = time.time() logger.info("\nTotal elapsed time: %f" % (final_time - starting_time)) logger.info("\n Training complete..") + + # Clean up memory + del psf_model + gc.collect() + tf.keras.backend.clear_session() diff --git a/src/wf_psf/utils/configs_handler.py b/src/wf_psf/utils/configs_handler.py index 5d45ba79..13ceb6de 100644 --- a/src/wf_psf/utils/configs_handler.py +++ b/src/wf_psf/utils/configs_handler.py @@ -12,12 +12,14 @@ import os import re import glob -from wf_psf.utils.read_config import read_conf -from wf_psf.data.training_preprocessing import DataHandler -from wf_psf.training import train -from wf_psf.psf_models import psf_models +from wf_psf.data.data_handler import DataHandler from wf_psf.metrics.metrics_interface import evaluate_model from wf_psf.plotting.plots_interface import plot_metrics +from wf_psf.psf_models import psf_models +from wf_psf.psf_models.psf_model_loader import load_trained_psf_model +from wf_psf.training import train +from wf_psf.utils.read_config import read_conf + logger = logging.getLogger(__name__) @@ -127,28 +129,31 @@ class DataConfigHandler: def __init__(self, data_conf, training_model_params, batch_size=16, load_data=True): try: self.data_conf = read_conf(data_conf) - except FileNotFoundError as e: - logger.exception(e) - exit() - except TypeError as e: + except (FileNotFoundError, TypeError) as e: logger.exception(e) exit() self.simPSF = psf_models.simPSF(training_model_params) + + # Extract sub-configs early + train_params = self.data_conf.data.training + test_params = self.data_conf.data.test + self.training_data = DataHandler( dataset_type="training", - data_params=self.data_conf.data, + data_params=train_params, simPSF=self.simPSF, n_bins_lambda=training_model_params.n_bins_lda, load_data=load_data, ) self.test_data = DataHandler( dataset_type="test", - data_params=self.data_conf.data, + data_params=test_params, simPSF=self.simPSF, n_bins_lambda=training_model_params.n_bins_lda, load_data=load_data, ) + self.batch_size = batch_size @@ -183,6 +188,7 @@ def __init__(self, training_conf, file_handler): self.training_conf.training.training_hparams.batch_size, self.training_conf.training.load_data_on_init, ) + self.data_conf.run_type = "training" self.file_handler.copy_conffile_to_output_dir( self.training_conf.training.data_config ) @@ -254,8 +260,13 @@ class MetricsConfigHandler: def __init__(self, metrics_conf, file_handler, training_conf=None): self._metrics_conf = read_conf(metrics_conf) self._file_handler = file_handler - self.trained_model_path = self._get_trained_model_path(training_conf) - self._training_conf = self._load_training_conf(training_conf) + self.training_conf = training_conf + self.data_conf = self._load_data_conf() + self.data_conf.run_type = "metrics" + self.metrics_dir = self._file_handler.get_metrics_dir( + self._file_handler._run_output_dir + ) + self.trained_psf_model = self._load_trained_psf_model() @property def metrics_conf(self): @@ -270,32 +281,29 @@ def metrics_conf(self): """ return self._metrics_conf - @property - def metrics_dir(self): - """Get Metrics Directory. - - A function that returns path - of metrics directory. - - Returns - ------- - str - Absolute path to metrics directory - """ - return self._file_handler.get_metrics_dir(self._file_handler._run_output_dir) - @property def training_conf(self): - """Get Training Conf. - - A function to return the training configuration file name. + """Returns the loaded training configuration.""" + return self._training_conf - Returns - ------- - RecursiveNamespace - An instance of the training configuration file. + @training_conf.setter + def training_conf(self, training_conf): """ - return self._training_conf + Sets the training configuration. If None is provided, attempts to load it + from the trained_model_path in the metrics configuration. + """ + if training_conf is None: + try: + training_conf_path = self._get_training_conf_path_from_metrics() + logger.info( + f"Loading training config from inferred path: {training_conf_path}" + ) + self._training_conf = read_conf(training_conf_path) + except Exception as e: + logger.error(f"Failed to load training config: {e}") + raise + else: + self._training_conf = training_conf @property def plotting_conf(self): @@ -310,112 +318,106 @@ def plotting_conf(self): """ return self.metrics_conf.metrics.plotting_config - @property - def data_conf(self): - """Get Data Conf. - - A function to return an instance of the DataConfigHandler class. + def _load_trained_psf_model(self): + trained_model_path = self._get_trained_model_path() + try: + model_subdir = self.metrics_conf.metrics.model_save_path + cycle = self.metrics_conf.metrics.saved_training_cycle + except AttributeError as e: + raise KeyError("Missing required model config fields.") from e + + model_name = self.training_conf.training.model_params.model_name + id_name = self.training_conf.training.id_name + + weights_path_pattern = os.path.join( + trained_model_path, + model_subdir, + (f"{model_subdir}*_{model_name}" f"*{id_name}_cycle{cycle}*"), + ) + return load_trained_psf_model( + self.training_conf, + self.data_conf, + weights_path_pattern, + ) - Returns - ------- - An instance of the DataConfigHandler class. + def _get_training_conf_path_from_metrics(self): """ - return self._load_data_conf() - - @property - def psf_model(self): - """Get PSF Model. - - A function to return an instance of the PSF model - to be evaluated. + Retrieves the full path to the training config based on the metrics configuration. Returns ------- - psf_model: obj - An instance of the PSF model to be evaluated. + str + Full path to the training configuration file. + + Raises + ------ + KeyError + If 'trained_model_config' key is missing. + FileNotFoundError + If the file does not exist at the constructed path. """ - return psf_models.get_psf_model( - self.training_conf.training.model_params, - self.training_conf.training.training_hparams, - self.data_conf, + trained_model_path = self._get_trained_model_path() + + try: + training_conf_filename = self._metrics_conf.metrics.trained_model_config + except AttributeError as e: + raise KeyError( + "Missing 'trained_model_config' key in metrics configuration." + ) from e + + training_conf_path = os.path.join( + self._file_handler.get_config_dir(trained_model_path), + training_conf_filename, ) - @property - def weights_path(self): - """Get Weights Path. + if not os.path.exists(training_conf_path): + raise FileNotFoundError( + f"Training config file not found: {training_conf_path}" + ) - A function to return the full path - of the user-specified psf model weights to be loaded. + return training_conf_path - Returns - ------- - str - A string representing the full path to the psf model weights to be loaded. + def _get_trained_model_path(self): """ - return psf_models.get_psf_model_weights_filepath(self.weights_basename_filepath) - - def _get_trained_model_path(self, training_conf): - """Get Trained Model Path. + Determine the trained model path from either: - Helper method to get the trained model path. - - Parameters - ---------- - training_conf: None or RecursiveNamespace - None type or RecursiveNamespace + 1. The metrics configuration file (i.e., for metrics-only runs after training), or + 2. The runtime-generated file handler paths (i.e., for single runs that perform both training and evaluation). Returns ------- str - A string representing the path to the trained model output run directory. + Path to the trained model directory. + Raises + ------ + ConfigParameterError + If the path specified in the metrics config is invalid or missing. """ - if training_conf is None: - try: - return self._metrics_conf.metrics.trained_model_path + trained_model_path = getattr( + self._metrics_conf.metrics, "trained_model_path", None + ) - except TypeError as e: - logger.exception(e) + if trained_model_path: + if not os.path.isdir(trained_model_path): raise ConfigParameterError( - "Metrics config file trained model path or config values are empty." + f"The trained model path provided in the metrics config is not a valid directory: {trained_model_path}" ) - else: - return os.path.join( - self._file_handler.output_path, - self._file_handler.parent_output_dir, - self._file_handler.workdir, + logger.info( + f"Using trained model path from metrics config: {trained_model_path}" ) + return trained_model_path - def _load_training_conf(self, training_conf): - """Load Training Conf. - - Load the training configuration if training_conf is not provided. - - Parameters - ---------- - training_conf: None or RecursiveNamespace - None type or a RecursiveNamespace storing the training configuration parameter setttings. - - Returns - ------- - RecursiveNamespace storing the training configuration parameter settings. - - """ - if training_conf is None: - try: - return read_conf( - os.path.join( - self._file_handler.get_config_dir(self.trained_model_path), - self._metrics_conf.metrics.trained_model_config, - ) - ) - except TypeError as e: - logger.exception(e) - raise ConfigParameterError( - "Metrics config file trained model path or config values are empty." - ) - else: - return training_conf + # Fallback for single-run training + metrics evaluation mode + fallback_path = os.path.join( + self._file_handler.output_path, + self._file_handler.parent_output_dir, + self._file_handler.workdir, + ) + logger.info( + f"Using fallback trained model path from runtime file handler: {fallback_path}" + ) + return fallback_path def _load_data_conf(self): """Load Data Conf. @@ -439,27 +441,6 @@ def _load_data_conf(self): logger.exception(e) raise ConfigParameterError("Data configuration loading error.") - @property - def weights_basename_filepath(self): - """Get PSF model weights filepath. - - A function to return the basename of the user-specified psf model weights path. - - Returns - ------- - weights_basename: str - The basename of the psf model weights to be loaded. - - """ - return os.path.join( - self.trained_model_path, - self.metrics_conf.metrics.model_save_path, - ( - f"{self.metrics_conf.metrics.model_save_path}*_{self.training_conf.training.model_params.model_name}" - f"*{self.training_conf.training.id_name}_cycle{self.metrics_conf.metrics.saved_training_cycle}*" - ), - ) - def call_plot_config_handler_run(self, model_metrics): """Make Metrics Plots. @@ -502,18 +483,17 @@ def call_plot_config_handler_run(self, model_metrics): def run(self): """Run. - A function to run wave-diff according to the + A function to run WaveDiff according to the input configuration. """ - logger.info(f"Running metrics evaluation on psf model: {self.weights_path}") + logger.info("Running metrics evaluation on trained PSF model...") model_metrics = evaluate_model( self.metrics_conf.metrics, self.training_conf.training, self.data_conf, - self.psf_model, - self.weights_path, + self.trained_psf_model, self.metrics_dir, ) diff --git a/src/wf_psf/utils/preprocessing.py b/src/wf_psf/utils/preprocessing.py deleted file mode 100644 index 210c03e5..00000000 --- a/src/wf_psf/utils/preprocessing.py +++ /dev/null @@ -1,102 +0,0 @@ -"""Preprocessing. - -A module with utils to preprocess data. - -:Author: Tobias Liaudat - -""" - -import numpy as np - - -def shift_x_y_to_zk1_2_wavediff(dxy, tel_focal_length=24.5, tel_diameter=1.2): - """Compute Zernike 1(2) for a given shifts in x(y) in WaveDifff conventions. - - All inputs should be in [m]. - A displacement of, for example, 0.5 pixels should be scaled with the corresponding pixel scale, - e.g. 12[um], to get a displacement in [m], which would be `dxy=0.5*12e-6`. - - The output zernike coefficient is in [um] units as expected by wavediff. - - To apply match the centroid with a `dx` that has a corresponding `zk1`, - the new PSF should be generated with `-zk1`. - - The same applies to `dy` and `zk2`. - - Parameters - ---------- - dxy : float - Centroid shift in [m]. It can be on the x-axis or the y-axis. - tel_focal_length : float - Telescope focal length in [m]. - tel_diameter : float - Telescope aperture diameter in [m]. - """ - reference_pix_sampling = 12e-6 - zernike_norm_factor = 2.0 - - # return zernike_norm_factor * (dx/reference_pix_sampling) / (tel_focal_length * tel_diameter / 2) - return ( - zernike_norm_factor - * (tel_diameter / 2) - * np.sin(np.arctan((dxy / reference_pix_sampling) / tel_focal_length)) - * 3.0 - ) - - -def defocus_to_zk4_zemax(dz, tel_focal_length=24.5, tel_diameter=1.2): - """Compute Zernike 4 value for a given defocus in zemax conventions. - - All inputs should be in [m]. - - Parameters - ---------- - dz : float - Shift in the z-axis, perpendicular to the focal plane. Units in [m]. - tel_focal_length : float - Telescope focal length in [m]. - tel_diameter : float - Telescope aperture diameter in [m]. - """ - # Base calculation - zk4 = dz / (8.0 * (tel_focal_length / tel_diameter) ** 2) - # Apply Z4 normalisation - # This step depends on the normalisation of the Zernike basis used - zk4 /= np.sqrt(3) - # Convert to waves with a reference of 800nm - zk4 /= 800e-9 - # Remove the peak to valley value - zk4 /= 2.0 - - return zk4 - - -def defocus_to_zk4_wavediff(dz, tel_focal_length=24.5, tel_diameter=1.2): - """Compute Zernike 4 value for a given defocus in WaveDifff conventions. - - All inputs should be in [m]. - - The output zernike coefficient is in [um] units as expected by wavediff. - - Parameters - ---------- - dz : float - Shift in the z-axis, perpendicular to the focal plane. Units in [m]. - tel_focal_length : float - Telescope focal length in [m]. - tel_diameter : float - Telescope aperture diameter in [m]. - """ - # Base calculation - zk4 = dz / (8.0 * (tel_focal_length / tel_diameter) ** 2) - # Apply Z4 normalisation - # This step depends on the normalisation of the Zernike basis used - zk4 /= np.sqrt(3) - - # Remove the peak to valley value - zk4 /= 2.0 - - # Change units to [um] as Wavediff uses - zk4 *= 1e6 - - return zk4 diff --git a/src/wf_psf/utils/read_config.py b/src/wf_psf/utils/read_config.py index 875ae8ed..48d23e00 100644 --- a/src/wf_psf/utils/read_config.py +++ b/src/wf_psf/utils/read_config.py @@ -140,4 +140,4 @@ def read_stream(conf_file): docs = yaml.load_all(stream, yaml.FullLoader) for doc in docs: # noqa: UP028 - yield doc + yield doc diff --git a/src/wf_psf/utils/utils.py b/src/wf_psf/utils/utils.py index 1b1f2d6d..17219ad9 100644 --- a/src/wf_psf/utils/utils.py +++ b/src/wf_psf/utils/utils.py @@ -26,6 +26,47 @@ pass +def scale_to_range(input_array, old_range, new_range): + # Scale to [0,1] + input_array = (input_array - old_range[0]) / (old_range[1] - old_range[0]) + # Scale to new_range + input_array = input_array * (new_range[1] - new_range[0]) + new_range[0] + return input_array + + +def ensure_batch(arr): + """ + Ensure array/tensor has a batch dimension. Converts shape (M, N) → (1, M, N). + + Parameters + ---------- + arr : np.ndarray or tf.Tensor + Input 2D or 3D array/tensor. + + Returns + ------- + np.ndarray or tf.Tensor + With batch dimension prepended if needed. + """ + if isinstance(arr, np.ndarray): + return arr if arr.ndim == 3 else np.expand_dims(arr, axis=0) + elif isinstance(arr, tf.Tensor): + return arr if arr.ndim == 3 else tf.expand_dims(arr, axis=0) + else: + raise TypeError(f"Expected np.ndarray or tf.Tensor, got {type(arr)}") + + +def calc_wfe(zernike_basis, zks): + wfe = np.einsum("ijk,ijk->jk", zernike_basis, zks.reshape(-1, 1, 1)) + return wfe + + +def calc_wfe_rms(zernike_basis, zks, pupil_mask): + wfe = calc_wfe(zernike_basis, zks) + wfe_rms = np.sqrt(np.mean((wfe[pupil_mask] - np.mean(wfe[pupil_mask])) ** 2)) + return wfe_rms + + def generalised_sigmoid(x, max_val=1, power_k=1): """ Apply a generalized sigmoid function to the input.