From d7cbe1aea7d32bfb181dcdd65e2cff124c58953d Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Tue, 13 May 2025 11:33:38 +0200 Subject: [PATCH 001/146] Correct syntax in docstring and generalise exception message --- src/wf_psf/psf_models/psf_models.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/wf_psf/psf_models/psf_models.py b/src/wf_psf/psf_models/psf_models.py index dcb6abd9..41976590 100644 --- a/src/wf_psf/psf_models/psf_models.py +++ b/src/wf_psf/psf_models/psf_models.py @@ -186,24 +186,24 @@ def build_PSF_model(model_inst, optimizer=None, loss=None, metrics=None): def get_psf_model_weights_filepath(weights_filepath): """Get PSF model weights filepath. - A function to return the basename of the user-specified psf model weights path. + A function to return the basename of the user-specified PSF model weights path. Parameters ---------- weights_filepath: str - Basename of the psf model weights to be loaded. + Basename of the PSF model weights to be loaded. Returns ------- str - The absolute path concatenated to the basename of the psf model weights to be loaded. + The absolute path concatenated to the basename of the PSF model weights to be loaded. """ try: return glob.glob(weights_filepath)[0].split(".")[0] except IndexError: logger.exception( - "PSF weights file not found. Check that you've specified the correct weights file in the metrics config file." + "PSF weights file not found. Check that you've specified the correct weights file in the your config file." ) raise PSFModelError("PSF model weights error.") From 18ee53327179c43f6b8695a65bf55ee015cf8d91 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Tue, 13 May 2025 11:35:24 +0200 Subject: [PATCH 002/146] Add inference and test_inference packages --- src/wf_psf/inference/__init__.py | 0 src/wf_psf/inference/psf_inference.py | 0 src/wf_psf/tests/test_inference/test_psf_inference.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/wf_psf/inference/__init__.py create mode 100644 src/wf_psf/inference/psf_inference.py create mode 100644 src/wf_psf/tests/test_inference/test_psf_inference.py diff --git a/src/wf_psf/inference/__init__.py b/src/wf_psf/inference/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py new file mode 100644 index 00000000..e69de29b diff --git a/src/wf_psf/tests/test_inference/test_psf_inference.py b/src/wf_psf/tests/test_inference/test_psf_inference.py new file mode 100644 index 00000000..e69de29b From 3358e7bf2ba22a6d82070e44b63da8af6536fced Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Tue, 13 May 2025 14:19:39 +0200 Subject: [PATCH 003/146] Refactor: Encapsulate logic in psf_models package with subpackages: models and tf_modules, add/rm modules, update import statements and tests --- src/wf_psf/__init__.py | 6 +++--- src/wf_psf/psf_models/models/__init__.py | 0 .../{ => models}/psf_model_parametric.py | 2 +- .../psf_model_physical_polychromatic.py | 8 ++++---- .../{ => models}/psf_model_semiparametric.py | 4 ++-- src/wf_psf/psf_models/tf_modules/__init__.py | 0 .../psf_models/{ => tf_modules}/tf_layers.py | 2 +- .../psf_models/{ => tf_modules}/tf_modules.py | 0 .../psf_models/{ => tf_modules}/tf_psf_field.py | 4 ++-- .../psf_model_physical_polychromatic_test.py | 14 +++++++------- .../tests/test_psf_models/psf_models_test.py | 6 +++--- 11 files changed, 23 insertions(+), 23 deletions(-) create mode 100644 src/wf_psf/psf_models/models/__init__.py rename src/wf_psf/psf_models/{ => models}/psf_model_parametric.py (99%) rename src/wf_psf/psf_models/{ => models}/psf_model_physical_polychromatic.py (99%) rename src/wf_psf/psf_models/{ => models}/psf_model_semiparametric.py (99%) create mode 100644 src/wf_psf/psf_models/tf_modules/__init__.py rename src/wf_psf/psf_models/{ => tf_modules}/tf_layers.py (99%) rename src/wf_psf/psf_models/{ => tf_modules}/tf_modules.py (100%) rename src/wf_psf/psf_models/{ => tf_modules}/tf_psf_field.py (99%) diff --git a/src/wf_psf/__init__.py b/src/wf_psf/__init__.py index 5df41b29..863675f1 100644 --- a/src/wf_psf/__init__.py +++ b/src/wf_psf/__init__.py @@ -2,6 +2,6 @@ # Dynamically import modules to trigger side effects when wf_psf is imported importlib.import_module('wf_psf.psf_models.psf_models') -importlib.import_module('wf_psf.psf_models.psf_model_semiparametric') -importlib.import_module('wf_psf.psf_models.psf_model_physical_polychromatic') -importlib.import_module('wf_psf.psf_models.tf_psf_field') +importlib.import_module('wf_psf.psf_models.models.psf_model_semiparametric') +importlib.import_module('wf_psf.psf_models.models.psf_model_physical_polychromatic') +importlib.import_module('wf_psf.psf_models.tf_modules.tf_psf_field') diff --git a/src/wf_psf/psf_models/models/__init__.py b/src/wf_psf/psf_models/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/wf_psf/psf_models/psf_model_parametric.py b/src/wf_psf/psf_models/models/psf_model_parametric.py similarity index 99% rename from src/wf_psf/psf_models/psf_model_parametric.py rename to src/wf_psf/psf_models/models/psf_model_parametric.py index 94f79f40..1b095852 100644 --- a/src/wf_psf/psf_models/psf_model_parametric.py +++ b/src/wf_psf/psf_models/models/psf_model_parametric.py @@ -9,7 +9,7 @@ import tensorflow as tf from wf_psf.psf_models.psf_models import register_psfclass -from wf_psf.psf_models.tf_layers import ( +from wf_psf.psf_models.tf_modules.tf_layers import ( TFPolynomialZernikeField, TFZernikeOPD, TFBatchPolychromaticPSF, diff --git a/src/wf_psf/psf_models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py similarity index 99% rename from src/wf_psf/psf_models/psf_model_physical_polychromatic.py rename to src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index ad61c1b5..f9ed8765 100644 --- a/src/wf_psf/psf_models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -10,11 +10,9 @@ from typing import Optional import tensorflow as tf from tensorflow.python.keras.engine import data_adapter -from wf_psf.psf_models import psf_models as psfm -from wf_psf.utils.read_config import RecursiveNamespace -from wf_psf.utils.configs_handler import DataConfigHandler from wf_psf.data.training_preprocessing import get_obs_positions, get_zernike_prior -from wf_psf.psf_models.tf_layers import ( +from wf_psf.psf_models import psf_models as psfm +from wf_psf.psf_models.tf_modules.tf_layers import ( TFPolynomialZernikeField, TFZernikeOPD, TFBatchPolychromaticPSF, @@ -22,6 +20,8 @@ TFNonParametricPolynomialVariationsOPD, TFPhysicalLayer, ) +from wf_psf.utils.read_config import RecursiveNamespace +from wf_psf.utils.configs_handler import DataConfigHandler import logging diff --git a/src/wf_psf/psf_models/psf_model_semiparametric.py b/src/wf_psf/psf_models/models/psf_model_semiparametric.py similarity index 99% rename from src/wf_psf/psf_models/psf_model_semiparametric.py rename to src/wf_psf/psf_models/models/psf_model_semiparametric.py index dc535204..c370956c 100644 --- a/src/wf_psf/psf_models/psf_model_semiparametric.py +++ b/src/wf_psf/psf_models/models/psf_model_semiparametric.py @@ -10,9 +10,9 @@ import numpy as np import tensorflow as tf from wf_psf.psf_models import psf_models as psfm -from wf_psf.psf_models import tf_layers as tfl +from wf_psf.psf_models.tf_modules import tf_layers as tfl from wf_psf.utils.utils import decompose_tf_obscured_opd_basis -from wf_psf.psf_models.tf_layers import ( +from wf_psf.psf_models.tf_modules.tf_layers import ( TFBatchPolychromaticPSF, TFBatchMonochromaticPSF, ) diff --git a/src/wf_psf/psf_models/tf_modules/__init__.py b/src/wf_psf/psf_models/tf_modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/wf_psf/psf_models/tf_layers.py b/src/wf_psf/psf_models/tf_modules/tf_layers.py similarity index 99% rename from src/wf_psf/psf_models/tf_layers.py rename to src/wf_psf/psf_models/tf_modules/tf_layers.py index 4479e73f..9d6f77c9 100644 --- a/src/wf_psf/psf_models/tf_layers.py +++ b/src/wf_psf/psf_models/tf_modules/tf_layers.py @@ -1,6 +1,6 @@ import tensorflow as tf import tensorflow_addons as tfa -from wf_psf.psf_models.tf_modules import TFMonochromaticPSF +from wf_psf.psf_models.tf_modules.tf_modules import TFMonochromaticPSF from wf_psf.utils.utils import calc_poly_position_mat import wf_psf.utils.utils as utils import logging diff --git a/src/wf_psf/psf_models/tf_modules.py b/src/wf_psf/psf_models/tf_modules/tf_modules.py similarity index 100% rename from src/wf_psf/psf_models/tf_modules.py rename to src/wf_psf/psf_models/tf_modules/tf_modules.py diff --git a/src/wf_psf/psf_models/tf_psf_field.py b/src/wf_psf/psf_models/tf_modules/tf_psf_field.py similarity index 99% rename from src/wf_psf/psf_models/tf_psf_field.py rename to src/wf_psf/psf_models/tf_modules/tf_psf_field.py index 9ef58ac6..45ea91dd 100644 --- a/src/wf_psf/psf_models/tf_psf_field.py +++ b/src/wf_psf/psf_models/tf_modules/tf_psf_field.py @@ -9,13 +9,13 @@ import numpy as np import tensorflow as tf from tensorflow.python.keras.engine import data_adapter -from wf_psf.psf_models.tf_layers import ( +from wf_psf.psf_models.tf_modules.tf_layers import ( TFZernikeOPD, TFBatchPolychromaticPSF, TFBatchMonochromaticPSF, TFPhysicalLayer, ) -from wf_psf.psf_models.psf_model_semiparametric import TFSemiParametricField +from wf_psf.psf_models.models.psf_model_semiparametric import TFSemiParametricField from wf_psf.data.training_preprocessing import get_obs_positions from wf_psf.psf_models import psf_models as psfm import logging diff --git a/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py b/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py index 81042c9a..e25d608d 100644 --- a/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py +++ b/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py @@ -9,7 +9,7 @@ import pytest import numpy as np import tensorflow as tf -from wf_psf.psf_models.psf_model_physical_polychromatic import ( +from wf_psf.psf_models.models.psf_model_physical_polychromatic import ( TFPhysicalPolychromaticField, ) from wf_psf.utils.configs_handler import DataConfigHandler @@ -54,7 +54,7 @@ def test_initialize_parameters(mocker, mock_data, mock_model_params, zks_prior): # Mock internal methods called during initialization mocker.patch( - "wf_psf.psf_models.psf_model_physical_polychromatic.get_zernike_prior", + "wf_psf.psf_models.models.psf_model_physical_polychromatic.get_zernike_prior", return_value=zks_prior, ) @@ -92,7 +92,7 @@ def test_initialize_zernike_parameters(mocker, mock_model_params, mock_data, zks # Mock internal methods called during initialization mocker.patch( - "wf_psf.psf_models.psf_model_physical_polychromatic.get_zernike_prior", + "wf_psf.psf_models.models.psf_model_physical_polychromatic.get_zernike_prior", return_value=zks_prior, ) @@ -146,13 +146,13 @@ def test_initialize_physical_layer_mocking( # Mock internal methods called during initialization mocker.patch( - "wf_psf.psf_models.psf_model_physical_polychromatic.get_zernike_prior", + "wf_psf.psf_models.models.psf_model_physical_polychromatic.get_zernike_prior", return_value=zks_prior, ) # Create a mock for the TFPhysicalLayer class mock_physical_layer_class = mocker.patch( - "wf_psf.psf_models.psf_model_physical_polychromatic.TFPhysicalLayer" + "wf_psf.psf_models.models.psf_model_physical_polychromatic.TFPhysicalLayer" ) # Create TFPhysicalPolychromaticField instance @@ -176,13 +176,13 @@ def physical_layer_instance(mocker, mock_model_params, mock_data, zks_prior): # Mock internal methods called during initialization mocker.patch( - "wf_psf.psf_models.psf_model_physical_polychromatic.get_zernike_prior", + "wf_psf.psf_models.models.psf_model_physical_polychromatic.get_zernike_prior", return_value=zks_prior, ) # Create a mock for the TFPhysicalLayer class mocker.patch( - "wf_psf.psf_models.psf_model_physical_polychromatic.TFPhysicalLayer" + "wf_psf.psf_models.models.psf_model_physical_polychromatic.TFPhysicalLayer" ) # Create TFPhysicalPolychromaticField instance diff --git a/src/wf_psf/tests/test_psf_models/psf_models_test.py b/src/wf_psf/tests/test_psf_models/psf_models_test.py index 066e1328..b7c906f6 100644 --- a/src/wf_psf/tests/test_psf_models/psf_models_test.py +++ b/src/wf_psf/tests/test_psf_models/psf_models_test.py @@ -7,10 +7,10 @@ """ -from wf_psf.psf_models import ( - psf_models, +from wf_psf.psf_models import psf_models +from wf_psf.psf_models.models import ( psf_model_semiparametric, - psf_model_physical_polychromatic, + psf_model_physical_polychromatic ) import tensorflow as tf import numpy as np From d00acd402e356d8486bf79ddd66fd70b9fe5cc7a Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Tue, 13 May 2025 14:29:13 +0200 Subject: [PATCH 004/146] Remove unused module with duplicate zernike_generator function --- src/wf_psf/psf_models/zernikes.py | 58 ------------------------------- 1 file changed, 58 deletions(-) delete mode 100644 src/wf_psf/psf_models/zernikes.py diff --git a/src/wf_psf/psf_models/zernikes.py b/src/wf_psf/psf_models/zernikes.py deleted file mode 100644 index dcfa6e39..00000000 --- a/src/wf_psf/psf_models/zernikes.py +++ /dev/null @@ -1,58 +0,0 @@ -"""Zernikes. - -A module to make Zernike maps. - -:Author: Tobias Liaudat and Jennifer Pollack - -""" - -import numpy as np -import zernike as zk -import logging - -logger = logging.getLogger(__name__) - - -def zernike_generator(n_zernikes, wfe_dim): - """ - Generate Zernike maps. - - Based on the zernike github repository. - https://github.com/jacopoantonello/zernike - - Parameters - ---------- - n_zernikes: int - Number of Zernike modes desired. - wfe_dim: int - Dimension of the Zernike map [wfe_dim x wfe_dim]. - - Returns - ------- - zernikes: list of np.ndarray - List containing the Zernike modes. - The values outside the unit circle are filled with NaNs. - """ - # Calculate which n (from the (n,m) Zernike convention) we need - # so that we have the desired total number of Zernike coefficients - min_n = (-3 + np.sqrt(1 + 8 * n_zernikes)) / 2 - n = int(np.ceil(min_n)) - - # Initialize the zernike generator - cart = zk.RZern(n) - # Create a [-1,1] mesh - ddx = np.linspace(-1.0, 1.0, wfe_dim) - ddy = np.linspace(-1.0, 1.0, wfe_dim) - xv, yv = np.meshgrid(ddx, ddy) - cart.make_cart_grid(xv, yv) - - c = np.zeros(cart.nk) - zernikes = [] - - # Extract each Zernike map one by one - for i in range(n_zernikes): - c *= 0.0 - c[i] = 1.0 - zernikes.append(cart.eval_grid(c, matrix=True)) - - return zernikes From 22edf80b69a5f28facb9f96a3c368d832c790924 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Tue, 13 May 2025 14:31:47 +0200 Subject: [PATCH 005/146] Correct syntax in docstrings and logger messages --- src/wf_psf/utils/configs_handler.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/wf_psf/utils/configs_handler.py b/src/wf_psf/utils/configs_handler.py index a7f76187..4f78990e 100644 --- a/src/wf_psf/utils/configs_handler.py +++ b/src/wf_psf/utils/configs_handler.py @@ -378,12 +378,12 @@ def _load_data_conf(self): def weights_basename_filepath(self): """Get PSF model weights filepath. - A function to return the basename of the user-specified psf model weights path. + A function to return the basename of the user-specified PSF model weights path. Returns ------- weights_basename: str - The basename of the psf model weights to be loaded. + The basename of the PSF model weights to be loaded. """ return os.path.join( @@ -437,12 +437,12 @@ def call_plot_config_handler_run(self, model_metrics): def run(self): """Run. - A function to run wave-diff according to the + A function to run WaveDiff according to the input configuration. """ logger.info( - "Running metrics evaluation on psf model: {}".format(self.weights_path) + "Running metrics evaluation on PSF model: {}".format(self.weights_path) ) model_metrics = evaluate_model( From 0959398d256479467426143480748a5bc429add8 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Wed, 14 May 2025 17:22:44 +0200 Subject: [PATCH 006/146] Refactor file structure; update import statements in tests; remove unit test due to refactoring --- src/wf_psf/{utils => data}/centroids.py | 2 +- .../data_preprocessing.py} | 2 +- src/wf_psf/data/training_preprocessing.py | 4 +-- src/wf_psf/instrument/__init__.py | 0 .../ccd_misalignments.py | 0 .../centroids_test.py | 10 +++---- .../tests/test_utils/configs_handler_test.py | 29 ++++--------------- 7 files changed, 15 insertions(+), 32 deletions(-) rename src/wf_psf/{utils => data}/centroids.py (99%) rename src/wf_psf/{utils/preprocessing.py => data/data_preprocessing.py} (99%) create mode 100644 src/wf_psf/instrument/__init__.py rename src/wf_psf/{utils => instrument}/ccd_misalignments.py (100%) rename src/wf_psf/tests/{test_utils => test_data}/centroids_test.py (97%) diff --git a/src/wf_psf/utils/centroids.py b/src/wf_psf/data/centroids.py similarity index 99% rename from src/wf_psf/utils/centroids.py rename to src/wf_psf/data/centroids.py index 75d3c7e9..3e1a8c71 100644 --- a/src/wf_psf/utils/centroids.py +++ b/src/wf_psf/data/centroids.py @@ -8,7 +8,7 @@ import numpy as np import scipy.signal as scisig -from wf_psf.utils.preprocessing import shift_x_y_to_zk1_2_wavediff +from wf_psf.data.data_preprocessing import shift_x_y_to_zk1_2_wavediff from typing import Optional diff --git a/src/wf_psf/utils/preprocessing.py b/src/wf_psf/data/data_preprocessing.py similarity index 99% rename from src/wf_psf/utils/preprocessing.py rename to src/wf_psf/data/data_preprocessing.py index 210c03e5..44e18436 100644 --- a/src/wf_psf/utils/preprocessing.py +++ b/src/wf_psf/data/data_preprocessing.py @@ -1,4 +1,4 @@ -"""Preprocessing. +"""Data Preprocessing. A module with utils to preprocess data. diff --git a/src/wf_psf/data/training_preprocessing.py b/src/wf_psf/data/training_preprocessing.py index 95afebb1..fe34a3bd 100644 --- a/src/wf_psf/data/training_preprocessing.py +++ b/src/wf_psf/data/training_preprocessing.py @@ -10,8 +10,8 @@ import numpy as np import wf_psf.utils.utils as utils import tensorflow as tf -from wf_psf.utils.ccd_misalignments import CCDMisalignmentCalculator -from wf_psf.utils.centroids import compute_zernike_tip_tilt +from wf_psf.instrument.ccd_misalignments import CCDMisalignmentCalculator +from wf_psf.data.centroids import compute_zernike_tip_tilt from fractions import Fraction import logging diff --git a/src/wf_psf/instrument/__init__.py b/src/wf_psf/instrument/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/wf_psf/utils/ccd_misalignments.py b/src/wf_psf/instrument/ccd_misalignments.py similarity index 100% rename from src/wf_psf/utils/ccd_misalignments.py rename to src/wf_psf/instrument/ccd_misalignments.py diff --git a/src/wf_psf/tests/test_utils/centroids_test.py b/src/wf_psf/tests/test_data/centroids_test.py similarity index 97% rename from src/wf_psf/tests/test_utils/centroids_test.py rename to src/wf_psf/tests/test_data/centroids_test.py index 02e479e9..5ef5f86f 100644 --- a/src/wf_psf/tests/test_utils/centroids_test.py +++ b/src/wf_psf/tests/test_data/centroids_test.py @@ -9,7 +9,7 @@ import numpy as np import pytest from unittest.mock import MagicMock, patch -from wf_psf.utils.centroids import ( +from wf_psf.data.centroids import ( compute_zernike_tip_tilt, CentroidEstimator ) @@ -124,7 +124,7 @@ def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_ma """Test compute_zernike_tip_tilt with single batch input and mocks.""" # Mock the CentroidEstimator class - mock_centroid_calc = mocker.patch("wf_psf.utils.centroids.CentroidEstimator", autospec=True) + mock_centroid_calc = mocker.patch("wf_psf.data.centroids.CentroidEstimator", autospec=True) # Create a mock instance and configure get_intra_pixel_shifts() mock_instance = mock_centroid_calc.return_value @@ -132,7 +132,7 @@ def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_ma # Mock shift_x_y_to_zk1_2_wavediff to return predictable values mock_shift_fn = mocker.patch( - "wf_psf.utils.centroids.shift_x_y_to_zk1_2_wavediff", + "wf_psf.data.centroids.shift_x_y_to_zk1_2_wavediff", side_effect=lambda shift: shift * 0.5 # Mocked conversion for test ) @@ -166,7 +166,7 @@ def test_compute_zernike_tip_tilt_batch(mocker, multiple_images): """Test compute_zernike_tip_tilt with batch input and mocks.""" # Mock the CentroidEstimator class - mock_centroid_calc = mocker.patch("wf_psf.utils.centroids.CentroidEstimator", autospec=True) + mock_centroid_calc = mocker.patch("wf_psf.data.centroids.CentroidEstimator", autospec=True) # Create a mock instance and configure get_intra_pixel_shifts() mock_instance = mock_centroid_calc.return_value @@ -174,7 +174,7 @@ def test_compute_zernike_tip_tilt_batch(mocker, multiple_images): # Mock shift_x_y_to_zk1_2_wavediff to return predictable values mock_shift_fn = mocker.patch( - "wf_psf.utils.centroids.shift_x_y_to_zk1_2_wavediff", + "wf_psf.data.centroids.shift_x_y_to_zk1_2_wavediff", side_effect=lambda shift: shift * 0.5 # Mocked conversion for test ) diff --git a/src/wf_psf/tests/test_utils/configs_handler_test.py b/src/wf_psf/tests/test_utils/configs_handler_test.py index 5fb7d7cd..18fdc86f 100644 --- a/src/wf_psf/tests/test_utils/configs_handler_test.py +++ b/src/wf_psf/tests/test_utils/configs_handler_test.py @@ -4,15 +4,18 @@ :Author: Jennifer Pollack - """ import pytest +from wf_psf.data.training_preprocessing import DataHandler from wf_psf.utils import configs_handler from wf_psf.utils.read_config import RecursiveNamespace from wf_psf.utils.io import FileIOHandler -from wf_psf.utils.configs_handler import TrainingConfigHandler, DataConfigHandler -from wf_psf.data.training_preprocessing import DataHandler +from wf_psf.utils.configs_handler import ( + TrainingConfigHandler, + MetricsConfigHandler, + DataConfigHandler +) import os @@ -234,23 +237,3 @@ def test_run_method_calls_train_with_correct_arguments( mock_th.optimizer_dir, mock_th.psf_model_dir, ) - - -def test_MetricsConfigHandler_weights_basename_filepath( - path_to_repo_dir, path_to_tmp_output_dir, path_to_config_dir -): - test_file_handler = FileIOHandler( - path_to_repo_dir, path_to_tmp_output_dir, path_to_config_dir - ) - - metrics_config_file = "validation/main_random_seed/config/metrics_config.yaml" - - metrics_object = configs_handler.MetricsConfigHandler( - os.path.join(path_to_config_dir, metrics_config_file), test_file_handler - ) - weights_filepath = metrics_object.weights_basename_filepath - - assert ( - weights_filepath - == "src/wf_psf/tests/data/validation/main_random_seed/checkpoint/checkpoint*_poly*_sample_w_bis1_2k_cycle2*" - ) From 16d4f9479b8774af25cf131d79ffc3ee8e7a6350 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Wed, 14 May 2025 20:35:12 +0200 Subject: [PATCH 007/146] Update package name in import statement --- src/wf_psf/instrument/ccd_misalignments.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/wf_psf/instrument/ccd_misalignments.py b/src/wf_psf/instrument/ccd_misalignments.py index 0f51c32f..5da55ee6 100644 --- a/src/wf_psf/instrument/ccd_misalignments.py +++ b/src/wf_psf/instrument/ccd_misalignments.py @@ -10,7 +10,7 @@ import numpy as np import matplotlib.path as mpltPath from scipy.spatial import KDTree -from wf_psf.utils.preprocessing import defocus_to_zk4_wavediff +from wf_psf.data.data_preprocessing import defocus_to_zk4_wavediff class CCDMisalignmentCalculator: From a0fe57bda31c5b3ddcf93f7a3eb0873aef755c62 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Wed, 14 May 2025 22:54:33 +0200 Subject: [PATCH 008/146] Reorder imports; Refactor MetricsConfigHandler class attributes, methods and variable names| --- src/wf_psf/utils/configs_handler.py | 191 +++++++++++++++------------- 1 file changed, 102 insertions(+), 89 deletions(-) diff --git a/src/wf_psf/utils/configs_handler.py b/src/wf_psf/utils/configs_handler.py index 4f78990e..bbd33dd7 100644 --- a/src/wf_psf/utils/configs_handler.py +++ b/src/wf_psf/utils/configs_handler.py @@ -12,12 +12,14 @@ import os import re import glob -from wf_psf.utils.read_config import read_conf from wf_psf.data.training_preprocessing import DataHandler -from wf_psf.training import train -from wf_psf.psf_models import psf_models from wf_psf.metrics.metrics_interface import evaluate_model from wf_psf.plotting.plots_interface import plot_metrics +from wf_psf.psf_models import psf_models +from wf_psf.psf_models.psf_model_loader import load_trained_psf_model +from wf_psf.training import train +from wf_psf.utils.read_config import read_conf + logger = logging.getLogger(__name__) @@ -254,103 +256,135 @@ class MetricsConfigHandler: def __init__(self, metrics_conf, file_handler, training_conf=None): self._metrics_conf = read_conf(metrics_conf) - self._file_handler = file_handler - self.trained_model_path = self._get_trained_model_path(training_conf) - self._training_conf = self._load_training_conf(training_conf) + self.data_conf = self._load_data_conf() + self._file_handler = file_handler + self.metrics_dir = self._file_handler.get_metrics_dir(self._file_handler._run_output_dir) + self.training_conf = training_conf + self.trained_psf_model = self.load_trained_psf_model(self.training_conf, self.data_conf ) @property def metrics_conf(self): return self._metrics_conf - @property - def metrics_dir(self): - return self._file_handler.get_metrics_dir(self._file_handler._run_output_dir) - @property def training_conf(self): + """Returns the loaded training configuration.""" return self._training_conf + @training_conf.setter + def training_conf(self, training_conf): + """ + Sets the training configuration. If None is provided, attempts to load it + from the trained_model_path in the metrics configuration. + """ + if training_conf is None: + try: + training_conf_path = self._get_training_conf_path_from_metrics() + logger.info(f"Loading training config from inferred path: {training_conf_path}") + self._training_conf = read_conf(training_conf_path) + except Exception as e: + logger.error(f"Failed to load training config: {e}") + raise + else: + self._training_conf = training_conf + @property def plotting_conf(self): return self.metrics_conf.metrics.plotting_config - @property - def data_conf(self): - return self._load_data_conf() - - @property - def psf_model(self): - return psf_models.get_psf_model( - self.training_conf.training.model_params, - self.training_conf.training.training_hparams, + def _load_trained_psf_model(self): + trained_model_path = self._get_trained_model_path() + try: + model_subdir = self.metrics_conf.metrics.model_save_path + cycle = self.metrics_conf.metrics.saved_training_cycle + except AttributeError as e: + raise KeyError("Missing required model config fields.") from e + + model_name = self.training_conf.training.model_params.model_name + id_name = self.training_conf.training.id_name + + weights_path_pattern = os.path.join( + trained_model_path, + model_subdir, + ( + f"{model_subdir}*_{model_name}" + f"*{id_name}_cycle{cycle}*" + ), + ) + return load_trained_psf_model( + self.training_conf, self.data_conf, + weights_path_pattern, ) - @property - def weights_path(self): - return psf_models.get_psf_model_weights_filepath(self.weights_basename_filepath) - - def _get_trained_model_path(self, training_conf): - """Get Trained Model Path. - Helper method to get the trained model path. - - Parameters - ---------- - training_conf: None or RecursiveNamespace - None type or RecursiveNamespace + def _get_training_conf_path_from_metrics(self): + """ + Retrieves the full path to the training config based on the metrics configuration. Returns ------- str - A string representing the path to the trained model output run directory. - + Full path to the training configuration file. + + Raises + ------ + KeyError + If 'trained_model_config' key is missing. + FileNotFoundError + If the file does not exist at the constructed path. """ - if training_conf is None: - try: - return self._metrics_conf.metrics.trained_model_path + trained_model_path = self._get_trained_model_path() - except TypeError as e: - logger.exception(e) - raise ConfigParameterError( - "Metrics config file trained model path or config values are empty." - ) - else: - return os.path.join( - self._file_handler.output_path, - self._file_handler.parent_output_dir, - self._file_handler.workdir, - ) + try: + training_conf_filename = self._metrics_conf.metrics.trained_model_config + except AttributeError as e: + raise KeyError("Missing 'trained_model_config' key in metrics configuration.") from e - def _load_training_conf(self, training_conf): - """Load Training Conf. - Load the training configuration if training_conf is not provided. + training_conf_path = os.path.join( + self._file_handler.get_config_dir(trained_model_path), training_conf_filename) - Parameters - ---------- - training_conf: None or RecursiveNamespace - None type or a RecursiveNamespace storing the training configuration parameter setttings. + if not os.path.exists(training_conf_path): + raise FileNotFoundError(f"Training config file not found: {training_conf_path}") + + return training_conf_path + + + def _get_trained_model_path(self): + """ + Determine the trained model path from either: + + 1. The metrics configuration file (i.e., for metrics-only runs after training), or + 2. The runtime-generated file handler paths (i.e., for single runs that perform both training and evaluation). Returns ------- - RecursiveNamespace storing the training configuration parameter settings. + str + Path to the trained model directory. + Raises + ------ + ConfigParameterError + If the path specified in the metrics config is invalid or missing. """ - if training_conf is None: - try: - return read_conf( - os.path.join( - self._file_handler.get_config_dir(self.trained_model_path), - self._metrics_conf.metrics.trained_model_config, - ) - ) - except TypeError as e: - logger.exception(e) + trained_model_path = getattr(self._metrics_conf.metrics, "trained_model_path", None) + + if trained_model_path: + if not os.path.isdir(trained_model_path): raise ConfigParameterError( - "Metrics config file trained model path or config values are empty." + f"The trained model path provided in the metrics config is not a valid directory: {trained_model_path}" ) - else: - return training_conf + logger.info(f"Using trained model path from metrics config: {trained_model_path}") + return trained_model_path + + # Fallback for single-run training + metrics evaluation mode + fallback_path = os.path.join( + self._file_handler.output_path, + self._file_handler.parent_output_dir, + self._file_handler.workdir, + ) + logger.info(f"Using fallback trained model path from runtime file handler: {fallback_path}") + return fallback_path def _load_data_conf(self): """Load Data Conf. @@ -374,26 +408,6 @@ def _load_data_conf(self): logger.exception(e) raise ConfigParameterError("Data configuration loading error.") - @property - def weights_basename_filepath(self): - """Get PSF model weights filepath. - - A function to return the basename of the user-specified PSF model weights path. - - Returns - ------- - weights_basename: str - The basename of the PSF model weights to be loaded. - - """ - return os.path.join( - self.trained_model_path, - self.metrics_conf.metrics.model_save_path, - ( - f"{self.metrics_conf.metrics.model_save_path}*_{self.training_conf.training.model_params.model_name}" - f"*{self.training_conf.training.id_name}_cycle{self.metrics_conf.metrics.saved_training_cycle}*" - ), - ) def call_plot_config_handler_run(self, model_metrics): """Make Metrics Plots. @@ -450,7 +464,6 @@ def run(self): self.training_conf.training, self.data_conf, self.psf_model, - self.weights_path, self.metrics_dir, ) From 7e34077ec2cd6e70235c47e618131dbd7094ad64 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Wed, 14 May 2025 22:55:47 +0200 Subject: [PATCH 009/146] Move psf_model weights loader to psf_model_loader.py module --- src/wf_psf/metrics/metrics_interface.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/wf_psf/metrics/metrics_interface.py b/src/wf_psf/metrics/metrics_interface.py index 027b0213..c7233c52 100644 --- a/src/wf_psf/metrics/metrics_interface.py +++ b/src/wf_psf/metrics/metrics_interface.py @@ -366,14 +366,6 @@ def evaluate_model( # Prepare np input simPSF_np = data.training_data.simPSF - ## Load the model's weights - try: - logger.info("Loading PSF model weights from {}".format(weights_path)) - psf_model.load_weights(weights_path) - except Exception as e: - logger.exception("An error occurred with the weights_path file: %s", e) - exit() - # Define datasets datasets = {"test": data.test_data.dataset, "train": data.training_data.dataset} From d6c8dfdc0dd6009a687ae0c5d1e8fd854f6f2c02 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Wed, 14 May 2025 22:56:13 +0200 Subject: [PATCH 010/146] Add psf_model_loader module --- src/wf_psf/psf_models/psf_model_loader.py | 52 +++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 src/wf_psf/psf_models/psf_model_loader.py diff --git a/src/wf_psf/psf_models/psf_model_loader.py b/src/wf_psf/psf_models/psf_model_loader.py new file mode 100644 index 00000000..26056e57 --- /dev/null +++ b/src/wf_psf/psf_models/psf_model_loader.py @@ -0,0 +1,52 @@ +"""PSF Model Loader. + +This module provides helper functions for loading trained PSF models. +It includes utilities to: +- Load a model from disk using its configuration and weights. +- Prepare inputs for inference or evaluation workflows. + +Author: Jennifer Pollack +""" +from wf_psf.psf_models.psf_models import ( + get_psf_model, + get_psf_model_weights_filepath +) + +def load_trained_psf_model(training_conf, data_conf, weights_path_pattern): + """ + Loads a trained PSF model and applies saved weights. + + Parameters + ---------- + training_conf : RecursiveNamespace + Configuration object containing model parameters and training hyperparameters. + Supports attribute-style access to nested fields. + data_conf : RecursiveNamespace + Configuration object containing data-related parameters. + weights_path_pattern : str + Glob-style pattern used to locate the model weights file. + + Returns + ------- + model : tf.keras.Model or compatible + The PSF model instance with loaded weights. + + Raises + ------ + RuntimeError + If loading the model weights fails for any reason. + """ + model = get_psf_model(training_conf.training.model_params, + training_conf.training.training_hparams, + data_conf) + + weights_path = get_psf_model_weights_filepath(weights_path_pattern) + + try: + logger.info(f"Loading PSF model weights from {weights_path}") + model.load_weights(weights_path) + except Exception as e: + logger.exception("Failed to load model weights.") + raise RuntimeError("Model weight loading failed.") from e + return model + From fa4f890b390c888d3bb9d84af008a990a501fb1f Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 15 May 2025 13:13:38 +0200 Subject: [PATCH 011/146] Remove weights_path arg from evaluate_model method; Update logger.info statement and comment --- src/wf_psf/metrics/metrics_interface.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/wf_psf/metrics/metrics_interface.py b/src/wf_psf/metrics/metrics_interface.py index c7233c52..5680cecc 100644 --- a/src/wf_psf/metrics/metrics_interface.py +++ b/src/wf_psf/metrics/metrics_interface.py @@ -325,7 +325,6 @@ def evaluate_model( trained_model_params, data, psf_model, - weights_path, metrics_output, ): """Evaluate the trained model on both training and test datasets by computing various metrics. @@ -356,8 +355,8 @@ def evaluate_model( try: ## Load datasets # ----------------------------------------------------- - # Get training data - logger.info("Fetching and preprocessing training and test data...") + # Get training and test data + logger.info("Fetching training and test data...") # Initialize metrics_handler metrics_handler = MetricsParamsHandler(metrics_params, trained_model_params) From 0eb4800111260df0b4e1334191f30355ae99273d Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 15 May 2025 13:14:22 +0200 Subject: [PATCH 012/146] Update variable name and logger statement --- src/wf_psf/utils/configs_handler.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/wf_psf/utils/configs_handler.py b/src/wf_psf/utils/configs_handler.py index bbd33dd7..7b391f27 100644 --- a/src/wf_psf/utils/configs_handler.py +++ b/src/wf_psf/utils/configs_handler.py @@ -256,12 +256,12 @@ class MetricsConfigHandler: def __init__(self, metrics_conf, file_handler, training_conf=None): self._metrics_conf = read_conf(metrics_conf) + self._file_handler = file_handler + self.training_conf = training_conf self.data_conf = self._load_data_conf() - self._file_handler = file_handler self.metrics_dir = self._file_handler.get_metrics_dir(self._file_handler._run_output_dir) - self.training_conf = training_conf - self.trained_psf_model = self.load_trained_psf_model(self.training_conf, self.data_conf ) - + self.trained_psf_model = self._load_trained_psf_model() + @property def metrics_conf(self): return self._metrics_conf @@ -455,15 +455,13 @@ def run(self): input configuration. """ - logger.info( - "Running metrics evaluation on PSF model: {}".format(self.weights_path) - ) + logger.info("Running metrics evaluation on trained PSF model...") model_metrics = evaluate_model( self.metrics_conf.metrics, self.training_conf.training, self.data_conf, - self.psf_model, + self.trained_psf_model, self.metrics_dir, ) From 079b990fe733ec5b4b79ffdb7622e98f53ea1e6d Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 15 May 2025 13:15:53 +0200 Subject: [PATCH 013/146] Add import logging and create logger object --- src/wf_psf/psf_models/psf_model_loader.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/wf_psf/psf_models/psf_model_loader.py b/src/wf_psf/psf_models/psf_model_loader.py index 26056e57..1d2e267f 100644 --- a/src/wf_psf/psf_models/psf_model_loader.py +++ b/src/wf_psf/psf_models/psf_model_loader.py @@ -7,11 +7,15 @@ Author: Jennifer Pollack """ +import logging from wf_psf.psf_models.psf_models import ( get_psf_model, get_psf_model_weights_filepath ) + +logger = logging.getLogger(__name__) + def load_trained_psf_model(training_conf, data_conf, weights_path_pattern): """ Loads a trained PSF model and applies saved weights. From 2aa8db7a8633e838ccb534148b8b40587b163ec0 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Thu, 15 May 2025 13:18:43 +0200 Subject: [PATCH 014/146] Create psf_inference.py module --- src/wf_psf/inference/psf_inference.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index e69de29b..4f5b39e5 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -0,0 +1,21 @@ +"""Inference. + +A module which provides a set of functions to perform inference +on PSF models. It includes functions to load a trained model, +perform inference on a dataset of SEDs and positions, and generate a polychromatic PSF. + +:Authors: Jennifer Pollack + +""" + +import os +import glob +import logging +import numpy as np +from wf_psf.psf_models import psf_models, psf_model_loader +import tensorflow as tf + + +#def prepare_inputs(...): ... +#def generate_psfs(...): ... +#def run_pipeline(...): ... From b33f171fd774a67610edb2944db0609ff996f47e Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Thu, 15 May 2025 13:21:11 +0200 Subject: [PATCH 015/146] Remove arg from evaluate_model unit test --- src/wf_psf/tests/test_metrics/metrics_interface_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/wf_psf/tests/test_metrics/metrics_interface_test.py b/src/wf_psf/tests/test_metrics/metrics_interface_test.py index 8b498295..f3e19441 100644 --- a/src/wf_psf/tests/test_metrics/metrics_interface_test.py +++ b/src/wf_psf/tests/test_metrics/metrics_interface_test.py @@ -76,7 +76,6 @@ def test_evaluate_model(mock_metrics_params, mock_trained_model_params, mock_dat trained_model_params=mock_trained_model_params, data=mock_data, psf_model=mock_psf_model, - weights_path="/mock/weights/path", metrics_output="/mock/metrics/output" ) From cc575115ac799ac22a37458f73b2e2b4f37a1564 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 16 May 2025 14:41:17 +0200 Subject: [PATCH 016/146] Refactor: reorganise modules, relocate utility functions, rename modules, update import statements and unit tests --- src/wf_psf/data/centroids.py | 71 ++++- ...ining_preprocessing.py => data_handler.py} | 230 +------------- src/wf_psf/data/data_preprocessing.py | 102 ------ src/wf_psf/data/data_zernike_utils.py | 295 ++++++++++++++++++ src/wf_psf/instrument/ccd_misalignments.py | 40 +++ .../psf_models/tf_modules/tf_psf_field.py | 2 +- src/wf_psf/tests/__init__.py | 1 + src/wf_psf/tests/conftest.py | 2 +- src/wf_psf/tests/test_data/__init__.py | 0 src/wf_psf/tests/test_data/centroids_test.py | 183 ++++------- src/wf_psf/tests/test_data/conftest.py | 61 ++++ ...rocessing_test.py => data_handler_test.py} | 179 ++--------- .../test_data/data_zernike_utils_test.py | 133 ++++++++ src/wf_psf/tests/test_data/test_data_utils.py | 30 ++ .../test_metrics/metrics_interface_test.py | 2 +- src/wf_psf/tests/test_psf_models/conftest.py | 2 +- .../psf_model_physical_polychromatic_test.py | 2 +- .../tests/test_utils/configs_handler_test.py | 2 +- src/wf_psf/utils/configs_handler.py | 2 +- 19 files changed, 744 insertions(+), 595 deletions(-) rename src/wf_psf/data/{training_preprocessing.py => data_handler.py} (50%) delete mode 100644 src/wf_psf/data/data_preprocessing.py create mode 100644 src/wf_psf/data/data_zernike_utils.py create mode 100644 src/wf_psf/tests/__init__.py create mode 100644 src/wf_psf/tests/test_data/__init__.py rename src/wf_psf/tests/test_data/{training_preprocessing_test.py => data_handler_test.py} (55%) create mode 100644 src/wf_psf/tests/test_data/data_zernike_utils_test.py create mode 100644 src/wf_psf/tests/test_data/test_data_utils.py diff --git a/src/wf_psf/data/centroids.py b/src/wf_psf/data/centroids.py index 3e1a8c71..01ecea4e 100644 --- a/src/wf_psf/data/centroids.py +++ b/src/wf_psf/data/centroids.py @@ -8,10 +8,79 @@ import numpy as np import scipy.signal as scisig -from wf_psf.data.data_preprocessing import shift_x_y_to_zk1_2_wavediff +from wf_psf.data.data_handler import extract_star_data +from fractions import Fraction +from wf_psf.data.data_zernike_utils import shift_x_y_to_zk1_2_wavediff +import tensorflow as tf from typing import Optional +def compute_centroid_correction(model_params, data, batch_size: int=1) -> np.ndarray: + """Compute centroid corrections using Zernike polynomials. + + This function calculates the Zernike contributions required to match the centroid + of the WaveDiff PSF model to the observed star centroids, processing in batches. + + Parameters + ---------- + model_params : RecursiveNamespace + An object containing parameters for the PSF model, including pixel sampling and initial centroid window parameters. + + data : DataConfigHandler + An object containing training and test datasets, including observed PSFs + and optional star masks. + + batch_size : int, optional + The batch size to use when processing the stars. Default is 16. + + + Returns + ------- + zernike_centroid_array : np.ndarray + A 2D NumPy array of shape `(n_stars, 3)`, where `n_stars` is the number of + observed stars. The array contains the computed Zernike contributions, + with zero padding applied to the first column to ensure a consistent shape. + """ + star_postage_stamps = extract_star_data(data=data, train_key="noisy_stars", test_key="stars") + + # Get star mask catalogue only if "masks" exist in both training and test datasets + star_masks = ( + extract_star_data(data=data, train_key="masks", test_key="masks") + if ( + data.training_data.dataset.get("masks") is not None + and data.test_data.dataset.get("masks") is not None + and tf.size(data.training_data.dataset["masks"]) > 0 + and tf.size(data.test_data.dataset["masks"]) > 0 + ) + else None + ) + + pix_sampling = model_params.pix_sampling * 1e-6 # Change units from [um] to [m] + + # Ensure star_masks is properly handled + star_masks = star_masks if star_masks is not None else None + + reference_shifts = [float(Fraction(value)) for value in model_params.reference_shifts] + + n_stars = len(star_postage_stamps) + zernike_centroid_array = [] + + # Batch process the stars + for i in range(0, n_stars, batch_size): + batch_postage_stamps = star_postage_stamps[i:i + batch_size] + batch_masks = star_masks[i:i + batch_size] if star_masks is not None else None + + # Compute Zernike 1 and Zernike 2 for the batch + zk1_2_batch = -1.0 * compute_zernike_tip_tilt( + batch_postage_stamps, batch_masks, pix_sampling, reference_shifts + ) + + # Zero pad array for each batch and append + zernike_centroid_array.append(np.pad(zk1_2_batch, pad_width=[(0, 0), (1, 0)], mode="constant", constant_values=0)) + + # Combine all batches into a single array + return np.concatenate(zernike_centroid_array, axis=0) + def compute_zernike_tip_tilt( star_images: np.ndarray, star_masks: Optional[np.ndarray] = None, diff --git a/src/wf_psf/data/training_preprocessing.py b/src/wf_psf/data/data_handler.py similarity index 50% rename from src/wf_psf/data/training_preprocessing.py rename to src/wf_psf/data/data_handler.py index fe34a3bd..fd0e3474 100644 --- a/src/wf_psf/data/training_preprocessing.py +++ b/src/wf_psf/data/data_handler.py @@ -1,17 +1,21 @@ -"""Training Data Processing. +"""Data Handler Module. -A module to load and preprocess training and validation test data. +Provides tools for loading, preprocessing, and managing data used in both training and inference workflows. -:Authors: Jennifer Pollack and Tobias Liaudat +Includes: +- The `DataHandler` class for managing datasets and associated metadata +- Utility functions for loading structured data products +- Preprocessing routines for spectral energy distributions (SEDs), including format conversion (e.g., to TensorFlow) and transformations +This module serves as a central interface between raw data and modeling components. + +Authors: Jennifer Pollack , Tobias Liaudat """ import os import numpy as np import wf_psf.utils.utils as utils import tensorflow as tf -from wf_psf.instrument.ccd_misalignments import CCDMisalignmentCalculator -from wf_psf.data.centroids import compute_zernike_tip_tilt from fractions import Fraction import logging @@ -80,12 +84,13 @@ def __init__(self, dataset_type, data_params, simPSF, n_bins_lambda, load_data: self.data_params = data_params.__dict__[dataset_type] self.simPSF = simPSF self.n_bins_lambda = n_bins_lambda - self.dataset = None - self.sed_data = None self.load_data_on_init = load_data if self.load_data_on_init: self.load_dataset() self.process_sed_data() + else: + self.dataset = None + self.sed_data = None def load_dataset(self): @@ -115,7 +120,8 @@ def load_dataset(self): ) else: logger.warning(f"Missing 'stars' in {self.dataset_type} dataset.") - + elif "inference" == self.dataset_type: + pass def process_sed_data(self): """Process SED Data. @@ -235,211 +241,3 @@ def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: # Concatenate and return return np.concatenate((train_data, test_data), axis=0) - -def get_np_zernike_prior(data): - """Get the zernike prior from the provided dataset. - - This method concatenates the stars from both the training - and test datasets to obtain the full prior. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - zernike_prior : np.ndarray - Numpy array containing the full prior. - """ - zernike_prior = np.concatenate( - ( - data.training_data.dataset["zernike_prior"], - data.test_data.dataset["zernike_prior"], - ), - axis=0, - ) - - return zernike_prior - - -def compute_centroid_correction(model_params, data, batch_size: int=1) -> np.ndarray: - """Compute centroid corrections using Zernike polynomials. - - This function calculates the Zernike contributions required to match the centroid - of the WaveDiff PSF model to the observed star centroids, processing in batches. - - Parameters - ---------- - model_params : RecursiveNamespace - An object containing parameters for the PSF model, including pixel sampling and initial centroid window parameters. - - data : DataConfigHandler - An object containing training and test datasets, including observed PSFs - and optional star masks. - - batch_size : int, optional - The batch size to use when processing the stars. Default is 16. - - - Returns - ------- - zernike_centroid_array : np.ndarray - A 2D NumPy array of shape `(n_stars, 3)`, where `n_stars` is the number of - observed stars. The array contains the computed Zernike contributions, - with zero padding applied to the first column to ensure a consistent shape. - """ - star_postage_stamps = extract_star_data(data=data, train_key="noisy_stars", test_key="stars") - - # Get star mask catalogue only if "masks" exist in both training and test datasets - star_masks = ( - extract_star_data(data=data, train_key="masks", test_key="masks") - if ( - data.training_data.dataset.get("masks") is not None - and data.test_data.dataset.get("masks") is not None - and tf.size(data.training_data.dataset["masks"]) > 0 - and tf.size(data.test_data.dataset["masks"]) > 0 - ) - else None - ) - - pix_sampling = model_params.pix_sampling * 1e-6 # Change units from [um] to [m] - - # Ensure star_masks is properly handled - star_masks = star_masks if star_masks is not None else None - - reference_shifts = [float(Fraction(value)) for value in model_params.reference_shifts] - - n_stars = len(star_postage_stamps) - zernike_centroid_array = [] - - # Batch process the stars - for i in range(0, n_stars, batch_size): - batch_postage_stamps = star_postage_stamps[i:i + batch_size] - batch_masks = star_masks[i:i + batch_size] if star_masks is not None else None - - # Compute Zernike 1 and Zernike 2 for the batch - zk1_2_batch = -1.0 * compute_zernike_tip_tilt( - batch_postage_stamps, batch_masks, pix_sampling, reference_shifts - ) - - # Zero pad array for each batch and append - zernike_centroid_array.append(np.pad(zk1_2_batch, pad_width=[(0, 0), (1, 0)], mode="constant", constant_values=0)) - - # Combine all batches into a single array - return np.concatenate(zernike_centroid_array, axis=0) - - -def compute_ccd_misalignment(model_params, data): - """Compute CCD misalignment. - - Parameters - ---------- - model_params : RecursiveNamespace - Object containing parameters for this PSF model class. - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - zernike_ccd_misalignment_array : np.ndarray - Numpy array containing the Zernike contributions to model the CCD chip misalignments. - """ - obs_positions = get_np_obs_positions(data) - - ccd_misalignment_calculator = CCDMisalignmentCalculator( - tiles_path=model_params.ccd_misalignments_input_path, - x_lims=model_params.x_lims, - y_lims=model_params.y_lims, - tel_focal_length=model_params.tel_focal_length, - tel_diameter=model_params.tel_diameter, - ) - # Compute required zernike 4 for each position - zk4_values = np.array( - [ - ccd_misalignment_calculator.get_zk4_from_position(single_pos) - for single_pos in obs_positions - ] - ).reshape(-1, 1) - - # Zero pad array to get shape (n_stars, n_zernike=4) - zernike_ccd_misalignment_array = np.pad( - zk4_values, pad_width=[(0, 0), (3, 0)], mode="constant", constant_values=0 - ) - - return zernike_ccd_misalignment_array - - -def get_zernike_prior(model_params, data, batch_size: int=16): - """Get Zernike priors from the provided dataset. - - This method concatenates the Zernike priors from both the training - and test datasets. - - Parameters - ---------- - model_params : RecursiveNamespace - Object containing parameters for this PSF model class. - data : DataConfigHandler - Object containing training and test datasets. - batch_size : int, optional - The batch size to use when processing the stars. Default is 16. - - Returns - ------- - tf.Tensor - Tensor containing the observed positions of the stars. - - Notes - ----- - The Zernike prior are obtained by concatenating the Zernike priors - from both the training and test datasets along the 0th axis. - - """ - # List of zernike contribution - zernike_contribution_list = [] - - if model_params.use_prior: - logger.info("Reading in Zernike prior into Zernike contribution list...") - zernike_contribution_list.append(get_np_zernike_prior(data)) - - if model_params.correct_centroids: - logger.info("Adding centroid correction to Zernike contribution list...") - zernike_contribution_list.append( - compute_centroid_correction(model_params, data, batch_size) - ) - - if model_params.add_ccd_misalignments: - logger.info("Adding CCD mis-alignments to Zernike contribution list...") - zernike_contribution_list.append(compute_ccd_misalignment(model_params, data)) - - if len(zernike_contribution_list) == 1: - zernike_contribution = zernike_contribution_list[0] - else: - # Get max zk order - max_zk_order = np.max( - np.array( - [ - zk_contribution.shape[1] - for zk_contribution in zernike_contribution_list - ] - ) - ) - - zernike_contribution = np.zeros( - (zernike_contribution_list[0].shape[0], max_zk_order) - ) - - # Pad arrays to get the same length and add the final contribution - for it in range(len(zernike_contribution_list)): - current_zk_order = zernike_contribution_list[it].shape[1] - current_zernike_contribution = np.pad( - zernike_contribution_list[it], - pad_width=[(0, 0), (0, int(max_zk_order - current_zk_order))], - mode="constant", - constant_values=0, - ) - - zernike_contribution += current_zernike_contribution - - return tf.convert_to_tensor(zernike_contribution, dtype=tf.float32) diff --git a/src/wf_psf/data/data_preprocessing.py b/src/wf_psf/data/data_preprocessing.py deleted file mode 100644 index 44e18436..00000000 --- a/src/wf_psf/data/data_preprocessing.py +++ /dev/null @@ -1,102 +0,0 @@ -"""Data Preprocessing. - -A module with utils to preprocess data. - -:Author: Tobias Liaudat - -""" - -import numpy as np - - -def shift_x_y_to_zk1_2_wavediff(dxy, tel_focal_length=24.5, tel_diameter=1.2): - """Compute Zernike 1(2) for a given shifts in x(y) in WaveDifff conventions. - - All inputs should be in [m]. - A displacement of, for example, 0.5 pixels should be scaled with the corresponding pixel scale, - e.g. 12[um], to get a displacement in [m], which would be `dxy=0.5*12e-6`. - - The output zernike coefficient is in [um] units as expected by wavediff. - - To apply match the centroid with a `dx` that has a corresponding `zk1`, - the new PSF should be generated with `-zk1`. - - The same applies to `dy` and `zk2`. - - Parameters - ---------- - dxy : float - Centroid shift in [m]. It can be on the x-axis or the y-axis. - tel_focal_length : float - Telescope focal length in [m]. - tel_diameter : float - Telescope aperture diameter in [m]. - """ - reference_pix_sampling = 12e-6 - zernike_norm_factor = 2.0 - - # return zernike_norm_factor * (dx/reference_pix_sampling) / (tel_focal_length * tel_diameter / 2) - return ( - zernike_norm_factor - * (tel_diameter / 2) - * np.sin(np.arctan((dxy / reference_pix_sampling) / tel_focal_length)) - * 3.0 - ) - - -def defocus_to_zk4_zemax(dz, tel_focal_length=24.5, tel_diameter=1.2): - """Compute Zernike 4 value for a given defocus in zemax conventions. - - All inputs should be in [m]. - - Parameters - ---------- - dz : float - Shift in the z-axis, perpendicular to the focal plane. Units in [m]. - tel_focal_length : float - Telescope focal length in [m]. - tel_diameter : float - Telescope aperture diameter in [m]. - """ - # Base calculation - zk4 = dz / (8.0 * (tel_focal_length / tel_diameter) ** 2) - # Apply Z4 normalisation - # This step depends on the normalisation of the Zernike basis used - zk4 /= np.sqrt(3) - # Convert to waves with a reference of 800nm - zk4 /= 800e-9 - # Remove the peak to valley value - zk4 /= 2.0 - - return zk4 - - -def defocus_to_zk4_wavediff(dz, tel_focal_length=24.5, tel_diameter=1.2): - """Compute Zernike 4 value for a given defocus in WaveDifff conventions. - - All inputs should be in [m]. - - The output zernike coefficient is in [um] units as expected by wavediff. - - Parameters - ---------- - dz : float - Shift in the z-axis, perpendicular to the focal plane. Units in [m]. - tel_focal_length : float - Telescope focal length in [m]. - tel_diameter : float - Telescope aperture diameter in [m]. - """ - # Base calculation - zk4 = dz / (8.0 * (tel_focal_length / tel_diameter) ** 2) - # Apply Z4 normalisation - # This step depends on the normalisation of the Zernike basis used - zk4 /= np.sqrt(3) - - # Remove the peak to valley value - zk4 /= 2.0 - - # Change units to [um] as Wavediff uses - zk4 *= 1e6 - - return zk4 diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py new file mode 100644 index 00000000..760adb11 --- /dev/null +++ b/src/wf_psf/data/data_zernike_utils.py @@ -0,0 +1,295 @@ +"""Utilities for Zernike Data Handling. + +This module provides utility functions for working with Zernike coefficients, including: +- Prior generation +- Data loading +- Conversions between physical displacements (e.g., defocus, centroid shifts) and modal Zernike coefficients + +Useful in contexts where Zernike representations are used to model optical aberrations or link physical misalignments to wavefront modes. + +:Author: Tobias Liaudat + +""" + +import numpy as np +import tensorflow as tf +from typing import Optional +import logging + +logger = logging.getLogger(__name__) + + +def get_np_zernike_prior(data): + """Get the zernike prior from the provided dataset. + + This method concatenates the stars from both the training + and test datasets to obtain the full prior. + + Parameters + ---------- + data : DataConfigHandler + Object containing training and test datasets. + + Returns + ------- + zernike_prior : np.ndarray + Numpy array containing the full prior. + """ + zernike_prior = np.concatenate( + ( + data.training_data.dataset["zernike_prior"], + data.test_data.dataset["zernike_prior"], + ), + axis=0, + ) + + return zernike_prior + + +def get_zernike_prior(model_params, data, batch_size: int=16): + """Get Zernike priors from the provided dataset. + + This method concatenates the Zernike priors from both the training + and test datasets. + + Parameters + ---------- + model_params : RecursiveNamespace + Object containing parameters for this PSF model class. + data : DataConfigHandler + Object containing training and test datasets. + batch_size : int, optional + The batch size to use when processing the stars. Default is 16. + + Returns + ------- + tf.Tensor + Tensor containing the observed positions of the stars. + + Notes + ----- + The Zernike prior are obtained by concatenating the Zernike priors + from both the training and test datasets along the 0th axis. + + """ + # List of zernike contribution + zernike_contribution_list = [] + + if model_params.use_prior: + logger.info("Reading in Zernike prior into Zernike contribution list...") + zernike_contribution_list.append(get_np_zernike_prior(data)) + + if model_params.correct_centroids: + logger.info("Adding centroid correction to Zernike contribution list...") + zernike_contribution_list.append( + compute_centroid_correction(model_params, data, batch_size) + ) + + if model_params.add_ccd_misalignments: + logger.info("Adding CCD mis-alignments to Zernike contribution list...") + zernike_contribution_list.append(compute_ccd_misalignment(model_params, data)) + + if len(zernike_contribution_list) == 1: + zernike_contribution = zernike_contribution_list[0] + else: + # Get max zk order + max_zk_order = np.max( + np.array( + [ + zk_contribution.shape[1] + for zk_contribution in zernike_contribution_list + ] + ) + ) + + zernike_contribution = np.zeros( + (zernike_contribution_list[0].shape[0], max_zk_order) + ) + + # Pad arrays to get the same length and add the final contribution + for it in range(len(zernike_contribution_list)): + current_zk_order = zernike_contribution_list[it].shape[1] + current_zernike_contribution = np.pad( + zernike_contribution_list[it], + pad_width=[(0, 0), (0, int(max_zk_order - current_zk_order))], + mode="constant", + constant_values=0, + ) + + zernike_contribution += current_zernike_contribution + + return tf.convert_to_tensor(zernike_contribution, dtype=tf.float32) + + +def shift_x_y_to_zk1_2_wavediff(dxy, tel_focal_length=24.5, tel_diameter=1.2): + """Compute Zernike 1(2) for a given shifts in x(y) in WaveDifff conventions. + + All inputs should be in [m]. + A displacement of, for example, 0.5 pixels should be scaled with the corresponding pixel scale, + e.g. 12[um], to get a displacement in [m], which would be `dxy=0.5*12e-6`. + + The output zernike coefficient is in [um] units as expected by wavediff. + + To apply match the centroid with a `dx` that has a corresponding `zk1`, + the new PSF should be generated with `-zk1`. + + The same applies to `dy` and `zk2`. + + Parameters + ---------- + dxy : float + Centroid shift in [m]. It can be on the x-axis or the y-axis. + tel_focal_length : float + Telescope focal length in [m]. + tel_diameter : float + Telescope aperture diameter in [m]. + """ + reference_pix_sampling = 12e-6 + zernike_norm_factor = 2.0 + + # return zernike_norm_factor * (dx/reference_pix_sampling) / (tel_focal_length * tel_diameter / 2) + return ( + zernike_norm_factor + * (tel_diameter / 2) + * np.sin(np.arctan((dxy / reference_pix_sampling) / tel_focal_length)) + * 3.0 + ) + +def compute_zernike_tip_tilt( + star_images: np.ndarray, + star_masks: Optional[np.ndarray] = None, + pixel_sampling: float = 12e-6, + reference_shifts: list[float] = [-1/3, -1/3], + sigma_init: float = 2.5, + n_iter: int = 20, +) -> np.ndarray: + """ + Compute Zernike tip-tilt corrections for a batch of PSF images. + + This function estimates the centroid shifts of multiple PSFs and computes + the corresponding Zernike tip-tilt corrections to align them with a reference. + + Parameters + ---------- + star_images : np.ndarray + A batch of PSF images (3D array of shape `(num_images, height, width)`). + star_masks : np.ndarray, optional + A batch of masks (same shape as `star_postage_stamps`). Each mask can have: + - `0` to ignore the pixel. + - `1` to fully consider the pixel. + - Values in `(0,1]` as weights for partial consideration. + Defaults to None. + pixel_sampling : float, optional + The pixel size in meters. Defaults to `12e-6 m` (12 microns). + reference_shifts : list[float], optional + The target centroid shifts in pixels, specified as `[dy, dx]`. + Defaults to `[-1/3, -1/3]` (nominal Euclid conditions). + sigma_init : float, optional + Initial standard deviation for centroid estimation. Default is `2.5`. + n_iter : int, optional + Number of iterations for centroid refinement. Default is `20`. + + Returns + ------- + np.ndarray + An array of shape `(num_images, 2)`, where: + - Column 0 contains `Zk1` (tip) values. + - Column 1 contains `Zk2` (tilt) values. + + Notes + ----- + - This function processes all images at once using vectorized operations. + - The Zernike coefficients are computed in the WaveDiff convention. + """ + from wf_psf.data.centroids import CentroidEstimator + + # Vectorize the centroid computation + centroid_estimator = CentroidEstimator( + im=star_images, + mask=star_masks, + sigma_init=sigma_init, + n_iter=n_iter + ) + + shifts = centroid_estimator.get_intra_pixel_shifts() + + # Ensure reference_shifts is a NumPy array (if it's not already) + reference_shifts = np.array(reference_shifts) + + # Reshape to ensure it's a column vector (1, 2) + reference_shifts = reference_shifts[None,:] + + # Broadcast reference_shifts to match the shape of shifts + reference_shifts = np.broadcast_to(reference_shifts, shifts.shape) + + # Compute displacements + displacements = (reference_shifts - shifts) # + + # Ensure the correct axis order for displacements (x-axis, then y-axis) + displacements_swapped = displacements[:, [1, 0]] # Adjust axis order if necessary + + # Call shift_x_y_to_zk1_2_wavediff directly on the vector of displacements + zk1_2_array = shift_x_y_to_zk1_2_wavediff(displacements_swapped.flatten() * pixel_sampling ) # vectorized call + + # Reshape the result back to the original shape of displacements + zk1_2_array = zk1_2_array.reshape(displacements.shape) + + return zk1_2_array + + +def defocus_to_zk4_zemax(dz, tel_focal_length=24.5, tel_diameter=1.2): + """Compute Zernike 4 value for a given defocus in zemax conventions. + + All inputs should be in [m]. + + Parameters + ---------- + dz : float + Shift in the z-axis, perpendicular to the focal plane. Units in [m]. + tel_focal_length : float + Telescope focal length in [m]. + tel_diameter : float + Telescope aperture diameter in [m]. + """ + # Base calculation + zk4 = dz / (8.0 * (tel_focal_length / tel_diameter) ** 2) + # Apply Z4 normalisation + # This step depends on the normalisation of the Zernike basis used + zk4 /= np.sqrt(3) + # Convert to waves with a reference of 800nm + zk4 /= 800e-9 + # Remove the peak to valley value + zk4 /= 2.0 + + return zk4 + + +def defocus_to_zk4_wavediff(dz, tel_focal_length=24.5, tel_diameter=1.2): + """Compute Zernike 4 value for a given defocus in WaveDifff conventions. + + All inputs should be in [m]. + + The output zernike coefficient is in [um] units as expected by wavediff. + + Parameters + ---------- + dz : float + Shift in the z-axis, perpendicular to the focal plane. Units in [m]. + tel_focal_length : float + Telescope focal length in [m]. + tel_diameter : float + Telescope aperture diameter in [m]. + """ + # Base calculation + zk4 = dz / (8.0 * (tel_focal_length / tel_diameter) ** 2) + # Apply Z4 normalisation + # This step depends on the normalisation of the Zernike basis used + zk4 /= np.sqrt(3) + + # Remove the peak to valley value + zk4 /= 2.0 + + # Change units to [um] as Wavediff uses + zk4 *= 1e6 + + return zk4 diff --git a/src/wf_psf/instrument/ccd_misalignments.py b/src/wf_psf/instrument/ccd_misalignments.py index 5da55ee6..1b153cb3 100644 --- a/src/wf_psf/instrument/ccd_misalignments.py +++ b/src/wf_psf/instrument/ccd_misalignments.py @@ -13,6 +13,46 @@ from wf_psf.data.data_preprocessing import defocus_to_zk4_wavediff +def compute_ccd_misalignment(model_params, data): + """Compute CCD misalignment. + + Parameters + ---------- + model_params : RecursiveNamespace + Object containing parameters for this PSF model class. + data : DataConfigHandler + Object containing training and test datasets. + + Returns + ------- + zernike_ccd_misalignment_array : np.ndarray + Numpy array containing the Zernike contributions to model the CCD chip misalignments. + """ + obs_positions = get_np_obs_positions(data) + + ccd_misalignment_calculator = CCDMisalignmentCalculator( + tiles_path=model_params.ccd_misalignments_input_path, + x_lims=model_params.x_lims, + y_lims=model_params.y_lims, + tel_focal_length=model_params.tel_focal_length, + tel_diameter=model_params.tel_diameter, + ) + # Compute required zernike 4 for each position + zk4_values = np.array( + [ + ccd_misalignment_calculator.get_zk4_from_position(single_pos) + for single_pos in obs_positions + ] + ).reshape(-1, 1) + + # Zero pad array to get shape (n_stars, n_zernike=4) + zernike_ccd_misalignment_array = np.pad( + zk4_values, pad_width=[(0, 0), (3, 0)], mode="constant", constant_values=0 + ) + + return zernike_ccd_misalignment_array + + class CCDMisalignmentCalculator: """CCD Misalignment Calculator. diff --git a/src/wf_psf/psf_models/tf_modules/tf_psf_field.py b/src/wf_psf/psf_models/tf_modules/tf_psf_field.py index 45ea91dd..f39f8bdd 100644 --- a/src/wf_psf/psf_models/tf_modules/tf_psf_field.py +++ b/src/wf_psf/psf_models/tf_modules/tf_psf_field.py @@ -16,7 +16,7 @@ TFPhysicalLayer, ) from wf_psf.psf_models.models.psf_model_semiparametric import TFSemiParametricField -from wf_psf.data.training_preprocessing import get_obs_positions +from wf_psf.data.data_handler import get_obs_positions from wf_psf.psf_models import psf_models as psfm import logging diff --git a/src/wf_psf/tests/__init__.py b/src/wf_psf/tests/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/src/wf_psf/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/src/wf_psf/tests/conftest.py b/src/wf_psf/tests/conftest.py index 5b617c63..beb6b9fb 100644 --- a/src/wf_psf/tests/conftest.py +++ b/src/wf_psf/tests/conftest.py @@ -13,7 +13,7 @@ from wf_psf.training.train import TrainingParamsHandler from wf_psf.utils.configs_handler import DataConfigHandler from wf_psf.psf_models import psf_models -from wf_psf.data.training_preprocessing import DataHandler +from wf_psf.data.data_handler import DataHandler training_config = RecursiveNamespace( id_name="-coherent_euclid_200stars", diff --git a/src/wf_psf/tests/test_data/__init__.py b/src/wf_psf/tests/test_data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/wf_psf/tests/test_data/centroids_test.py b/src/wf_psf/tests/test_data/centroids_test.py index 5ef5f86f..08c05382 100644 --- a/src/wf_psf/tests/test_data/centroids_test.py +++ b/src/wf_psf/tests/test_data/centroids_test.py @@ -8,11 +8,14 @@ import numpy as np import pytest -from unittest.mock import MagicMock, patch from wf_psf.data.centroids import ( - compute_zernike_tip_tilt, + compute_centroid_correction, CentroidEstimator ) +from wf_psf.data.data_handler import extract_star_data +from wf_psf.data.data_zernike_utils import compute_zernike_tip_tilt +from wf_psf.utils.read_config import RecursiveNamespace +from unittest.mock import MagicMock, patch # Function to compute centroid based on first-order moments def calculate_centroid(image, mask=None): @@ -29,22 +32,6 @@ def calculate_centroid(image, mask=None): yc = M01 / M00 return (xc, yc) -@pytest.fixture -def simple_image(): - """Fixture for a batch of simple star images.""" - num_images = 1 # Change this to test with multiple images - image = np.zeros((num_images, 5, 5)) # Create a 3D array - image[:, 2, 2] = 1 # Place the star at the center for each image - return image - -@pytest.fixture -def multiple_images(): - """Fixture for a batch of images with stars at different positions.""" - images = np.zeros((3, 5, 5)) # 3 images, each of size 5x5 - images[0, 2, 2] = 1 # Star at center of image 0 - images[1, 1, 3] = 1 # Star at (1, 3) in image 1 - images[2, 3, 1] = 1 # Star at (3, 1) in image 2 - return images @pytest.fixture def simple_star_and_mask(): @@ -64,13 +51,6 @@ def simple_star_and_mask(): return image, mask - -@pytest.fixture -def identity_mask(): - """Creates a mask where all pixels are fully considered.""" - return np.ones((5, 5)) - - @pytest.fixture def simple_image_with_mask(simple_image): """Fixture for a batch of star images with masks.""" @@ -120,101 +100,74 @@ def batch_images(): return images -def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_mask): - """Test compute_zernike_tip_tilt with single batch input and mocks.""" - - # Mock the CentroidEstimator class - mock_centroid_calc = mocker.patch("wf_psf.data.centroids.CentroidEstimator", autospec=True) - - # Create a mock instance and configure get_intra_pixel_shifts() - mock_instance = mock_centroid_calc.return_value - mock_instance.get_intra_pixel_shifts.return_value = np.array([[0.05, -0.02]]) # Shape (1, 2) - - # Mock shift_x_y_to_zk1_2_wavediff to return predictable values - mock_shift_fn = mocker.patch( - "wf_psf.data.centroids.shift_x_y_to_zk1_2_wavediff", - side_effect=lambda shift: shift * 0.5 # Mocked conversion for test +def test_compute_centroid_correction_with_masks(mock_data): + """Test compute_centroid_correction function with masks present.""" + # Given that compute_centroid_correction expects a model_params and data object + model_params = RecursiveNamespace( + pix_sampling=12e-6, # Example pixel sampling in meters + correct_centroids=True, + reference_shifts=["-1/3", "-1/3"] ) - # Define test inputs (batch of 1 image) - pixel_sampling = 12e-6 - reference_shifts = [-1 / 3, -1 / 3] # Default Euclid conditions - - # Run the function - zernike_corrections = compute_zernike_tip_tilt(simple_image, identity_mask, pixel_sampling, reference_shifts) - zernike_corrections = compute_zernike_tip_tilt(simple_image, identity_mask, pixel_sampling, reference_shifts) - - # Expected shifts based on centroid calculation - expected_dx = (reference_shifts[1] - (-0.02)) # Expected x-axis shift in meters - expected_dy = (reference_shifts[0] - 0.05) # Expected y-axis shift in meters - - # Expected calls to the mocked function - # Extract the arguments passed to mock_shift_fn - args, _ = mock_shift_fn.call_args_list[0] # Get the first call args - - # Compare expected values with the actual arguments passed to the mock function - np.testing.assert_allclose(args[0][0], expected_dx * pixel_sampling, rtol=1e-7, atol=0) - - # Check dy values similarly - np.testing.assert_allclose(args[0][1], expected_dy * pixel_sampling, rtol=1e-7, atol=0) - - # Expected values based on mock side_effect (0.5 * shift) - np.testing.assert_allclose(zernike_corrections[0, 0], expected_dx * pixel_sampling * 0.5) # Zk1 - np.testing.assert_allclose(zernike_corrections[0, 1], expected_dy * pixel_sampling * 0.5) # Zk2 - -def test_compute_zernike_tip_tilt_batch(mocker, multiple_images): - """Test compute_zernike_tip_tilt with batch input and mocks.""" + # Mock the internal function calls: + with patch('wf_psf.data.centroids.extract_star_data') as mock_extract_star_data, \ + patch('wf_psf.data.centroids.compute_zernike_tip_tilt') as mock_compute_zernike_tip_tilt: + + # Mock the return values of extract_star_data and compute_zernike_tip_tilt + mock_extract_star_data.side_effect = lambda data, train_key, test_key: ( + np.array([[1, 2], [3, 4]]) if train_key == 'noisy_stars' else np.array([[5, 6], [7, 8]]) + ) + mock_compute_zernike_tip_tilt.return_value = np.array([[0.1, 0.2], [0.3, 0.4]]) + + # Call the function under test + result = compute_centroid_correction(model_params, mock_data) + + # Ensure the result has the correct shape + assert result.shape == (4, 3) # Should be (n_stars, 3 Zernike components) + + assert np.allclose(result[0, :], np.array([0, -0.1, -0.2])) # First star Zernike coefficients + assert np.allclose(result[1, :], np.array([0, -0.3, -0.4])) # Second star Zernike coefficients + + +def test_compute_centroid_correction_without_masks(mock_data): + """Test compute_centroid_correction function when no masks are provided.""" + # Remove masks from mock_data + mock_data.test_data.dataset["masks"] = None + mock_data.training_data.dataset["masks"] = None - # Mock the CentroidEstimator class - mock_centroid_calc = mocker.patch("wf_psf.data.centroids.CentroidEstimator", autospec=True) - - # Create a mock instance and configure get_intra_pixel_shifts() - mock_instance = mock_centroid_calc.return_value - mock_instance.get_intra_pixel_shifts.return_value = np.array([[0.05, -0.02], [0.04, -0.01], [0.06, -0.03]]) # Shape (3, 2) - - # Mock shift_x_y_to_zk1_2_wavediff to return predictable values - mock_shift_fn = mocker.patch( - "wf_psf.data.centroids.shift_x_y_to_zk1_2_wavediff", - side_effect=lambda shift: shift * 0.5 # Mocked conversion for test + # Define model parameters + model_params = RecursiveNamespace( + pix_sampling=12e-6, # Example pixel sampling in meters + correct_centroids=True, + reference_shifts=["-1/3", "-1/3"] ) - - # Define test inputs (batch of 3 images) - pixel_sampling = 12e-6 - reference_shifts = [-1 / 3, -1 / 3] # Default Euclid conditions - - # Run the function - zernike_corrections = compute_zernike_tip_tilt( - star_images=multiple_images, - pixel_sampling=pixel_sampling, - reference_shifts=reference_shifts + + # Mock internal function calls + with patch('wf_psf.data.centroids.extract_star_data') as mock_extract_star_data, \ + patch('wf_psf.data.centroids.compute_zernike_tip_tilt') as mock_compute_zernike_tip_tilt: + + # Mock extract_star_data to return synthetic star postage stamps + mock_extract_star_data.side_effect = lambda data, train_key, test_key: ( + np.array([[1, 2], [3, 4]]) if train_key == 'noisy_stars' else np.array([[5, 6], [7, 8]]) ) - - # Check if the mock function was called once with the full batch - assert len(mock_shift_fn.call_args_list) == 1, f"Expected 1 call, but got {len(mock_shift_fn.call_args_list)}" - - # Get the arguments passed to the mock function for the batch of images - args, _ = mock_shift_fn.call_args_list[0] - - print("Shape of args[0]:", args[0].shape) - print("Contents of args[0]:", args[0]) - print("Mock function call args list:", mock_shift_fn.call_args_list) - - # Reshape args[0] to (N, 2) for batch processing - args_array = np.array(args[0]).reshape(-1, 2) - - # Process the displacements and expected values for each image in the batch - expected_dx = reference_shifts[1] - mock_instance.get_intra_pixel_shifts.return_value[:, 1] # Expected x-axis shift in meters - - expected_dy = reference_shifts[0] - mock_instance.get_intra_pixel_shifts.return_value[:, 0] # Expected y-axis shift in meters - - # Compare expected values with the actual arguments passed to the mock function - np.testing.assert_allclose(args_array[:, 0], expected_dx * pixel_sampling, rtol=1e-7, atol=0) - np.testing.assert_allclose(args_array[:, 1], expected_dy * pixel_sampling, rtol=1e-7, atol=0) - - # Expected values based on mock side_effect (0.5 * shift) - np.testing.assert_allclose(zernike_corrections[:, 0], expected_dx * pixel_sampling * 0.5) # Zk1 for each image - np.testing.assert_allclose(zernike_corrections[:, 1], expected_dy * pixel_sampling * 0.5) # Zk2 for each image - + + # Mock compute_zernike_tip_tilt assuming no masks + mock_compute_zernike_tip_tilt.return_value = np.array([[0.1, 0.2], [0.3, 0.4]]) + + # Call function under test + result = compute_centroid_correction(model_params, mock_data) + + # Validate result shape + assert result.shape == (4, 3) # (n_stars, 3 Zernike components) + + # Validate expected values (adjust based on behavior) + expected_result = np.array([ + [0, -0.1, -0.2], # From training data + [0, -0.3, -0.4], + [0, -0.1, -0.2], # From test data (reused mocked return) + [0, -0.3, -0.4] + ]) + assert np.allclose(result, expected_result) # Test for centroid calculation without mask diff --git a/src/wf_psf/tests/test_data/conftest.py b/src/wf_psf/tests/test_data/conftest.py index 6159d53a..04a56893 100644 --- a/src/wf_psf/tests/test_data/conftest.py +++ b/src/wf_psf/tests/test_data/conftest.py @@ -9,8 +9,11 @@ """ import pytest +import numpy as np +import tensorflow as tf from wf_psf.utils.read_config import RecursiveNamespace from wf_psf.psf_models import psf_models +from wf_psf.tests.test_data.test_data_utils import MockData, MockDataset training_config = RecursiveNamespace( id_name="-coherent_euclid_200stars", @@ -93,6 +96,64 @@ ) +@pytest.fixture +def mock_data(scope="module"): + """Fixture to provide mock data for testing.""" + # Mock positions and Zernike priors + training_positions = np.array([[1, 2], [3, 4]]) + test_positions = np.array([[5, 6], [7, 8]]) + training_zernike_priors = np.array([[0.1, 0.2], [0.3, 0.4]]) + test_zernike_priors = np.array([[0.5, 0.6], [0.7, 0.8]]) + + # Define dummy 5x5 image patches for stars (mock star images) + # Define varied values for 5x5 star images + noisy_stars = tf.constant([ + np.arange(25).reshape(5, 5), + np.arange(25, 50).reshape(5, 5) + ], dtype=tf.float32) + + noisy_masks = tf.constant([ + np.eye(5), + np.ones((5, 5)) + ], dtype=tf.float32) + + stars = tf.constant([ + np.full((5, 5), 100), + np.full((5, 5), 200) + ], dtype=tf.float32) + + masks = tf.constant([ + np.zeros((5, 5)), + np.tri(5) + ], dtype=tf.float32) + + return MockData( + training_positions, test_positions, training_zernike_priors, + test_zernike_priors, noisy_stars, noisy_masks, stars, masks + ) + +@pytest.fixture +def simple_image(scope="module"): + """Fixture for a simple star image.""" + num_images = 1 # Change this to test with multiple images + image = np.zeros((num_images, 5, 5)) # Create a 3D array + image[:, 2, 2] = 1 # Place the star at the center for each image + return image + +@pytest.fixture +def identity_mask(scope="module"): + """Creates a mask where all pixels are fully considered.""" + return np.ones((5, 5)) + +@pytest.fixture +def multiple_images(scope="module"): + """Fixture for a batch of images with stars at different positions.""" + images = np.zeros((3, 5, 5)) # 3 images, each of size 5x5 + images[0, 2, 2] = 1 # Star at center of image 0 + images[1, 1, 3] = 1 # Star at (1, 3) in image 1 + images[2, 3, 1] = 1 # Star at (3, 1) in image 2 + return images + @pytest.fixture(scope="module", params=[data]) def data_params(): return data diff --git a/src/wf_psf/tests/test_data/training_preprocessing_test.py b/src/wf_psf/tests/test_data/data_handler_test.py similarity index 55% rename from src/wf_psf/tests/test_data/training_preprocessing_test.py rename to src/wf_psf/tests/test_data/data_handler_test.py index 6769d0e5..3a908274 100644 --- a/src/wf_psf/tests/test_data/training_preprocessing_test.py +++ b/src/wf_psf/tests/test_data/data_handler_test.py @@ -1,66 +1,15 @@ import pytest import numpy as np import tensorflow as tf -from wf_psf.utils.read_config import RecursiveNamespace -from wf_psf.data.training_preprocessing import ( +from wf_psf.data.data_handler import ( DataHandler, get_obs_positions, - get_zernike_prior, extract_star_data, - compute_centroid_correction, ) +from wf_psf.utils.read_config import RecursiveNamespace import logging from unittest.mock import patch -class MockData: - def __init__( - self, - training_positions, - test_positions, - training_zernike_priors, - test_zernike_priors, - noisy_stars=None, - noisy_masks=None, - stars=None, - masks=None, - ): - self.training_data = MockDataset( - positions=training_positions, - zernike_priors=training_zernike_priors, - star_type="noisy_stars", - stars=noisy_stars, - masks=noisy_masks) - self.test_data = MockDataset( - positions=test_positions, - zernike_priors=test_zernike_priors, - star_type="stars", - stars=stars, - masks=masks) - - -class MockDataset: - def __init__(self, positions, zernike_priors, star_type, stars, masks): - self.dataset = {"positions": positions, "zernike_prior": zernike_priors, star_type: stars, "masks": masks} - - -@pytest.fixture -def mock_data(): - # Mock data for testing - # Mock training and test positions and Zernike priors - training_positions = np.array([[1, 2], [3, 4]]) - test_positions = np.array([[5, 6], [7, 8]]) - training_zernike_priors = np.array([[0.1, 0.2], [0.3, 0.4]]) - test_zernike_priors = np.array([[0.5, 0.6], [0.7, 0.8]]) - # Mock noisy stars, stars and masks - noisy_stars = tf.constant([[1, 2], [3, 4]], dtype=tf.float32) - noisy_masks = tf.constant([[1], [0]], dtype=tf.float32) - stars = tf.constant([[5, 6], [7, 8]], dtype=tf.float32) - masks = tf.constant([[0], [1]], dtype=tf.float32) - - return MockData( - training_positions, test_positions, training_zernike_priors, test_zernike_priors, noisy_stars, noisy_masks, stars, masks - ) - def test_process_sed_data(data_params, simPSF): # Test processing SED data without initialization @@ -164,7 +113,7 @@ def test_load_train_dataset_missing_noisy_stars(tmp_path, data_params, simPSF): n_bins_lambda = 10 data_handler = DataHandler("training", data_params, simPSF, n_bins_lambda, load_data=False) - with patch("wf_psf.data.training_preprocessing.logger.warning") as mock_warning: + with patch("wf_psf.data.data_handler.logger.warning") as mock_warning: data_handler.load_dataset() mock_warning.assert_called_with("Missing 'noisy_stars' in training dataset.") @@ -188,7 +137,7 @@ def test_load_test_dataset_missing_stars(tmp_path, data_params, simPSF): n_bins_lambda = 10 data_handler = DataHandler("test", data_params, simPSF, n_bins_lambda, load_data=False) - with patch("wf_psf.data.training_preprocessing.logger.warning") as mock_warning: + with patch("wf_psf.data.data_handler.logger.warning") as mock_warning: data_handler.load_dataset() mock_warning.assert_called_with("Missing 'stars' in test dataset.") @@ -221,48 +170,38 @@ def test_get_obs_positions(mock_data): assert tf.reduce_all(tf.equal(observed_positions, expected_positions)) -def test_get_zernike_prior(model_params, mock_data): - zernike_priors = get_zernike_prior(model_params, mock_data) - expected_shape = ( - 4, - 2, - ) # Assuming 2 Zernike priors for each dataset (training and test) - assert zernike_priors.shape == expected_shape - - -def test_get_zernike_prior_dtype(model_params, mock_data): - zernike_priors = get_zernike_prior(model_params, mock_data) - assert zernike_priors.dtype == np.float32 - - -def test_get_zernike_prior_concatenation(model_params, mock_data): - zernike_priors = get_zernike_prior(model_params, mock_data) - expected_zernike_priors = tf.convert_to_tensor( - np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]]), dtype=tf.float32 - ) - - assert np.array_equal(zernike_priors, expected_zernike_priors) - - -def test_get_zernike_prior_empty_data(model_params): - empty_data = MockData(np.array([]), np.array([]), np.array([]), np.array([])) - zernike_priors = get_zernike_prior(model_params, empty_data) - assert zernike_priors.shape == tf.TensorShape([0]) # Check for empty array shape - def test_extract_star_data_valid_keys(mock_data): """Test extracting valid data from the dataset.""" result = extract_star_data(mock_data, train_key="noisy_stars", test_key="stars") - expected = np.array([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=np.float32) + expected = tf.concat([ + tf.constant([ + np.arange(25).reshape(5, 5), + np.arange(25, 50).reshape(5, 5) + ], dtype=tf.float32), + tf.constant([ + np.full((5, 5), 100), + np.full((5, 5), 200) + ], dtype=tf.float32), + ], axis=0) + np.testing.assert_array_equal(result, expected) + def test_extract_star_data_masks(mock_data): """Test extracting star masks from the dataset.""" result = extract_star_data(mock_data, train_key="masks", test_key="masks") - - expected = np.array([[1], [0], [0], [1]], dtype=np.float32) + + mask0 = np.eye(5, dtype=np.float32) + mask1 = np.ones((5, 5), dtype=np.float32) + mask2 = np.zeros((5, 5), dtype=np.float32) + mask3 = np.tri(5, dtype=np.float32) + + expected = np.array([mask0, mask1, mask2, mask3], dtype=np.float32) + np.testing.assert_array_equal(result, expected) + def test_extract_star_data_missing_key(mock_data): """Test that the function raises a KeyError when a key is missing.""" with pytest.raises(KeyError, match="Missing keys in dataset: \\['invalid_key'\\]"): @@ -281,74 +220,6 @@ def test_extract_star_data_tensor_conversion(mock_data): assert isinstance(result, np.ndarray), "The result should be a NumPy array" assert result.dtype == np.float32, "The NumPy array should have dtype float32" -def test_compute_centroid_correction_with_masks(mock_data): - """Test compute_centroid_correction function with masks present.""" - # Given that compute_centroid_correction expects a model_params and data object - model_params = RecursiveNamespace( - pix_sampling=12e-6, # Example pixel sampling in meters - correct_centroids=True, - reference_shifts=["-1/3", "-1/3"] - ) - - # Mock the internal function calls: - with patch('wf_psf.data.training_preprocessing.extract_star_data') as mock_extract_star_data, \ - patch('wf_psf.data.training_preprocessing.compute_zernike_tip_tilt') as mock_compute_zernike_tip_tilt: - - # Mock the return values of extract_star_data and compute_zernike_tip_tilt - mock_extract_star_data.side_effect = lambda data, train_key, test_key: ( - np.array([[1, 2], [3, 4]]) if train_key == 'noisy_stars' else np.array([[5, 6], [7, 8]]) - ) - mock_compute_zernike_tip_tilt.return_value = np.array([[0.1, 0.2], [0.3, 0.4]]) - - # Call the function under test - result = compute_centroid_correction(model_params, mock_data) - - # Ensure the result has the correct shape - assert result.shape == (4, 3) # Should be (n_stars, 3 Zernike components) - - assert np.allclose(result[0, :], np.array([0, -0.1, -0.2])) # First star Zernike coefficients - assert np.allclose(result[1, :], np.array([0, -0.3, -0.4])) # Second star Zernike coefficients - - -def test_compute_centroid_correction_without_masks(mock_data): - """Test compute_centroid_correction function when no masks are provided.""" - # Remove masks from mock_data - mock_data.test_data.dataset["masks"] = None - mock_data.training_data.dataset["masks"] = None - - # Define model parameters - model_params = RecursiveNamespace( - pix_sampling=12e-6, # Example pixel sampling in meters - correct_centroids=True, - reference_shifts=["-1/3", "-1/3"] - ) - - # Mock internal function calls - with patch('wf_psf.data.training_preprocessing.extract_star_data') as mock_extract_star_data, \ - patch('wf_psf.data.training_preprocessing.compute_zernike_tip_tilt') as mock_compute_zernike_tip_tilt: - - # Mock extract_star_data to return synthetic star postage stamps - mock_extract_star_data.side_effect = lambda data, train_key, test_key: ( - np.array([[1, 2], [3, 4]]) if train_key == 'noisy_stars' else np.array([[5, 6], [7, 8]]) - ) - - # Mock compute_zernike_tip_tilt assuming no masks - mock_compute_zernike_tip_tilt.return_value = np.array([[0.1, 0.2], [0.3, 0.4]]) - - # Call function under test - result = compute_centroid_correction(model_params, mock_data) - - # Validate result shape - assert result.shape == (4, 3) # (n_stars, 3 Zernike components) - - # Validate expected values (adjust based on behavior) - expected_result = np.array([ - [0, -0.1, -0.2], # From training data - [0, -0.3, -0.4], - [0, -0.1, -0.2], # From test data (reused mocked return) - [0, -0.3, -0.4] - ]) - assert np.allclose(result, expected_result) def test_reference_shifts_broadcasting(): reference_shifts = [-1/3, -1/3] # Example reference_shifts diff --git a/src/wf_psf/tests/test_data/data_zernike_utils_test.py b/src/wf_psf/tests/test_data/data_zernike_utils_test.py new file mode 100644 index 00000000..692624be --- /dev/null +++ b/src/wf_psf/tests/test_data/data_zernike_utils_test.py @@ -0,0 +1,133 @@ + +import pytest +import numpy as np +import tensorflow as tf +from wf_psf.data.data_zernike_utils import ( + get_zernike_prior, + compute_zernike_tip_tilt, +) +from wf_psf.tests.test_data.test_data_utils import MockData, MockDataset + +def test_get_zernike_prior(model_params, mock_data): + zernike_priors = get_zernike_prior(model_params, mock_data) + expected_shape = ( + 4, + 2, + ) # Assuming 2 Zernike priors for each dataset (training and test) + assert zernike_priors.shape == expected_shape + + +def test_get_zernike_prior_dtype(model_params, mock_data): + zernike_priors = get_zernike_prior(model_params, mock_data) + assert zernike_priors.dtype == np.float32 + + +def test_get_zernike_prior_concatenation(model_params, mock_data): + zernike_priors = get_zernike_prior(model_params, mock_data) + expected_zernike_priors = tf.convert_to_tensor( + np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]]), dtype=tf.float32 + ) + + assert np.array_equal(zernike_priors, expected_zernike_priors) + + +def test_get_zernike_prior_empty_data(model_params): + empty_data = MockData(np.array([]), np.array([]), np.array([]), np.array([])) + zernike_priors = get_zernike_prior(model_params, empty_data) + assert zernike_priors.shape == tf.TensorShape([0]) # Check for empty array shape + + +def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_mask): + """Test compute_zernike_tip_tilt with single batch input and mocks.""" + + # Mock the CentroidEstimator class + mock_centroid_calc = mocker.patch("wf_psf.data.centroids.CentroidEstimator", autospec=True) + + # Create a mock instance and configure get_intra_pixel_shifts() + mock_instance = mock_centroid_calc.return_value + mock_instance.get_intra_pixel_shifts.return_value = np.array([[0.05, -0.02]]) # Shape (1, 2) + + # Mock shift_x_y_to_zk1_2_wavediff to return predictable values + mock_shift_fn = mocker.patch( + "wf_psf.data.data_zernike_utils.shift_x_y_to_zk1_2_wavediff", + side_effect=lambda shift: shift * 0.5 # Mocked conversion for test + ) + + # Define test inputs (batch of 1 image) + pixel_sampling = 12e-6 + reference_shifts = [-1 / 3, -1 / 3] # Default Euclid conditions + + # Run the function + zernike_corrections = compute_zernike_tip_tilt(simple_image, identity_mask, pixel_sampling, reference_shifts) + zernike_corrections = compute_zernike_tip_tilt(simple_image, identity_mask, pixel_sampling, reference_shifts) + + # Expected shifts based on centroid calculation + expected_dx = (reference_shifts[1] - (-0.02)) # Expected x-axis shift in meters + expected_dy = (reference_shifts[0] - 0.05) # Expected y-axis shift in meters + + # Expected calls to the mocked function + # Extract the arguments passed to mock_shift_fn + args, _ = mock_shift_fn.call_args_list[0] # Get the first call args + + # Compare expected values with the actual arguments passed to the mock function + np.testing.assert_allclose(args[0][0], expected_dx * pixel_sampling, rtol=1e-7, atol=0) + + # Check dy values similarly + np.testing.assert_allclose(args[0][1], expected_dy * pixel_sampling, rtol=1e-7, atol=0) + + # Expected values based on mock side_effect (0.5 * shift) + np.testing.assert_allclose(zernike_corrections[0, 0], expected_dx * pixel_sampling * 0.5) # Zk1 + np.testing.assert_allclose(zernike_corrections[0, 1], expected_dy * pixel_sampling * 0.5) # Zk2 + +def test_compute_zernike_tip_tilt_batch(mocker, multiple_images): + """Test compute_zernike_tip_tilt with batch input and mocks.""" + + # Mock the CentroidEstimator class + mock_centroid_calc = mocker.patch("wf_psf.data.centroids.CentroidEstimator", autospec=True) + + # Create a mock instance and configure get_intra_pixel_shifts() + mock_instance = mock_centroid_calc.return_value + mock_instance.get_intra_pixel_shifts.return_value = np.array([[0.05, -0.02], [0.04, -0.01], [0.06, -0.03]]) # Shape (3, 2) + + # Mock shift_x_y_to_zk1_2_wavediff to return predictable values + mock_shift_fn = mocker.patch( + "wf_psf.data.data_zernike_utils.shift_x_y_to_zk1_2_wavediff", + side_effect=lambda shift: shift * 0.5 # Mocked conversion for test + ) + + # Define test inputs (batch of 3 images) + pixel_sampling = 12e-6 + reference_shifts = [-1 / 3, -1 / 3] # Default Euclid conditions + + # Run the function + zernike_corrections = compute_zernike_tip_tilt( + star_images=multiple_images, + pixel_sampling=pixel_sampling, + reference_shifts=reference_shifts + ) + + # Check if the mock function was called once with the full batch + assert len(mock_shift_fn.call_args_list) == 1, f"Expected 1 call, but got {len(mock_shift_fn.call_args_list)}" + + # Get the arguments passed to the mock function for the batch of images + args, _ = mock_shift_fn.call_args_list[0] + + print("Shape of args[0]:", args[0].shape) + print("Contents of args[0]:", args[0]) + print("Mock function call args list:", mock_shift_fn.call_args_list) + + # Reshape args[0] to (N, 2) for batch processing + args_array = np.array(args[0]).reshape(-1, 2) + + # Process the displacements and expected values for each image in the batch + expected_dx = reference_shifts[1] - mock_instance.get_intra_pixel_shifts.return_value[:, 1] # Expected x-axis shift in meters + + expected_dy = reference_shifts[0] - mock_instance.get_intra_pixel_shifts.return_value[:, 0] # Expected y-axis shift in meters + + # Compare expected values with the actual arguments passed to the mock function + np.testing.assert_allclose(args_array[:, 0], expected_dx * pixel_sampling, rtol=1e-7, atol=0) + np.testing.assert_allclose(args_array[:, 1], expected_dy * pixel_sampling, rtol=1e-7, atol=0) + + # Expected values based on mock side_effect (0.5 * shift) + np.testing.assert_allclose(zernike_corrections[:, 0], expected_dx * pixel_sampling * 0.5) # Zk1 for each image + np.testing.assert_allclose(zernike_corrections[:, 1], expected_dy * pixel_sampling * 0.5) # Zk2 for each image \ No newline at end of file diff --git a/src/wf_psf/tests/test_data/test_data_utils.py b/src/wf_psf/tests/test_data/test_data_utils.py new file mode 100644 index 00000000..a5ead298 --- /dev/null +++ b/src/wf_psf/tests/test_data/test_data_utils.py @@ -0,0 +1,30 @@ + +class MockDataset: + def __init__(self, positions, zernike_priors, star_type, stars, masks): + self.dataset = {"positions": positions, "zernike_prior": zernike_priors, star_type: stars, "masks": masks} + +class MockData: + def __init__( + self, + training_positions, + test_positions, + training_zernike_priors, + test_zernike_priors, + noisy_stars=None, + noisy_masks=None, + stars=None, + masks=None, + ): + self.training_data = MockDataset( + positions=training_positions, + zernike_priors=training_zernike_priors, + star_type="noisy_stars", + stars=noisy_stars, + masks=noisy_masks) + self.test_data = MockDataset( + positions=test_positions, + zernike_priors=test_zernike_priors, + star_type="stars", + stars=stars, + masks=masks) + diff --git a/src/wf_psf/tests/test_metrics/metrics_interface_test.py b/src/wf_psf/tests/test_metrics/metrics_interface_test.py index f3e19441..225ba894 100644 --- a/src/wf_psf/tests/test_metrics/metrics_interface_test.py +++ b/src/wf_psf/tests/test_metrics/metrics_interface_test.py @@ -2,7 +2,7 @@ from unittest.mock import patch, MagicMock import pytest from wf_psf.metrics.metrics_interface import evaluate_model, MetricsParamsHandler -from wf_psf.data.training_preprocessing import DataHandler +from wf_psf.data.data_handler import DataHandler @pytest.fixture def mock_metrics_params(): diff --git a/src/wf_psf/tests/test_psf_models/conftest.py b/src/wf_psf/tests/test_psf_models/conftest.py index cbaae8d9..4693b343 100644 --- a/src/wf_psf/tests/test_psf_models/conftest.py +++ b/src/wf_psf/tests/test_psf_models/conftest.py @@ -12,7 +12,7 @@ from wf_psf.utils.read_config import RecursiveNamespace from wf_psf.training.train import TrainingParamsHandler from wf_psf.psf_models import psf_models -from wf_psf.data.training_preprocessing import DataHandler +from wf_psf.data.data_handler import DataHandler training_config = RecursiveNamespace( id_name="_sample_w_bis1_2k", diff --git a/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py b/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py index e25d608d..7e967465 100644 --- a/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py +++ b/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py @@ -59,7 +59,7 @@ def test_initialize_parameters(mocker, mock_data, mock_model_params, zks_prior): ) mocker.patch( - "wf_psf.data.training_preprocessing.get_obs_positions", return_value=True + "wf_psf.data.data_handler.get_obs_positions", return_value=True ) # Create TFPhysicalPolychromaticField instance diff --git a/src/wf_psf/tests/test_utils/configs_handler_test.py b/src/wf_psf/tests/test_utils/configs_handler_test.py index 18fdc86f..ebddac66 100644 --- a/src/wf_psf/tests/test_utils/configs_handler_test.py +++ b/src/wf_psf/tests/test_utils/configs_handler_test.py @@ -7,7 +7,7 @@ """ import pytest -from wf_psf.data.training_preprocessing import DataHandler +from wf_psf.data.data_handler import DataHandler from wf_psf.utils import configs_handler from wf_psf.utils.read_config import RecursiveNamespace from wf_psf.utils.io import FileIOHandler diff --git a/src/wf_psf/utils/configs_handler.py b/src/wf_psf/utils/configs_handler.py index 7b391f27..2260abe3 100644 --- a/src/wf_psf/utils/configs_handler.py +++ b/src/wf_psf/utils/configs_handler.py @@ -12,7 +12,7 @@ import os import re import glob -from wf_psf.data.training_preprocessing import DataHandler +from wf_psf.data.data_handler import DataHandler from wf_psf.metrics.metrics_interface import evaluate_model from wf_psf.plotting.plots_interface import plot_metrics from wf_psf.psf_models import psf_models From 478dfb27bb124fd439b97300fac4ed27236bfae6 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 16 May 2025 14:44:07 +0200 Subject: [PATCH 017/146] Update import statements to new module names --- .../psf_models/models/psf_model_physical_polychromatic.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index f9ed8765..baec8923 100644 --- a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -10,7 +10,8 @@ from typing import Optional import tensorflow as tf from tensorflow.python.keras.engine import data_adapter -from wf_psf.data.training_preprocessing import get_obs_positions, get_zernike_prior +from wf_psf.data.data_handler import get_obs_positions +from wf_psf.data.data_zernike_utils import get_zernike_prior from wf_psf.psf_models import psf_models as psfm from wf_psf.psf_models.tf_modules.tf_layers import ( TFPolynomialZernikeField, From d79e89dca9d524e6fbb52dd8f04783256d937ecb Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 16 May 2025 14:57:57 +0200 Subject: [PATCH 018/146] Update DataHandler class docstring to include option for inference dataset handling --- src/wf_psf/data/data_handler.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index fd0e3474..692b482e 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -25,7 +25,8 @@ class DataHandler: """Data Handler. - This class manages loading and processing of training and testing data for use during PSF model training and validation. + This class manages loading and processing of training, testing and inference data for use during PSF model training, inference, and validation. + It provides methods to access and preprocess the data. Parameters @@ -45,9 +46,9 @@ class DataHandler: Attributes ---------- dataset_type: str - A string indicating the type of dataset ("train" or "test"). + A string indicating the type of dataset ("train", "test" or "inference"). data_params: Recursive Namespace object - A Recursive Namespace object containing training or test data parameters. + A Recursive Namespace object containing training, test or inference data parameters. dataset: dict A dictionary containing the loaded dataset, including positions and stars/noisy_stars. simPSF: object @@ -68,9 +69,9 @@ def __init__(self, dataset_type, data_params, simPSF, n_bins_lambda, load_data: Parameters ---------- dataset_type : str - A string indicating the type of data ("train" or "test"). + A string indicating the type of data ("train", "test", or "inference"). data_params : Recursive Namespace object - A Recursive Namespace object containing parameters for both 'train' and 'test' datasets. + A Recursive Namespace object containing parameters for both 'train', 'test', 'inference' datasets. simPSF : PSFSimulator An instance of the PSFSimulator class for simulating a PSF. n_bins_lambda : int From e9e066ed995040cd9c3c2453c8dca1e3f3f1ec68 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sun, 18 May 2025 12:22:27 +0200 Subject: [PATCH 019/146] Refactor data_handler with new utility functions to validate and process datasets and update docstrings --- src/wf_psf/data/data_handler.py | 200 ++++++++++++++++++++++++-------- 1 file changed, 149 insertions(+), 51 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index 692b482e..c6d51a0d 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -17,78 +17,122 @@ import wf_psf.utils.utils as utils import tensorflow as tf from fractions import Fraction +from typing import Optional, Union import logging logger = logging.getLogger(__name__) class DataHandler: - """Data Handler. + """ + DataHandler for WaveDiff PSF modeling. - This class manages loading and processing of training, testing and inference data for use during PSF model training, inference, and validation. - - It provides methods to access and preprocess the data. + This class manages loading, preprocessing, and TensorFlow conversion of datasets used + for PSF model training, testing, and inference in the WaveDiff framework. Parameters ---------- - dataset_type: str - A string indicating type of data ("train" or "test"). - data_params: Recursive Namespace object - Recursive Namespace object containing training data parameters - simPSF: PSFSimulator - An instance of the PSFSimulator class for simulating a PSF. - n_bins_lambda: int - The number of bins in wavelength. - load_data: bool, optional - A flag used to control data loading steps. If True, data is loaded and processed - during initialization. If False, data loading is deferred until explicitly called. + dataset_type : str + Indicates the dataset mode ("train", "test", or "inference"). + data_params : RecursiveNamespace + Configuration object containing dataset parameters (e.g., file paths, preprocessing flags). + simPSF : PSFSimulator + An instance of the PSFSimulator class used to encode SEDs into a TensorFlow-compatible format. + n_bins_lambda : int + Number of wavelength bins used to discretize SEDs. + load_data : bool, optional + If True (default), loads and processes data during initialization. If False, data loading + must be triggered explicitly. + dataset : dict or list, optional + If provided, uses this pre-loaded dataset instead of triggering automatic loading. + sed_data : dict or list, optional + If provided, uses this SED data directly instead of extracting it from the dataset. Attributes ---------- - dataset_type: str - A string indicating the type of dataset ("train", "test" or "inference"). - data_params: Recursive Namespace object - A Recursive Namespace object containing training, test or inference data parameters. - dataset: dict - A dictionary containing the loaded dataset, including positions and stars/noisy_stars. - simPSF: object - An instance of the SimPSFToolkit class for simulating PSF. - n_bins_lambda: int - The number of bins in wavelength. - sed_data: tf.Tensor - A TensorFlow tensor containing the SED data for training/testing. - load_data_on_init: bool, optional - A flag used to control data loading steps. If True, data is loaded and processed - during initialization. If False, data loading is deferred until explicitly called. + dataset_type : str + Indicates the dataset mode ("train", "test", or "inference"). + data_params : RecursiveNamespace + Configuration parameters for data access and structure. + simPSF : PSFSimulator + Simulator used to transform SEDs into TensorFlow-ready tensors. + n_bins_lambda : int + Number of wavelength bins in the SED representation. + load_data_on_init : bool + Whether data was loaded automatically during initialization. + dataset : dict + Loaded dataset including keys such as 'positions', 'stars', 'noisy_stars', or similar. + sed_data : tf.Tensor + TensorFlow-formatted SED data with shape [batch_size, n_bins_lambda, features]. """ - def __init__(self, dataset_type, data_params, simPSF, n_bins_lambda, load_data: bool=True): + def __init__( + self, + dataset_type, + data_params, + simPSF, + n_bins_lambda, + load_data: bool = True, + dataset: Optional[Union[dict, list]] = None, + sed_data: Optional[Union[dict, list]] = None, + ): """ - Initialize the dataset handler for PSF simulation. + Initialize the DataHandler for PSF dataset preparation. + + This constructor sets up the dataset handler used for PSF simulation tasks, + such as training, testing, or inference. It supports three modes of use: + + 1. **Manual mode** (`load_data=False`, no `dataset`): data loading and SED processing + must be triggered manually via `load_dataset()` and `process_sed_data()`. + 2. **Pre-loaded dataset mode** (`dataset` is provided): the given dataset is used directly, + and `process_sed_data()` is called with either the given `sed_data` or `dataset["SEDs"]`. + 3. **Automatic loading mode** (`load_data=True` and no `dataset`): the dataset is loaded + from disk using `data_params`, and SEDs are extracted and processed automatically. Parameters ---------- dataset_type : str - A string indicating the type of data ("train", "test", or "inference"). - data_params : Recursive Namespace object - A Recursive Namespace object containing parameters for both 'train', 'test', 'inference' datasets. + One of {"train", "test", "inference"} indicating dataset usage. + data_params : RecursiveNamespace + Configuration object with paths, preprocessing options, and metadata. simPSF : PSFSimulator - An instance of the PSFSimulator class for simulating a PSF. + Used to convert SEDs to TensorFlow format. n_bins_lambda : int - The number of bins in wavelength. + Number of wavelength bins for the SEDs. load_data : bool, optional - A flag to control whether data should be loaded and processed during initialization. - If True, data is loaded and processed during initialization; if False, data loading - is deferred until explicitly called. + Whether to automatically load and process the dataset (default: True). + dataset : dict or list, optional + A pre-loaded dataset to use directly (overrides `load_data`). + sed_data : array-like, optional + Pre-loaded SED data to use directly. If not provided but `dataset` is, + SEDs are taken from `dataset["SEDs"]`. + + Raises + ------ + ValueError + If SEDs cannot be found in either `dataset` or as `sed_data`. + + Notes + ----- + - `self.dataset` and `self.sed_data` are both `None` if neither `dataset` nor + `load_data=True` is used. + - TensorFlow conversion is performed at the end of initialization via `convert_dataset_to_tensorflow()`. """ + self.dataset_type = dataset_type - self.data_params = data_params.__dict__[dataset_type] + self.data_params = data_params self.simPSF = simPSF self.n_bins_lambda = n_bins_lambda self.load_data_on_init = load_data - if self.load_data_on_init: + + if dataset is not None: + self.dataset = dataset + self.process_sed_data(sed_data) + self.validate_and_process_dataset() + elif self.load_data_on_init: self.load_dataset() - self.process_sed_data() + self.process_sed_data(self.dataset["SEDs"]) + self.validate_and_process_dataset() else: self.dataset = None self.sed_data = None @@ -104,6 +148,36 @@ def load_dataset(self): os.path.join(self.data_params.data_dir, self.data_params.file), allow_pickle=True, )[()] + + def validate_and_process_dataset(self): + """Validate the dataset structure and convert fields to TensorFlow tensors.""" + self._validate_dataset_structure() + self._convert_dataset_to_tensorflow() + + + def _validate_dataset_structure(self): + """Validate dataset structure based on dataset_type.""" + if self.dataset is None: + raise ValueError("Dataset is None") + + if "positions" not in self.dataset: + raise ValueError("Dataset missing required field: 'positions'") + + if self.dataset_type == "train": + if "noisy_stars" not in self.dataset: + logger.warning("Missing 'noisy_stars' in 'train' dataset.") + elif self.dataset_type == "test": + if "stars" not in self.dataset: + logger.warning("Missing 'stars' in 'test' dataset.") + elif self.dataset_type == "inference": + pass + else: + logger.warning(f"Unrecognized dataset_type: {self.dataset_type}") + + + def _convert_dataset_to_tensorflow(self): + """Convert dataset to TensorFlow tensors.""" + self.dataset["positions"] = tf.convert_to_tensor( self.dataset["positions"], dtype=tf.float32 ) @@ -119,22 +193,46 @@ def load_dataset(self): self.dataset["stars"] = tf.convert_to_tensor( self.dataset["stars"], dtype=tf.float32 ) - else: - logger.warning(f"Missing 'stars' in {self.dataset_type} dataset.") - elif "inference" == self.dataset_type: - pass + - def process_sed_data(self): - """Process SED Data. + def process_sed_data(self, sed_data): + """ + Generate and process SED (Spectral Energy Distribution) data. - A method to generate and process SED data. + This method transforms raw SED inputs into TensorFlow tensors suitable for model input. + It generates wavelength-binned SED elements using the PSF simulator, converts the result + into a tensor, and transposes it to match the expected shape for training or inference. + Parameters + ---------- + sed_data : list or array-like + A list or array of raw SEDs, where each SED is typically a vector of flux values + or coefficients. These will be processed using the PSF simulator. + + Raises + ------ + ValueError + If `sed_data` is None. + + Notes + ----- + The resulting tensor is stored in `self.sed_data` and has shape + `(num_samples, n_bins_lambda, n_components)`, where: + - `num_samples` is the number of SEDs, + - `n_bins_lambda` is the number of wavelength bins, + - `n_components` is the number of components per SED (e.g., filters or basis terms). + + The intermediate tensor is created with `tf.float64` for precision during generation, + but is converted to `tf.float32` after processing for use in training. """ + if sed_data is None: + raise ValueError("SED data must be provided explicitly or via dataset.") + self.sed_data = [ utils.generate_SED_elems_in_tensorflow( _sed, self.simPSF, n_bins=self.n_bins_lambda, tf_dtype=tf.float64 ) - for _sed in self.dataset["SEDs"] + for _sed in sed_data ] self.sed_data = tf.convert_to_tensor(self.sed_data, dtype=tf.float32) self.sed_data = tf.transpose(self.sed_data, perm=[0, 2, 1]) From cfddcdfb4d42a98e274543390db0ee7789c7574c Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sun, 18 May 2025 12:23:27 +0200 Subject: [PATCH 020/146] Update unit tests associated to changes in data_handler.py --- .../tests/test_data/data_handler_test.py | 53 +++++++++++++------ .../tests/test_utils/configs_handler_test.py | 19 +++++-- 2 files changed, 52 insertions(+), 20 deletions(-) diff --git a/src/wf_psf/tests/test_data/data_handler_test.py b/src/wf_psf/tests/test_data/data_handler_test.py index 3a908274..838db8cb 100644 --- a/src/wf_psf/tests/test_data/data_handler_test.py +++ b/src/wf_psf/tests/test_data/data_handler_test.py @@ -10,6 +10,9 @@ import logging from unittest.mock import patch +def mock_sed(): + # Create a fake SED with shape (n_wavelengths,) — match what your real SEDs look like + return np.linspace(0.1, 1.0, 50) def test_process_sed_data(data_params, simPSF): # Test processing SED data without initialization @@ -18,11 +21,11 @@ def test_process_sed_data(data_params, simPSF): ) assert data_handler.sed_data is None # SED data should not be processed - # Test processing SED data with initialization +def test_process_sed_data_auto_load(data_params, simPSF): + # load_data=True → dataset is used and SEDs processed automatically data_handler = DataHandler( "training", data_params, simPSF, n_bins_lambda=10, load_data=True ) - assert data_handler.sed_data is not None # SED data should be processed def test_load_train_dataset(tmp_path, data_params, simPSF): @@ -82,7 +85,12 @@ def test_load_test_dataset(tmp_path, data_params, simPSF): ) n_bins_lambda = 10 - data_handler = DataHandler("test", data_params, simPSF, n_bins_lambda, load_data=False) + data_handler = DataHandler( + dataset_type="test", + data_params=data_params.test, + simPSF=simPSF, + n_bins_lambda=n_bins_lambda, + load_data=False) # Call the load_dataset method data_handler.load_dataset() @@ -93,8 +101,8 @@ def test_load_test_dataset(tmp_path, data_params, simPSF): assert np.array_equal(data_handler.dataset["SEDs"], mock_dataset["SEDs"]) -def test_load_train_dataset_missing_noisy_stars(tmp_path, data_params, simPSF): - """Test that a warning is raised if 'noisy_stars' is missing in training data.""" +def test_validate_train_dataset_missing_noisy_stars_raises(tmp_path, simPSF): + """Test that validation raises an error if 'noisy_stars' is missing in training data.""" data_dir = tmp_path / "data" data_dir.mkdir() temp_data_file = data_dir / "train_data.npy" @@ -145,24 +153,37 @@ def test_load_test_dataset_missing_stars(tmp_path, data_params, simPSF): def test_process_sed_data(data_params, simPSF): mock_dataset = { "positions": np.array([[1, 2], [3, 4]]), - "noisy_stars": np.array([[5, 6], [7, 8]]), - "SEDs": np.array([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]]), + "SEDs": np.array([ + [[0.1, 0.2], [0.3, 0.4]], + [[0.5, 0.6], [0.7, 0.8]] + ]), + # Missing 'noisy_stars' } # Initialize DataHandler instance n_bins_lambda = 4 data_handler = DataHandler("training", data_params, simPSF, n_bins_lambda, False) - data_handler.dataset = mock_dataset - data_handler.process_sed_data() - # Assertions - assert isinstance(data_handler.sed_data, tf.Tensor) - assert data_handler.sed_data.dtype == tf.float32 - assert data_handler.sed_data.shape == ( - len(data_handler.dataset["positions"]), - n_bins_lambda, - len(["feasible_N", "feasible_wv", "SED_norm"]), + np.save(temp_data_file, mock_dataset) + + data_params = RecursiveNamespace( + data_dir=str(data_dir), file="train_data.npy" ) + data_handler = DataHandler( + dataset_type="train", + data_params=data_params, + simPSF=simPSF, + n_bins_lambda=10, + load_data=False + ) + + data_handler.load_dataset() + data_handler.process_sed_data(mock_dataset["SEDs"]) + + with patch("wf_psf.data.data_handler.logger.warning") as mock_warning: + data_handler._validate_dataset_structure() + mock_warning.assert_called_with("Missing 'noisy_stars' in 'train' dataset.") + def test_get_obs_positions(mock_data): observed_positions = get_obs_positions(mock_data) diff --git a/src/wf_psf/tests/test_utils/configs_handler_test.py b/src/wf_psf/tests/test_utils/configs_handler_test.py index ebddac66..2f2b6c9c 100644 --- a/src/wf_psf/tests/test_utils/configs_handler_test.py +++ b/src/wf_psf/tests/test_utils/configs_handler_test.py @@ -110,7 +110,6 @@ def test_get_run_config(path_to_repo_dir, path_to_tmp_output_dir, path_to_config assert type(config_class) is RegisterConfigClass - def test_data_config_handler_init( mock_training_conf, mock_data_read_conf, mocker ): @@ -123,10 +122,18 @@ def test_data_config_handler_init( "wf_psf.psf_models.psf_models.simPSF", return_value=mock_simPSF_instance ) - # Patch the load_dataset and process_sed_data methods inside DataHandler - mocker.patch.object(DataHandler, "load_dataset") + # Patch process_sed_data method mocker.patch.object(DataHandler, "process_sed_data") + # Patch validate_and_process_datasetmethod + mocker.patch.object(DataHandler, "validate_and_process_dataset") + + # Patch load_dataset to assign dataset + def mock_load_dataset(self): + self.dataset = {"SEDs": ["dummy_sed_data"], "positions": ["dummy_positions_data"]} + + mocker.patch.object(DataHandler, "load_dataset", new=mock_load_dataset) + # Create DataConfigHandler instance data_config_handler = DataConfigHandler( "/path/to/data_config.yaml", @@ -145,7 +152,11 @@ def test_data_config_handler_init( data_config_handler.test_data.n_bins_lambda == mock_training_conf.training.model_params.n_bins_lda ) - assert (data_config_handler.batch_size == mock_training_conf.training.training_hparams.batch_size) # Default value + assert ( + data_config_handler.batch_size + == mock_training_conf.training.training_hparams.batch_size + ) + def test_training_config_handler_init(mocker, mock_training_conf, mock_file_handler): From 9479146c0bae83b5b8bb594c6810d83b089c4b55 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sun, 18 May 2025 12:24:21 +0200 Subject: [PATCH 021/146] Change exception handling in DataConfigHandler; modify args to DataHandler --- src/wf_psf/utils/configs_handler.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/wf_psf/utils/configs_handler.py b/src/wf_psf/utils/configs_handler.py index 2260abe3..a037069f 100644 --- a/src/wf_psf/utils/configs_handler.py +++ b/src/wf_psf/utils/configs_handler.py @@ -129,28 +129,31 @@ class DataConfigHandler: def __init__(self, data_conf, training_model_params, batch_size=16, load_data=True): try: self.data_conf = read_conf(data_conf) - except FileNotFoundError as e: - logger.exception(e) - exit() - except TypeError as e: + except (FileNotFoundError, TypeError) as e: logger.exception(e) exit() self.simPSF = psf_models.simPSF(training_model_params) + + # Extract sub-configs early + train_params = self.data_conf.data.training + test_params = self.data_conf.data.test + self.training_data = DataHandler( dataset_type="training", - data_params=self.data_conf.data, + data_params=train_params, simPSF=self.simPSF, n_bins_lambda=training_model_params.n_bins_lda, load_data=load_data, ) self.test_data = DataHandler( dataset_type="test", - data_params=self.data_conf.data, + data_params=test_params, simPSF=self.simPSF, n_bins_lambda=training_model_params.n_bins_lda, load_data=load_data, ) + self.batch_size = batch_size From abe53ed4f77a1c1cbe99e0d5bc195e4792947575 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Mon, 19 May 2025 10:49:14 +0200 Subject: [PATCH 022/146] Add data and psf_model_imports into inference and sketch out methods --- src/wf_psf/inference/psf_inference.py | 59 +++++++++++++++++++++++++-- 1 file changed, 55 insertions(+), 4 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 4f5b39e5..5437ee61 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -12,10 +12,61 @@ import glob import logging import numpy as np -from wf_psf.psf_models import psf_models, psf_model_loader +from wf_psf.data.data_handler import DataHandler +from wf_psf.psf_models import psf_models +from wf_psf.psf_models.psf_model_loader import load_trained_psf_model import tensorflow as tf +def prepare_inputs(dataset): + + # Convert dataset to tensorflow Dataset + dataset["positions"] = tf.convert_to_tensor(dataset["positions"], dtype=tf.float32) + + + +def get_trained_psf_model(model_path, model_dir_name, cycle, training_conf, data_conf): + + trained_model_path = model_path + model_subdir = model_dir_name + cycle = cycle + + model_name = training_conf.training.model_params.model_name + id_name = training_conf.training.id_name + + weights_path_pattern = os.path.join( + trained_model_path, + model_subdir, + ( + f"{model_subdir}*_{model_name}" + f"*{id_name}_cycle{cycle}*" + ), + ) + return load_trained_psf_model( + training_conf, + data_conf, + weights_path_pattern, + ) + + +def generate_psfs(psf_model, inputs): + pass + + +def run_pipeline(): + psf_model = get_trained_psf_model( + model_path, + model_dir, + cycle, + training_conf, + data_conf + ) + inputs = prepare_inputs( + + ) + psfs = generate_psfs( + psf_model, + inputs, + batch_size=1, + ) + return psfs -#def prepare_inputs(...): ... -#def generate_psfs(...): ... -#def run_pipeline(...): ... From 3d5ffb0f495994229b6ce0b9b0cdd0652a84935c Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 11:56:50 +0200 Subject: [PATCH 023/146] add base psf inference --- src/wf_psf/inference/psf_inference.py | 182 ++++++++++++++++++-------- 1 file changed, 127 insertions(+), 55 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 5437ee61..6787be61 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -1,8 +1,8 @@ """Inference. -A module which provides a set of functions to perform inference -on PSF models. It includes functions to load a trained model, -perform inference on a dataset of SEDs and positions, and generate a polychromatic PSF. +A module which provides a PSFInference class to perform inference +with trained PSF models. It is able to load a trained model, +perform inference on a dataset of SEDs and positions, and generate polychromatic PSFs. :Authors: Jennifer Pollack @@ -13,60 +13,132 @@ import logging import numpy as np from wf_psf.data.data_handler import DataHandler +from wf_psf.utils.read_config import read_conf from wf_psf.psf_models import psf_models from wf_psf.psf_models.psf_model_loader import load_trained_psf_model import tensorflow as tf +from typing import Optional -def prepare_inputs(dataset): - - # Convert dataset to tensorflow Dataset - dataset["positions"] = tf.convert_to_tensor(dataset["positions"], dtype=tf.float32) - - - -def get_trained_psf_model(model_path, model_dir_name, cycle, training_conf, data_conf): - - trained_model_path = model_path - model_subdir = model_dir_name - cycle = cycle - - model_name = training_conf.training.model_params.model_name - id_name = training_conf.training.id_name - - weights_path_pattern = os.path.join( - trained_model_path, - model_subdir, - ( - f"{model_subdir}*_{model_name}" - f"*{id_name}_cycle{cycle}*" - ), - ) - return load_trained_psf_model( - training_conf, - data_conf, - weights_path_pattern, - ) - - -def generate_psfs(psf_model, inputs): - pass - - -def run_pipeline(): - psf_model = get_trained_psf_model( - model_path, - model_dir, - cycle, - training_conf, - data_conf - ) - inputs = prepare_inputs( - - ) - psfs = generate_psfs( - psf_model, - inputs, - batch_size=1, - ) - return psfs +class PSFInference: + """Class to perform inference on PSF models.""" + + def __init__( + self, + trained_model_path: str, + model_subdir: str, + cycle: int, + training_conf_path: str, + data_conf_path: str, + batch_size: Optional[int] = None, + ): + self.trained_model_path = trained_model_path + self.model_subdir = model_subdir + self.cycle = cycle + self.training_conf_path = training_conf_path + self.data_conf_path = data_conf_path + + # Set source parameters + self.x_field = None + self.y_field = None + self.seds = None + self.trained_psf_model = None + + # Load the training and data configurations + self.training_conf = read_conf(training_conf_path) + self.data_conf = read_conf(data_conf_path) + + # Set the number of labmda bins + self.n_bins_lambda = self.training_conf.training.model_params.n_bins_lambda + + # Set the batch size + self.batch_size = ( + batch_size + if batch_size is not None + else self.training_conf.training.model_params.batch_size + ) + + # Instantiate the PSF simulator object + self.simPSF = psf_models.simPSF(self.training_conf.training.model_params) + + # Instantiate the data handler + self.data_handler = DataHandler( + dataset_type="inference", + data_params=self.data_conf, + simPSF=self.simPSF, + n_bins_lambda=self.n_bins_lambda, + load_data=False, + ) + + # Load the trained PSF model + self.trained_psf_model = self.get_trained_psf_model() + + def get_trained_psf_model(self): + """Get the trained PSF model.""" + + model_name = self.training_conf.training.model_params.model_name + id_name = self.training_conf.training.id_name + + weights_path_pattern = os.path.join( + self.trained_model_path, + self.model_subdir, + (f"{self.model_subdir}*_{model_name}" f"*{id_name}_cycle{self.cycle}*"), + ) + return load_trained_psf_model( + self.training_conf, + self.data_conf, + weights_path_pattern, + ) + + def set_source_parameters(self, x_field, y_field, seds): + """Set the input source parameters for inferring the PSF. + + Parameters + ---------- + x_field : array-like + X coordinates of the sources in WaveDiff format. + y_field : array-like + Y coordinates of the sources in WaveDiff format. + seds : list or array-like + A list or array of raw SEDs, where each SED is typically a vector of flux values + or coefficients. These will be processed using the PSF simulator. + It assumes the standard WaveDiff SED format. + + """ + # Positions array is of shape (n_sources, 2) + self.positions = tf.convert_to_tensor( + np.array([x_field, y_field]).T, dtype=tf.float32 + ) + # Process SED data + self.sed_data = self.data_handler.process_sed_data(seds) + + def get_psfs(self): + """Generate PSFs on the input source parameters.""" + + while counter < n_samples: + # Calculate the batch end element + if counter + batch_size <= n_samples: + end_sample = counter + batch_size + else: + end_sample = n_samples + + # Define the batch positions + batch_pos = pos[counter:end_sample, :] + + inputs = [self.positions, self.sed_data] + poly_psfs = self.trained_psf_model(inputs, training=False) + + return poly_psfs + + +# def run_pipeline(): +# psf_model = get_trained_psf_model( +# model_path, model_dir, cycle, training_conf, data_conf +# ) +# inputs = prepare_inputs() +# psfs = generate_psfs( +# psf_model, +# inputs, +# batch_size=1, +# ) +# return psfs From 2e9cbd201178d65ba834eb8192e8e75aefb82741 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 11:57:05 +0200 Subject: [PATCH 024/146] add common call interface through PSF models --- src/wf_psf/psf_models/models/psf_model_parametric.py | 2 +- src/wf_psf/psf_models/models/psf_model_semiparametric.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/wf_psf/psf_models/models/psf_model_parametric.py b/src/wf_psf/psf_models/models/psf_model_parametric.py index 1b095852..5c643a3a 100644 --- a/src/wf_psf/psf_models/models/psf_model_parametric.py +++ b/src/wf_psf/psf_models/models/psf_model_parametric.py @@ -205,7 +205,7 @@ def predict_opd(self, input_positions): return opd_maps - def call(self, inputs): + def call(self, inputs, **kwargs): """Define the PSF field forward model. [1] From positions to Zernike coefficients diff --git a/src/wf_psf/psf_models/models/psf_model_semiparametric.py b/src/wf_psf/psf_models/models/psf_model_semiparametric.py index c370956c..7b2ff04d 100644 --- a/src/wf_psf/psf_models/models/psf_model_semiparametric.py +++ b/src/wf_psf/psf_models/models/psf_model_semiparametric.py @@ -421,7 +421,7 @@ def project_DD_features(self, tf_zernike_cube=None): s_new = self.tf_np_poly_opd.S_mat - s_mat_projected self.assign_S_mat(s_new) - def call(self, inputs): + def call(self, inputs, **kwargs): """Define the PSF field forward model. [1] From positions to Zernike coefficients From b6e3f446ebb4f01a1abbd17cceca1eb905386936 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 14:35:29 +0200 Subject: [PATCH 025/146] add handling of inference params --- src/wf_psf/inference/psf_inference.py | 101 +++++++++++++++++++++----- 1 file changed, 81 insertions(+), 20 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 6787be61..a5ca12a6 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -27,16 +27,15 @@ def __init__( self, trained_model_path: str, model_subdir: str, - cycle: int, training_conf_path: str, data_conf_path: str, - batch_size: Optional[int] = None, + inference_conf_path: str, ): self.trained_model_path = trained_model_path self.model_subdir = model_subdir - self.cycle = cycle self.training_conf_path = training_conf_path self.data_conf_path = data_conf_path + self.inference_conf_path = inference_conf_path # Set source parameters self.x_field = None @@ -44,18 +43,24 @@ def __init__( self.seds = None self.trained_psf_model = None + # Set compute PSF placeholder + self.inferred_psfs = None + # Load the training and data configurations self.training_conf = read_conf(training_conf_path) self.data_conf = read_conf(data_conf_path) + self.inference_conf = read_conf(inference_conf_path) # Set the number of labmda bins - self.n_bins_lambda = self.training_conf.training.model_params.n_bins_lambda - + self.n_bins_lambda = self.inference_conf.inference.model_params.n_bins_lda # Set the batch size - self.batch_size = ( - batch_size - if batch_size is not None - else self.training_conf.training.model_params.batch_size + self.batch_size = self.inference_conf.inference.batch_size + # Set the cycle to use for inference + self.cycle = self.inference_conf.inference.cycle + + # Overwrite the model parameters with the inference configuration + self.training_conf.training.model_params = self.overwrite_model_params( + self.training_conf, self.inference_conf ) # Instantiate the PSF simulator object @@ -73,6 +78,29 @@ def __init__( # Load the trained PSF model self.trained_psf_model = self.get_trained_psf_model() + @staticmethod + def overwrite_model_params(training_conf=None, inference_conf=None): + """Overwrite model_params of the training_conf with the inference_conf. + + Parameters + ---------- + training_conf : RecursiveNamespace + Configuration object containing model parameters and training hyperparameters. + inference_conf : RecursiveNamespace + Configuration object containing inference-related parameters. + + """ + model_params = training_conf.training.model_params + inf_model_params = inference_conf.inference.model_params + if model_params is not None and inf_model_params is not None: + for key, value in inf_model_params.__dict__.items(): + # Check if model_params has the attribute + if hasattr(model_params, key): + # Set the attribute of model_params to the new value + setattr(model_params, key, value) + + return model_params + def get_trained_psf_model(self): """Get the trained PSF model.""" @@ -110,25 +138,58 @@ def set_source_parameters(self, x_field, y_field, seds): np.array([x_field, y_field]).T, dtype=tf.float32 ) # Process SED data - self.sed_data = self.data_handler.process_sed_data(seds) - - def get_psfs(self): - """Generate PSFs on the input source parameters.""" + self.data_handler.process_sed_data(seds) + self.sed_data = self.data_handler.sed_data + + def compute_psfs(self): + """Compute the PSFs for the input source parameters.""" + + # Check if source parameters are set + if self.positions is None or self.sed_data is None: + raise ValueError( + "Source parameters not set. Call set_source_parameters first." + ) + + # Get the number of samples + n_samples = self.positions.shape[0] + # Initialize counter + counter = 0 + # Initialize PSF array + self.inferred_psfs = np.zeros((n_samples,)) + psf_array = [] while counter < n_samples: # Calculate the batch end element - if counter + batch_size <= n_samples: - end_sample = counter + batch_size + if counter + self.batch_size <= n_samples: + end_sample = counter + self.batch_size else: end_sample = n_samples - # Define the batch positions - batch_pos = pos[counter:end_sample, :] + # Define the batch positions + batch_pos = self.positions[counter:end_sample, :] + batch_seds = self.sed_data[counter:end_sample, :, :] + + # Generate PSFs for the current batch + batch_inputs = [batch_pos, batch_seds] + batch_poly_psfs = self.trained_psf_model(batch_inputs, training=False) + + # Append to the PSF array + psf_array.append(poly_psfs) + + # Update the counter + counter += self.batch_size + + return tf.concat(psf_array, axis=0) + + def get_psfs(self) -> np.ndarray: + """Get all the generated PSFs.""" + + pass - inputs = [self.positions, self.sed_data] - poly_psfs = self.trained_psf_model(inputs, training=False) + def get_psf(self, index) -> np.ndarray: + """Generate the generated PSF at a specific index.""" - return poly_psfs + pass # def run_pipeline(): From 0414a502cd4525b9d6bfcff55fdad26f0cdef1e2 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 14:45:24 +0200 Subject: [PATCH 026/146] automatic formatting --- src/wf_psf/data/data_handler.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index c6d51a0d..95eb2e11 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -28,7 +28,7 @@ class DataHandler: DataHandler for WaveDiff PSF modeling. This class manages loading, preprocessing, and TensorFlow conversion of datasets used - for PSF model training, testing, and inference in the WaveDiff framework. + for PSF model training, testing, and inference in the WaveDiff framework. Parameters ---------- @@ -104,7 +104,7 @@ def __init__( dataset : dict or list, optional A pre-loaded dataset to use directly (overrides `load_data`). sed_data : array-like, optional - Pre-loaded SED data to use directly. If not provided but `dataset` is, + Pre-loaded SED data to use directly. If not provided but `dataset` is, SEDs are taken from `dataset["SEDs"]`. Raises @@ -114,7 +114,7 @@ def __init__( Notes ----- - - `self.dataset` and `self.sed_data` are both `None` if neither `dataset` nor + - `self.dataset` and `self.sed_data` are both `None` if neither `dataset` nor `load_data=True` is used. - TensorFlow conversion is performed at the end of initialization via `convert_dataset_to_tensorflow()`. """ @@ -137,7 +137,6 @@ def __init__( self.dataset = None self.sed_data = None - def load_dataset(self): """Load dataset. @@ -154,7 +153,6 @@ def validate_and_process_dataset(self): self._validate_dataset_structure() self._convert_dataset_to_tensorflow() - def _validate_dataset_structure(self): """Validate dataset structure based on dataset_type.""" if self.dataset is None: @@ -174,10 +172,9 @@ def _validate_dataset_structure(self): else: logger.warning(f"Unrecognized dataset_type: {self.dataset_type}") - def _convert_dataset_to_tensorflow(self): """Convert dataset to TensorFlow tensors.""" - + self.dataset["positions"] = tf.convert_to_tensor( self.dataset["positions"], dtype=tf.float32 ) @@ -216,7 +213,7 @@ def process_sed_data(self, sed_data): Notes ----- - The resulting tensor is stored in `self.sed_data` and has shape + The resulting tensor is stored in `self.sed_data` and has shape `(num_samples, n_bins_lambda, n_components)`, where: - `num_samples` is the number of SEDs, - `n_bins_lambda` is the number of wavelength bins, @@ -290,7 +287,7 @@ def get_obs_positions(data): def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: """Extract specific star-related data from training and test datasets. - + This function retrieves and concatenates specific star-related data (e.g., stars, masks) from the star training and test datasets such as star stamps or masks, based on the provided keys. @@ -320,10 +317,14 @@ def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: """ # Ensure the requested keys exist in both training and test datasets missing_keys = [ - key for key, dataset in [(train_key, data.training_data.dataset), (test_key, data.test_data.dataset)] + key + for key, dataset in [ + (train_key, data.training_data.dataset), + (test_key, data.test_data.dataset), + ] if key not in dataset ] - + if missing_keys: raise KeyError(f"Missing keys in dataset: {missing_keys}") @@ -339,4 +340,3 @@ def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: # Concatenate and return return np.concatenate((train_data, test_data), axis=0) - From 34cd12198d9163754d7941aaa2eac7f1b986ead9 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 14:45:43 +0200 Subject: [PATCH 027/146] add first completed class draft --- src/wf_psf/inference/psf_inference.py | 49 +++++++++++++++++++++------ 1 file changed, 39 insertions(+), 10 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index a5ca12a6..80a2c1c1 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -21,7 +21,23 @@ class PSFInference: - """Class to perform inference on PSF models.""" + """Class to perform inference on PSF models. + + + Parameters + ---------- + trained_model_path : str + Path to the directory containing the trained model. + model_subdir : str + Subdirectory name of the trained model. + training_conf_path : str + Path to the training configuration file used to train the model. + data_conf_path : str + Path to the data configuration file. + inference_conf_path : str + Path to the inference configuration file. + + """ def __init__( self, @@ -55,8 +71,11 @@ def __init__( self.n_bins_lambda = self.inference_conf.inference.model_params.n_bins_lda # Set the batch size self.batch_size = self.inference_conf.inference.batch_size + assert self.batch_size > 0, "Batch size must be greater than 0." # Set the cycle to use for inference self.cycle = self.inference_conf.inference.cycle + # Get output psf dimensions + self.output_dim = self.inference_conf.inference.model_params.output_dim # Overwrite the model parameters with the inference configuration self.training_conf.training.model_params = self.overwrite_model_params( @@ -73,6 +92,7 @@ def __init__( simPSF=self.simPSF, n_bins_lambda=self.n_bins_lambda, load_data=False, + dataset=None, ) # Load the trained PSF model @@ -155,8 +175,7 @@ def compute_psfs(self): # Initialize counter counter = 0 # Initialize PSF array - self.inferred_psfs = np.zeros((n_samples,)) - psf_array = [] + self.inferred_psfs = np.zeros((n_samples, self.output_dim, self.output_dim)) while counter < n_samples: # Calculate the batch end element @@ -174,22 +193,32 @@ def compute_psfs(self): batch_poly_psfs = self.trained_psf_model(batch_inputs, training=False) # Append to the PSF array - psf_array.append(poly_psfs) + self.inferred_psfs[counter:end_sample, :, :] = batch_poly_psfs.numpy() # Update the counter counter += self.batch_size - return tf.concat(psf_array, axis=0) - def get_psfs(self) -> np.ndarray: - """Get all the generated PSFs.""" + """Get all the generated PSFs. - pass + Returns + ------- + np.ndarray + The generated PSFs for the input source parameters. + Shape is (n_samples, output_dim, output_dim). + """ + return self.inferred_psfs def get_psf(self, index) -> np.ndarray: - """Generate the generated PSF at a specific index.""" + """Generate the generated PSF at a specific index. - pass + Returns + ------- + np.ndarray + The generated PSFs for the input source parameters. + Shape is (output_dim, output_dim). + """ + return self.inferred_psfs[index] # def run_pipeline(): From bc795606befa4102f2c194713d2f8e4944291e68 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 14:45:55 +0200 Subject: [PATCH 028/146] add inference config file --- config/inference_conf.yaml | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 config/inference_conf.yaml diff --git a/config/inference_conf.yaml b/config/inference_conf.yaml new file mode 100644 index 00000000..af67cde6 --- /dev/null +++ b/config/inference_conf.yaml @@ -0,0 +1,28 @@ + +inference: + # Inference batch size + batch_size: 16 + + # Cycle to use for inference. Can be: 1, 2, ... + cycle: 2 + + # The following parameters will overwrite the `model_params` in the training config file. + model_params: + # Num of wavelength bins to reconstruct polychromatic objects. + n_bins_lda: 20 + + # Downsampling rate to match the oversampled model to the specified telescope's sampling. + output_Q: 3 + + # Oversampling rate used for the OPD/WFE PSF model. + oversampling_rate: 3 + + # Dimension of the pixel PSF postage stamp + output_dim: 32 + + # Dimension of the OPD/Wavefront space. + pupil_diameter: 256 + + + + From 95e592cd23c3c0a0c0aac3a81e23c8945041a619 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 14:46:16 +0200 Subject: [PATCH 029/146] remove unused code --- src/wf_psf/inference/psf_inference.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 80a2c1c1..97db58da 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -219,16 +219,3 @@ def get_psf(self, index) -> np.ndarray: Shape is (output_dim, output_dim). """ return self.inferred_psfs[index] - - -# def run_pipeline(): -# psf_model = get_trained_psf_model( -# model_path, model_dir, cycle, training_conf, data_conf -# ) -# inputs = prepare_inputs() -# psfs = generate_psfs( -# psf_model, -# inputs, -# batch_size=1, -# ) -# return psfs From aa9f46c17410de6838d398ba273ff70b21b3f04e Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 14:56:51 +0200 Subject: [PATCH 030/146] update params --- config/inference_conf.yaml | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/config/inference_conf.yaml b/config/inference_conf.yaml index af67cde6..cf4de4ff 100644 --- a/config/inference_conf.yaml +++ b/config/inference_conf.yaml @@ -12,17 +12,8 @@ inference: n_bins_lda: 20 # Downsampling rate to match the oversampled model to the specified telescope's sampling. - output_Q: 3 - - # Oversampling rate used for the OPD/WFE PSF model. - oversampling_rate: 3 + output_Q: 1 # Dimension of the pixel PSF postage stamp output_dim: 32 - - # Dimension of the OPD/Wavefront space. - pupil_diameter: 256 - - - From 5a945e017ceee79505871dd348392eef3c89fe50 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 14:57:17 +0200 Subject: [PATCH 031/146] update params --- config/inference_conf.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config/inference_conf.yaml b/config/inference_conf.yaml index cf4de4ff..afb0174c 100644 --- a/config/inference_conf.yaml +++ b/config/inference_conf.yaml @@ -9,11 +9,11 @@ inference: # The following parameters will overwrite the `model_params` in the training config file. model_params: # Num of wavelength bins to reconstruct polychromatic objects. - n_bins_lda: 20 + n_bins_lda: 8 # Downsampling rate to match the oversampled model to the specified telescope's sampling. output_Q: 1 # Dimension of the pixel PSF postage stamp - output_dim: 32 + output_dim: 64 From 95c1201422af042254987528f3a3f38710e56696 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 17:04:40 +0200 Subject: [PATCH 032/146] update inference --- config/inference_conf.yaml | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/config/inference_conf.yaml b/config/inference_conf.yaml index afb0174c..7e971957 100644 --- a/config/inference_conf.yaml +++ b/config/inference_conf.yaml @@ -2,10 +2,22 @@ inference: # Inference batch size batch_size: 16 - # Cycle to use for inference. Can be: 1, 2, ... cycle: 2 + configs: + # Path to the directory containing the trained model + training_config_path: models/ + + # Subdirectory name of the trained model + model_subdir: models + + # Path to the training configuration file used to train the model + trained_model_path: config/training_config.yaml + + # Path to the data config file (this could contain prior information) + data_conf_path: + # The following parameters will overwrite the `model_params` in the training config file. model_params: # Num of wavelength bins to reconstruct polychromatic objects. From f9d7c53403a14da3c42d474ed0d534113a7149f0 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 17:05:00 +0200 Subject: [PATCH 033/146] reduce arguments and add compute psfs when appropiate --- src/wf_psf/inference/psf_inference.py | 48 +++++++++++++-------------- 1 file changed, 23 insertions(+), 25 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 97db58da..e73231cb 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -26,32 +26,31 @@ class PSFInference: Parameters ---------- - trained_model_path : str - Path to the directory containing the trained model. - model_subdir : str - Subdirectory name of the trained model. - training_conf_path : str - Path to the training configuration file used to train the model. - data_conf_path : str - Path to the data configuration file. inference_conf_path : str Path to the inference configuration file. """ - def __init__( - self, - trained_model_path: str, - model_subdir: str, - training_conf_path: str, - data_conf_path: str, - inference_conf_path: str, - ): - self.trained_model_path = trained_model_path - self.model_subdir = model_subdir - self.training_conf_path = training_conf_path - self.data_conf_path = data_conf_path + def __init__(self, inference_conf_path: str): + self.inference_conf_path = inference_conf_path + # Load the training and data configurations + self.inference_conf = read_conf(inference_conf_path) + + # Set config paths + self.config_paths = self.inference_conf.inference.configs.config_paths + self.trained_model_path = self.config_paths.trained_model_path + self.model_subdir = self.config_paths.model_subdir + self.training_config_path = self.config_paths.training_config_path + self.data_conf_path = self.config_paths.data_conf_path + + # Load the training and data configurations + self.training_conf = read_conf(self.training_conf_path) + if self.data_conf_path is not None: + # Load the data configuration + self.data_conf = read_conf(self.data_conf_path) + else: + self.data_conf = None # Set source parameters self.x_field = None @@ -62,11 +61,6 @@ def __init__( # Set compute PSF placeholder self.inferred_psfs = None - # Load the training and data configurations - self.training_conf = read_conf(training_conf_path) - self.data_conf = read_conf(data_conf_path) - self.inference_conf = read_conf(inference_conf_path) - # Set the number of labmda bins self.n_bins_lambda = self.inference_conf.inference.model_params.n_bins_lda # Set the batch size @@ -207,6 +201,8 @@ def get_psfs(self) -> np.ndarray: The generated PSFs for the input source parameters. Shape is (n_samples, output_dim, output_dim). """ + if self.inferred_psfs is None: + self.compute_psfs() return self.inferred_psfs def get_psf(self, index) -> np.ndarray: @@ -218,4 +214,6 @@ def get_psf(self, index) -> np.ndarray: The generated PSFs for the input source parameters. Shape is (output_dim, output_dim). """ + if self.inferred_psfs is None: + self.compute_psfs() return self.inferred_psfs[index] From 6c84c0b770ec2d6ac1c31abd65e6d96a9278a971 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 17:08:49 +0200 Subject: [PATCH 034/146] add config handler class --- src/wf_psf/inference/psf_inference.py | 53 +++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index e73231cb..d0f2b930 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -20,6 +20,59 @@ from typing import Optional +class InferenceConfigHandler: + ids = ("inference_conf",) + + def __init__( + self, + trained_model_path: str, + model_subdir: str, + training_conf_path: str, + data_conf_path: str, + inference_conf_path: str, + ): + self.trained_model_path = trained_model_path + self.model_subdir = model_subdir + self.training_conf_path = training_conf_path + self.data_conf_path = data_conf_path + self.inference_conf_path = inference_conf_path + + # Overwrite the model parameters with the inference configuration + self.model_params = self.overwrite_model_params( + self.training_conf, self.inference_conf + ) + + def read_configurations(self): + # Load the training and data configurations + self.training_conf = read_conf(training_conf_path) + self.data_conf = read_conf(data_conf_path) + self.inference_conf = read_conf(inference_conf_path) + + @staticmethod + def overwrite_model_params(training_conf=None, inference_conf=None): + """Overwrite model_params of the training_conf with the inference_conf. + + Parameters + ---------- + training_conf : RecursiveNamespace + Configuration object containing model parameters and training hyperparameters. + inference_conf : RecursiveNamespace + Configuration object containing inference-related parameters. + + """ + model_params = training_conf.training.model_params + inf_model_params = inference_conf.inference.model_params + + if model_params is not None and inf_model_params is not None: + for key, value in inf_model_params.__dict__.items(): + # Check if model_params has the attribute + if hasattr(model_params, key): + # Set the attribute of model_params to the new value + setattr(model_params, key, value) + + return model_params + + class PSFInference: """Class to perform inference on PSF models. From 52a6676cf52d24c18d754cbcb6596ffcc9e608ad Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 17:48:47 +0200 Subject: [PATCH 035/146] set up inference config handler and simplify PSFInferenc init --- src/wf_psf/inference/psf_inference.py | 103 ++++++++++++++------------ 1 file changed, 55 insertions(+), 48 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index d0f2b930..57f84080 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -23,30 +23,43 @@ class InferenceConfigHandler: ids = ("inference_conf",) - def __init__( - self, - trained_model_path: str, - model_subdir: str, - training_conf_path: str, - data_conf_path: str, - inference_conf_path: str, - ): - self.trained_model_path = trained_model_path - self.model_subdir = model_subdir - self.training_conf_path = training_conf_path - self.data_conf_path = data_conf_path + def __init__(self, inference_conf_path: str): self.inference_conf_path = inference_conf_path + # Load the inference configuration + self.read_configurations() + # Overwrite the model parameters with the inference configuration self.model_params = self.overwrite_model_params( self.training_conf, self.inference_conf ) def read_configurations(self): + """Read the configuration files.""" + # Load the inference configuration + self.inference_conf = read_conf(self.inference_conf_path) + # Set config paths + self.set_config_paths() # Load the training and data configurations - self.training_conf = read_conf(training_conf_path) - self.data_conf = read_conf(data_conf_path) - self.inference_conf = read_conf(inference_conf_path) + self.training_conf = read_conf(self.training_conf_path) + if self.data_conf_path is not None: + # Load the data configuration + self.data_conf = read_conf(self.data_conf_path) + else: + self.data_conf = None + + def set_config_paths(self): + """Extract and set the configuration paths.""" + # Set config paths + self.config_paths = self.inference_conf.inference.configs.config_paths + self.trained_model_path = self.config_paths.trained_model_path + self.model_subdir = self.config_paths.model_subdir + self.training_config_path = self.config_paths.training_config_path + self.data_conf_path = self.config_paths.data_conf_path + + def get_configs(self): + """Get the configurations.""" + return (self.inference_conf, self.training_conf, self.data_conf) @staticmethod def overwrite_model_params(training_conf=None, inference_conf=None): @@ -86,48 +99,25 @@ class PSFInference: def __init__(self, inference_conf_path: str): - self.inference_conf_path = inference_conf_path - # Load the training and data configurations - self.inference_conf = read_conf(inference_conf_path) - - # Set config paths - self.config_paths = self.inference_conf.inference.configs.config_paths - self.trained_model_path = self.config_paths.trained_model_path - self.model_subdir = self.config_paths.model_subdir - self.training_config_path = self.config_paths.training_config_path - self.data_conf_path = self.config_paths.data_conf_path + self.inference_config_handler = InferenceConfigHandler( + inference_conf_path=inference_conf_path + ) - # Load the training and data configurations - self.training_conf = read_conf(self.training_conf_path) - if self.data_conf_path is not None: - # Load the data configuration - self.data_conf = read_conf(self.data_conf_path) - else: - self.data_conf = None + self.inference_conf, self.training_conf, self.data_conf = ( + self.inference_config_handler.get_configs() + ) - # Set source parameters + # Init source parameters self.x_field = None self.y_field = None self.seds = None self.trained_psf_model = None - # Set compute PSF placeholder + # Init compute PSF placeholder self.inferred_psfs = None - # Set the number of labmda bins - self.n_bins_lambda = self.inference_conf.inference.model_params.n_bins_lda - # Set the batch size - self.batch_size = self.inference_conf.inference.batch_size - assert self.batch_size > 0, "Batch size must be greater than 0." - # Set the cycle to use for inference - self.cycle = self.inference_conf.inference.cycle - # Get output psf dimensions - self.output_dim = self.inference_conf.inference.model_params.output_dim - - # Overwrite the model parameters with the inference configuration - self.training_conf.training.model_params = self.overwrite_model_params( - self.training_conf, self.inference_conf - ) + # Load inference parameters + self.load_inference_params() # Instantiate the PSF simulator object self.simPSF = psf_models.simPSF(self.training_conf.training.model_params) @@ -145,6 +135,18 @@ def __init__(self, inference_conf_path: str): # Load the trained PSF model self.trained_psf_model = self.get_trained_psf_model() + def load_inference_params(self): + """Load the inference parameters from the configuration file.""" + # Set the number of labmda bins + self.n_bins_lambda = self.inference_conf.inference.model_params.n_bins_lda + # Set the batch size + self.batch_size = self.inference_conf.inference.batch_size + assert self.batch_size > 0, "Batch size must be greater than 0." + # Set the cycle to use for inference + self.cycle = self.inference_conf.inference.cycle + # Get output psf dimensions + self.output_dim = self.inference_conf.inference.model_params.output_dim + @staticmethod def overwrite_model_params(training_conf=None, inference_conf=None): """Overwrite model_params of the training_conf with the inference_conf. @@ -188,6 +190,11 @@ def get_trained_psf_model(self): def set_source_parameters(self, x_field, y_field, seds): """Set the input source parameters for inferring the PSF. + Note + ---- + The input source parameters are expected to be in the WaveDiff format. See the simulated data + format for more details. + Parameters ---------- x_field : array-like From 0f7eea65d127118867502955cbc02c21cf4732a5 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 17:49:47 +0200 Subject: [PATCH 036/146] remove unused imports --- src/wf_psf/inference/psf_inference.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 57f84080..797a410f 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -4,20 +4,17 @@ with trained PSF models. It is able to load a trained model, perform inference on a dataset of SEDs and positions, and generate polychromatic PSFs. -:Authors: Jennifer Pollack +:Authors: Jennifer Pollack , Tobias Liaudat """ import os -import glob -import logging import numpy as np from wf_psf.data.data_handler import DataHandler from wf_psf.utils.read_config import read_conf from wf_psf.psf_models import psf_models from wf_psf.psf_models.psf_model_loader import load_trained_psf_model import tensorflow as tf -from typing import Optional class InferenceConfigHandler: From 0a0fc2d7af8da11de10b3fa2a19b0628c566cfae Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 17:50:52 +0200 Subject: [PATCH 037/146] update inference --- config/inference_conf.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/config/inference_conf.yaml b/config/inference_conf.yaml index 7e971957..0c846fca 100644 --- a/config/inference_conf.yaml +++ b/config/inference_conf.yaml @@ -5,6 +5,7 @@ inference: # Cycle to use for inference. Can be: 1, 2, ... cycle: 2 + # Paths to the configuration files and trained model directory configs: # Path to the directory containing the trained model training_config_path: models/ From b1fdf29d65da28a53cad7ecd44cb169372d21149 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 23 May 2025 11:43:09 +0200 Subject: [PATCH 038/146] Add single-space lines to improve readability; Remove duplicated static method --- src/wf_psf/inference/psf_inference.py | 34 ++++++++------------------- 1 file changed, 10 insertions(+), 24 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 797a410f..f8496453 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -35,10 +35,13 @@ def read_configurations(self): """Read the configuration files.""" # Load the inference configuration self.inference_conf = read_conf(self.inference_conf_path) + # Set config paths self.set_config_paths() + # Load the training and data configurations self.training_conf = read_conf(self.training_conf_path) + if self.data_conf_path is not None: # Load the data configuration self.data_conf = read_conf(self.data_conf_path) @@ -136,37 +139,18 @@ def load_inference_params(self): """Load the inference parameters from the configuration file.""" # Set the number of labmda bins self.n_bins_lambda = self.inference_conf.inference.model_params.n_bins_lda + # Set the batch size self.batch_size = self.inference_conf.inference.batch_size assert self.batch_size > 0, "Batch size must be greater than 0." + # Set the cycle to use for inference self.cycle = self.inference_conf.inference.cycle + # Get output psf dimensions self.output_dim = self.inference_conf.inference.model_params.output_dim - - @staticmethod - def overwrite_model_params(training_conf=None, inference_conf=None): - """Overwrite model_params of the training_conf with the inference_conf. - - Parameters - ---------- - training_conf : RecursiveNamespace - Configuration object containing model parameters and training hyperparameters. - inference_conf : RecursiveNamespace - Configuration object containing inference-related parameters. - - """ - model_params = training_conf.training.model_params - inf_model_params = inference_conf.inference.model_params - if model_params is not None and inf_model_params is not None: - for key, value in inf_model_params.__dict__.items(): - # Check if model_params has the attribute - if hasattr(model_params, key): - # Set the attribute of model_params to the new value - setattr(model_params, key, value) - - return model_params - + + def get_trained_psf_model(self): """Get the trained PSF model.""" @@ -223,8 +207,10 @@ def compute_psfs(self): # Get the number of samples n_samples = self.positions.shape[0] + # Initialize counter counter = 0 + # Initialize PSF array self.inferred_psfs = np.zeros((n_samples, self.output_dim, self.output_dim)) From 8aedf77a32e46c4babf1e7a3e54dc693d3be6cc9 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Mon, 26 May 2025 16:19:40 +0200 Subject: [PATCH 039/146] Add additional PSFInference class attributes; update set_source_parameters to use class attributes; assign variable names in get_trained_psf_model; Update class and __init__ docstrings --- src/wf_psf/inference/psf_inference.py | 83 ++++++++++++++++++++++----- 1 file changed, 68 insertions(+), 15 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index f8496453..552abd02 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -87,17 +87,67 @@ def overwrite_model_params(training_conf=None, inference_conf=None): class PSFInference: - """Class to perform inference on PSF models. + """ + Perform PSF inference using a pre-trained WaveDiff model. + This class handles the setup for PSF inference, including loading configuration + files, instantiating the PSF simulator and data handler, and preparing the + input data required for inference. Parameters ---------- - inference_conf_path : str - Path to the inference configuration file. - + inference_conf_path : str, optional + Path to the inference configuration YAML file. This file should define + paths and parameters for the inference, training, and data configurations. + x_field : array-like, optional + Array of x field-of-view coordinates in the SHE convention to be transformed + and passed to the WaveDiff model. + y_field : array-like, optional + Array of y field-of-view coordinates in the SHE convention to be transformed + and passed to the WaveDiff model. + seds : array-like, optional + Spectral energy distributions (SEDs) for the sources being modeled. These + will be used as part of the input to the PSF simulator. + + Attributes + ---------- + inference_config_handler : InferenceConfigHandler + Handler object to load and parse inference, training, and data configs. + inference_conf : dict + Dictionary containing inference configuration settings. + training_conf : dict + Dictionary containing training configuration settings. + data_conf : dict + Dictionary containing data configuration settings. + x_field : array-like + Input x coordinates after transformation (if applicable). + y_field : array-like + Input y coordinates after transformation (if applicable). + seds : array-like + Input spectral energy distributions. + trained_psf_model : keras.Model + Loaded PSF model used for prediction. + inferred_psfs : array-like or None + Array of inferred PSF images, populated after inference is performed. + simPSF : psf_models.simPSF + PSF simulator instance initialized with training model parameters. + data_handler : DataHandler + Data handler configured for inference, used to prepare inputs to the model. + n_bins_lambda : int + Number of spectral bins used for PSF simulation (loaded from config). + + Methods + ------- + load_inference_params() + Load parameters required for inference, including spectral binning. + get_trained_psf_model() + Load and return the trained Keras model for PSF inference. + run_inference() + Run the model on the input data and generate predicted PSFs. """ - def __init__(self, inference_conf_path: str): + + def __init__(self, inference_conf_path: str, x_field=None, y_field=None, seds=None): self.inference_config_handler = InferenceConfigHandler( inference_conf_path=inference_conf_path @@ -108,9 +158,9 @@ def __init__(self, inference_conf_path: str): ) # Init source parameters - self.x_field = None - self.y_field = None - self.seds = None + self.x_field = x_field + self.y_field = y_field + self.seds = seds self.trained_psf_model = None # Init compute PSF placeholder @@ -149,18 +199,21 @@ def load_inference_params(self): # Get output psf dimensions self.output_dim = self.inference_conf.inference.model_params.output_dim - + def get_trained_psf_model(self): """Get the trained PSF model.""" + # Load the trained PSF model + model_path = self.inference_config_handler.trained_model_path + model_dir_name = self.inference_config_handler model_name = self.training_conf.training.model_params.model_name id_name = self.training_conf.training.id_name weights_path_pattern = os.path.join( - self.trained_model_path, - self.model_subdir, - (f"{self.model_subdir}*_{model_name}" f"*{id_name}_cycle{self.cycle}*"), + model_path, + model_dir_name, + (f"{model_dir_name}*_{model_name}" f"*{id_name}_cycle{self.cycle}*"), ) return load_trained_psf_model( self.training_conf, @@ -168,7 +221,7 @@ def get_trained_psf_model(self): weights_path_pattern, ) - def set_source_parameters(self, x_field, y_field, seds): + def set_source_parameters(self): """Set the input source parameters for inferring the PSF. Note @@ -190,10 +243,10 @@ def set_source_parameters(self, x_field, y_field, seds): """ # Positions array is of shape (n_sources, 2) self.positions = tf.convert_to_tensor( - np.array([x_field, y_field]).T, dtype=tf.float32 + np.array([self.x_field, self.y_field]).T, dtype=tf.float32 ) # Process SED data - self.data_handler.process_sed_data(seds) + self.data_handler.process_sed_data(self.seds) self.sed_data = self.data_handler.sed_data def compute_psfs(self): From 0968e0fb55b3f8e50afb1cc93170bee6a18c0bfe Mon Sep 17 00:00:00 2001 From: jeipollack Date: Tue, 27 May 2025 18:09:04 +0100 Subject: [PATCH 040/146] Add checks to convert to np.ndarray and expand dimensions if needed --- src/wf_psf/data/centroids.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/wf_psf/data/centroids.py b/src/wf_psf/data/centroids.py index 01ecea4e..c95b303f 100644 --- a/src/wf_psf/data/centroids.py +++ b/src/wf_psf/data/centroids.py @@ -247,6 +247,19 @@ class CentroidEstimator: def __init__(self, im, mask=None, sigma_init=7.5, n_iter=5, auto_run=True, xc=None, yc=None): """Initialize class attributes.""" + + # Convert to np.ndarray if not already + im = np.asarray(im) + if mask is not None: + mask = np.asarray(mask) + + # Check im dimensions convert to batch, if 2D + if im.ndim == 2: + # Single stamp → convert to batch of one + im = np.expand_dims(im, axis=0) + elif im.ndim != 3: + raise ValueError(f"Expected 2D or 3D input, got shape {im.shape}") + self.im = im self.mask = mask if self.mask is not None: From dc464fb097bc72d9de8bdd1dbb87ea9a4a9fc31d Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 5 Jun 2025 13:53:34 +0100 Subject: [PATCH 041/146] Update pyproject.toml with numpy dependency limits - sdc-uk --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 24160d37..f6c6a922 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,9 +11,9 @@ maintainers = [ description = 'A software framework to perform Differentiable wavefront-based PSF modelling.' dependencies = [ - "numpy>=1.26.4,<2.0", + "numpy>=1.18,<1.24", "scipy", - "tensorflow==2.11.0", + # "tensorflow==2.11.0", "tensorflow-addons", "tensorflow-estimator", "zernike", From 4d3bdbe58450c3f1a6cc33e370d05d16da7eaa3e Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 5 Jun 2025 13:56:31 +0100 Subject: [PATCH 042/146] Revert "Update pyproject.toml with numpy dependency limits - sdc-uk" This reverts commit 076a5f27d54287cb43dacc5f9d4a574488b5999c. --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f6c6a922..24160d37 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,9 +11,9 @@ maintainers = [ description = 'A software framework to perform Differentiable wavefront-based PSF modelling.' dependencies = [ - "numpy>=1.18,<1.24", + "numpy>=1.26.4,<2.0", "scipy", - # "tensorflow==2.11.0", + "tensorflow==2.11.0", "tensorflow-addons", "tensorflow-estimator", "zernike", From 75cdaac11d4bd05a962179f9dbba16d93649c1cc Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 6 Jun 2025 12:53:26 +0200 Subject: [PATCH 043/146] Correct name of psf_inference_test module to follow repo naming convention --- .../{test_psf_inference.py => psf_inference_test.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/wf_psf/tests/test_inference/{test_psf_inference.py => psf_inference_test.py} (100%) diff --git a/src/wf_psf/tests/test_inference/test_psf_inference.py b/src/wf_psf/tests/test_inference/psf_inference_test.py similarity index 100% rename from src/wf_psf/tests/test_inference/test_psf_inference.py rename to src/wf_psf/tests/test_inference/psf_inference_test.py From 1171b7bf69ce39bf369b9ae6df155d4407a615e4 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sun, 8 Jun 2025 18:22:14 +0200 Subject: [PATCH 044/146] Correct config subkey names for defining trained_model_path and trained_model_config_path --- config/inference_conf.yaml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/config/inference_conf.yaml b/config/inference_conf.yaml index 0c846fca..c9d29cb8 100644 --- a/config/inference_conf.yaml +++ b/config/inference_conf.yaml @@ -8,16 +8,16 @@ inference: # Paths to the configuration files and trained model directory configs: # Path to the directory containing the trained model - training_config_path: models/ + trained_model_path: /path/to/trained/model/ - # Subdirectory name of the trained model - model_subdir: models + # Subdirectory name of the trained model, e.g. psf_model + model_subdir: model - # Path to the training configuration file used to train the model - trained_model_path: config/training_config.yaml + # Relative Path to the training configuration file used to train the model + trained_model_config_path: config/training_config.yaml # Path to the data config file (this could contain prior information) - data_conf_path: + data_config_path: # The following parameters will overwrite the `model_params` in the training config file. model_params: From c750904a64d0032e0e6738d39bd69c246babbdf4 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sun, 8 Jun 2025 18:24:26 +0200 Subject: [PATCH 045/146] Refactor psf_inference adding PSFInferenceEngine to separate concerns, enabling isolated testing; implement lazy loaders for config handling, model loading, and inference --- src/wf_psf/inference/psf_inference.py | 426 +++++++++++++------------- 1 file changed, 206 insertions(+), 220 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 552abd02..9012e3fc 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -9,6 +9,7 @@ """ import os +from pathlib import Path import numpy as np from wf_psf.data.data_handler import DataHandler from wf_psf.utils.read_config import read_conf @@ -20,71 +21,60 @@ class InferenceConfigHandler: ids = ("inference_conf",) - def __init__(self, inference_conf_path: str): - self.inference_conf_path = inference_conf_path + def __init__(self, inference_config_path: str): + self.inference_config_path = inference_config_path + self.inference_config = None + self.training_config = None + self.data_config = None - # Load the inference configuration - self.read_configurations() - # Overwrite the model parameters with the inference configuration - self.model_params = self.overwrite_model_params( - self.training_conf, self.inference_conf - ) - - def read_configurations(self): - """Read the configuration files.""" - # Load the inference configuration - self.inference_conf = read_conf(self.inference_conf_path) - - # Set config paths + def load_configs(self): + """Load configuration files based on the inference config.""" + self.inference_config = read_conf(self.inference_config_path) self.set_config_paths() - - # Load the training and data configurations - self.training_conf = read_conf(self.training_conf_path) - - if self.data_conf_path is not None: + self.training_config = read_conf(self.trained_model_config_path) + + if self.data_config_path is not None: # Load the data configuration - self.data_conf = read_conf(self.data_conf_path) - else: - self.data_conf = None + self.data_conf = read_conf(self.data_config_path) + def set_config_paths(self): """Extract and set the configuration paths.""" # Set config paths - self.config_paths = self.inference_conf.inference.configs.config_paths - self.trained_model_path = self.config_paths.trained_model_path - self.model_subdir = self.config_paths.model_subdir - self.training_config_path = self.config_paths.training_config_path - self.data_conf_path = self.config_paths.data_conf_path + config_paths = self.inference_config.inference.configs + + self.trained_model_path = Path(config_paths.trained_model_path) + self.model_subdir = config_paths.model_subdir + self.trained_model_config_path = self.trained_model_path / config_paths.trained_model_config_path + self.data_config_path = config_paths.data_config_path - def get_configs(self): - """Get the configurations.""" - return (self.inference_conf, self.training_conf, self.data_conf) @staticmethod - def overwrite_model_params(training_conf=None, inference_conf=None): - """Overwrite model_params of the training_conf with the inference_conf. + def overwrite_model_params(training_config=None, inference_config=None): + """ + Overwrite training model_params with values from inference_config if available. Parameters ---------- - training_conf : RecursiveNamespace - Configuration object containing model parameters and training hyperparameters. - inference_conf : RecursiveNamespace - Configuration object containing inference-related parameters. - + training_config : RecursiveNamespace + Configuration object from training phase. + inference_config : RecursiveNamespace + Configuration object from inference phase. + + Notes + ----- + Updates are applied in-place to training_config.training.model_params. """ - model_params = training_conf.training.model_params - inf_model_params = inference_conf.inference.model_params + model_params = training_config.training.model_params + inf_model_params = inference_config.inference.model_params - if model_params is not None and inf_model_params is not None: + if model_params and inf_model_params: for key, value in inf_model_params.__dict__.items(): - # Check if model_params has the attribute if hasattr(model_params, key): - # Set the attribute of model_params to the new value setattr(model_params, key, value) - return model_params - + class PSFInference: """ @@ -96,220 +86,216 @@ class PSFInference: Parameters ---------- - inference_conf_path : str, optional - Path to the inference configuration YAML file. This file should define - paths and parameters for the inference, training, and data configurations. + inference_config_path : str + Path to the inference configuration YAML file. x_field : array-like, optional - Array of x field-of-view coordinates in the SHE convention to be transformed - and passed to the WaveDiff model. + x coordinates in SHE convention. y_field : array-like, optional - Array of y field-of-view coordinates in the SHE convention to be transformed - and passed to the WaveDiff model. + y coordinates in SHE convention. seds : array-like, optional - Spectral energy distributions (SEDs) for the sources being modeled. These - will be used as part of the input to the PSF simulator. - - Attributes - ---------- - inference_config_handler : InferenceConfigHandler - Handler object to load and parse inference, training, and data configs. - inference_conf : dict - Dictionary containing inference configuration settings. - training_conf : dict - Dictionary containing training configuration settings. - data_conf : dict - Dictionary containing data configuration settings. - x_field : array-like - Input x coordinates after transformation (if applicable). - y_field : array-like - Input y coordinates after transformation (if applicable). - seds : array-like - Input spectral energy distributions. - trained_psf_model : keras.Model - Loaded PSF model used for prediction. - inferred_psfs : array-like or None - Array of inferred PSF images, populated after inference is performed. - simPSF : psf_models.simPSF - PSF simulator instance initialized with training model parameters. - data_handler : DataHandler - Data handler configured for inference, used to prepare inputs to the model. - n_bins_lambda : int - Number of spectral bins used for PSF simulation (loaded from config). - - Methods - ------- - load_inference_params() - Load parameters required for inference, including spectral binning. - get_trained_psf_model() - Load and return the trained Keras model for PSF inference. - run_inference() - Run the model on the input data and generate predicted PSFs. + Spectral energy distributions (SEDs). """ + def __init__(self, inference_config_path: str, x_field=None, y_field=None, seds=None): - def __init__(self, inference_conf_path: str, x_field=None, y_field=None, seds=None): + self.inference_config_path = inference_config_path - self.inference_config_handler = InferenceConfigHandler( - inference_conf_path=inference_conf_path - ) - - self.inference_conf, self.training_conf, self.data_conf = ( - self.inference_config_handler.get_configs() - ) - - # Init source parameters + # Inputs for the model self.x_field = x_field self.y_field = y_field self.seds = seds - self.trained_psf_model = None - - # Init compute PSF placeholder - self.inferred_psfs = None - - # Load inference parameters - self.load_inference_params() - - # Instantiate the PSF simulator object - self.simPSF = psf_models.simPSF(self.training_conf.training.model_params) - - # Instantiate the data handler - self.data_handler = DataHandler( - dataset_type="inference", - data_params=self.data_conf, - simPSF=self.simPSF, - n_bins_lambda=self.n_bins_lambda, - load_data=False, - dataset=None, + + # Internal caches for lazy-loading + self._config_handler = None + self._simPSF = None + self._data_handler = None + self._trained_psf_model = None + self._n_bins_lambda = None + self._batch_size = None + self._cycle = None + self._output_dim = None + + # Initialise PSF Inference engine + self.engine = None + + @property + def config_handler(self): + if self._config_handler is None: + self._config_handler = InferenceConfigHandler(self.inference_config_path) + self._config_handler.load_configs() + return self._config_handler + + def prepare_configs(self): + """Prepare the configuration for inference.""" + # Overwrite model parameters with inference config + self.config_handler.overwrite_model_params( + self.training_config, self.inference_config ) - # Load the trained PSF model - self.trained_psf_model = self.get_trained_psf_model() + @property + def inference_config(self): + return self.config_handler.inference_config + + @property + def training_config(self): + return self.config_handler.training_config + + @property + def data_config(self): + return self.config_handler.data_config + + @property + def simPSF(self): + if self._simPSF is None: + self._simPSF = psf_models.simPSF(self.model_params) + return self._simPSF + + @property + def data_handler(self): + if self._data_handler is None: + # Instantiate the data handler + self._data_handler = DataHandler( + dataset_type="inference", + data_params=self.data_config, + simPSF=self.simPSF, + n_bins_lambda=self.n_bins_lambda, + load_data=False, + dataset=None, + ) + return self._data_handler - def load_inference_params(self): - """Load the inference parameters from the configuration file.""" - # Set the number of labmda bins - self.n_bins_lambda = self.inference_conf.inference.model_params.n_bins_lda + @property + def trained_psf_model(self): + if self._trained_psf_model is None: + self._trained_psf_model = self.load_inference_model() + return self._trained_psf_model - # Set the batch size - self.batch_size = self.inference_conf.inference.batch_size - assert self.batch_size > 0, "Batch size must be greater than 0." + def load_inference_model(self): + # Prepare the configuration for inference + self.prepare_configs() - # Set the cycle to use for inference - self.cycle = self.inference_conf.inference.cycle - - # Get output psf dimensions - self.output_dim = self.inference_conf.inference.model_params.output_dim - - - def get_trained_psf_model(self): - """Get the trained PSF model.""" - - # Load the trained PSF model - model_path = self.inference_config_handler.trained_model_path - model_dir_name = self.inference_config_handler - model_name = self.training_conf.training.model_params.model_name - id_name = self.training_conf.training.id_name + model_path = self.config_handler.trained_model_path + model_dir = self.config_handler.model_subdir + model_name = self.training_config.training.model_params.model_name + id_name = self.training_config.training.id_name weights_path_pattern = os.path.join( model_path, - model_dir_name, - (f"{model_dir_name}*_{model_name}" f"*{id_name}_cycle{self.cycle}*"), + model_dir, + f"{model_dir}*_{model_name}*{id_name}_cycle{self.cycle}*" ) + + # Load the trained PSF model return load_trained_psf_model( - self.training_conf, - self.data_conf, + self.training_config, + self.data_config, weights_path_pattern, ) - def set_source_parameters(self): - """Set the input source parameters for inferring the PSF. - - Note - ---- - The input source parameters are expected to be in the WaveDiff format. See the simulated data - format for more details. - - Parameters - ---------- - x_field : array-like - X coordinates of the sources in WaveDiff format. - y_field : array-like - Y coordinates of the sources in WaveDiff format. - seds : list or array-like - A list or array of raw SEDs, where each SED is typically a vector of flux values - or coefficients. These will be processed using the PSF simulator. - It assumes the standard WaveDiff SED format. - - """ - # Positions array is of shape (n_sources, 2) - self.positions = tf.convert_to_tensor( + @property + def n_bins_lambda(self): + if self._n_bins_lambda is None: + self._n_bins_lambda = self.inference_config.inference.model_params.n_bins_lda + return self._n_bins_lambda + + @property + def batch_size(self): + if self._batch_size is None: + self._batch_size = self.inference_config.inference.batch_size + assert self._batch_size > 0, "Batch size must be greater than 0." + return self._batch_size + + @property + def cycle(self): + if self._cycle is None: + self._cycle = self.inference_config.inference.cycle + return self._cycle + + @property + def output_dim(self): + if self._output_dim is None: + self._output_dim = self.inference_config.inference.model_params.output_dim + return self._output_dim + + def _prepare_positions_and_seds(self): + """Preprocess and return tensors for positions and SEDs.""" + positions = tf.convert_to_tensor( np.array([self.x_field, self.y_field]).T, dtype=tf.float32 ) - # Process SED data self.data_handler.process_sed_data(self.seds) - self.sed_data = self.data_handler.sed_data + sed_data = self.data_handler.sed_data + return positions, sed_data - def compute_psfs(self): - """Compute the PSFs for the input source parameters.""" + def run_inference(self): + """Run PSF inference and return the full PSF array.""" + positions, sed_data = self._prepare_positions_and_seds() - # Check if source parameters are set - if self.positions is None or self.sed_data is None: - raise ValueError( - "Source parameters not set. Call set_source_parameters first." - ) + self.engine = PSFInferenceEngine( + trained_model=self.trained_psf_model, + batch_size=self.batch_size, + output_dim=self.output_dim, + ) + return self.engine.compute_psfs(positions, sed_data) + + def _ensure_psf_inference_completed(self): + if self.engine is None or self.engine.inferred_psfs is None: + self.run_inference() + + def get_psfs(self): + self._ensure_psf_inference_completed() + return self.engine.get_psfs() + + def get_psf(self, index): + self._ensure_psf_inference_completed() + return self.engine.get_psf(index) - # Get the number of samples - n_samples = self.positions.shape[0] +class PSFInferenceEngine: + def __init__(self, trained_model, batch_size: int, output_dim: int): + self.trained_model = trained_model + self.batch_size = batch_size + self.output_dim = output_dim + self._inferred_psfs = None + + @property + def inferred_psfs(self) -> np.ndarray: + """Access the cached inferred PSFs, if available.""" + return self._inferred_psfs + + def compute_psfs(self, positions: tf.Tensor, sed_data: tf.Tensor) -> np.ndarray: + """Compute and cache PSFs for the input source parameters.""" + n_samples = positions.shape[0] + self._inferred_psfs = np.zeros((n_samples, self.output_dim, self.output_dim), dtype=np.float32) # Initialize counter counter = 0 - - # Initialize PSF array - self.inferred_psfs = np.zeros((n_samples, self.output_dim, self.output_dim)) - while counter < n_samples: # Calculate the batch end element - if counter + self.batch_size <= n_samples: - end_sample = counter + self.batch_size - else: - end_sample = n_samples + end = min(counter + self.batch_size, n_samples) # Define the batch positions - batch_pos = self.positions[counter:end_sample, :] - batch_seds = self.sed_data[counter:end_sample, :, :] - - # Generate PSFs for the current batch + batch_pos = positions[counter:end_sample, :] + batch_seds = sed_data[counter:end_sample, :, :] batch_inputs = [batch_pos, batch_seds] - batch_poly_psfs = self.trained_psf_model(batch_inputs, training=False) - - # Append to the PSF array - self.inferred_psfs[counter:end_sample, :, :] = batch_poly_psfs.numpy() + + # Generate PSFs for the current batch + batch_psfs = self.trained_model(batch_inputs, training=False) + self.inferred_psfs[counter:end, :, :] = batch_psfs.numpy() # Update the counter - counter += self.batch_size + counter = end + + return self._inferred_psfs def get_psfs(self) -> np.ndarray: - """Get all the generated PSFs. + """Get all the generated PSFs.""" + if self._inferred_psfs is None: + raise ValueError("PSFs not yet computed. Call compute_psfs() first.") + return self._inferred_psfs + + def get_psf(self, index: int) -> np.ndarray: + """Get the PSF at a specific index.""" + if self._inferred_psfs is None: + raise ValueError("PSFs not yet computed. Call compute_psfs() first.") + return self._inferred_psfs[index] + - Returns - ------- - np.ndarray - The generated PSFs for the input source parameters. - Shape is (n_samples, output_dim, output_dim). - """ - if self.inferred_psfs is None: - self.compute_psfs() - return self.inferred_psfs - - def get_psf(self, index) -> np.ndarray: - """Generate the generated PSF at a specific index. - - Returns - ------- - np.ndarray - The generated PSFs for the input source parameters. - Shape is (output_dim, output_dim). - """ - if self.inferred_psfs is None: - self.compute_psfs() - return self.inferred_psfs[index] From aeeefd253ad3150b32b7ed90266e095f8ab9c819 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sun, 8 Jun 2025 18:26:26 +0200 Subject: [PATCH 046/146] Add unit tests for psf_inference --- .../test_inference/psf_inference_test.py | 255 ++++++++++++++++++ 1 file changed, 255 insertions(+) diff --git a/src/wf_psf/tests/test_inference/psf_inference_test.py b/src/wf_psf/tests/test_inference/psf_inference_test.py index e69de29b..5709ed6d 100644 --- a/src/wf_psf/tests/test_inference/psf_inference_test.py +++ b/src/wf_psf/tests/test_inference/psf_inference_test.py @@ -0,0 +1,255 @@ +"""UNIT TESTS FOR PACKAGE MODULE: PSF Inference. + +This module contains unit tests for the wf_psf.inference.psf_inference module. + +:Author: Jennifer Pollack + +""" + +import numpy as np +import os +from pathlib import Path +import pytest +import tensorflow as tf +from types import SimpleNamespace +from unittest.mock import MagicMock, patch, PropertyMock +from wf_psf.inference.psf_inference import ( + InferenceConfigHandler, + PSFInference, + PSFInferenceEngine +) + +from wf_psf.utils.read_config import RecursiveNamespace + +@pytest.fixture +def mock_training_config(): + training_config = RecursiveNamespace( + training=RecursiveNamespace( + id_name="mock_id", + model_params=RecursiveNamespace( + model_name="mock_model", + output_Q=2, + output_dim=32 + ) + ) + ) + return training_config + +@pytest.fixture +def mock_inference_config(): + inference_config = RecursiveNamespace( + inference=RecursiveNamespace( + batch_size=16, + cycle=2, + configs=RecursiveNamespace( + trained_model_path='/path/to/trained/model', + model_subdir='psf_model', + trained_model_config_path='config/training_config.yaml', + data_config_path=None + ), + model_params=RecursiveNamespace( + output_Q=1, + output_dim=64 + ) + ) + ) + return inference_config + + +@pytest.fixture +def psf_test_setup(mock_inference_config): + num_sources = 2 + num_bins = 10 + output_dim = 32 + + mock_positions = tf.convert_to_tensor([[0.1, 0.1], [0.2, 0.2]], dtype=tf.float32) + mock_seds = tf.convert_to_tensor(np.random.rand(num_sources, 2, num_bins), dtype=tf.float32) + expected_psfs = np.random.rand(num_sources, output_dim, output_dim).astype(np.float32) + + inference = PSFInference( + "dummy_path.yaml", + x_field=[0.1, 0.2], + y_field=[0.1, 0.2], + seds=np.random.rand(num_sources, num_bins) + ) + inference._config_handler = MagicMock() + inference._config_handler.inference_config = mock_inference_config + inference._trained_psf_model = MagicMock() + + return { + "inference": inference, + "mock_positions": mock_positions, + "mock_seds": mock_seds, + "expected_psfs": expected_psfs, + "num_sources": num_sources, + "num_bins": num_bins, + "output_dim": output_dim + } + + +def test_set_config_paths(mock_inference_config): + """Test setting configuration paths.""" + + # Initialize handler and inject mock config + config_handler = InferenceConfigHandler("fake/path") + config_handler.inference_config = mock_inference_config + + # Call the method under test + config_handler.set_config_paths() + + # Assertions + assert config_handler.trained_model_path == Path("/path/to/trained/model") + assert config_handler.model_subdir == "psf_model" + assert config_handler.trained_model_config_path == Path("/path/to/trained/model/config/training_config.yaml") + assert config_handler.data_config_path == None + + +def test_overwrite_model_params(mock_training_config, mock_inference_config): + """Test that model_params can be overwritten.""" + # Mock the model_params object with some initial values + training_config = mock_training_config + inference_config = mock_inference_config + + InferenceConfigHandler.overwrite_model_params( + training_config, inference_config + ) + + # Assert that the model_params were overwritten correctly + assert training_config.training.model_params.output_Q == 1, "output_Q should be overwritten" + assert training_config.training.model_params.output_dim == 64, "output_dim should be overwritten" + + assert training_config.training.id_name == "mock_id", "id_name should not be overwritten" + + +def test_prepare_configs(mock_training_config, mock_inference_config): + """Test preparing configurations for inference.""" + # Mock the model_params object with some initial values + training_config = mock_training_config + inference_config = mock_inference_config + + # Make copy of the original training config model_params + original_model_params = mock_training_config.training.model_params + + # Instantiate PSFInference + psf_inf = PSFInference('/dummy/path.yaml') + + # Mock the config handler attribute with a mock InferenceConfigHandler + mock_config_handler = MagicMock(spec=InferenceConfigHandler) + mock_config_handler.training_config = training_config + mock_config_handler.inference_config = inference_config + + # Patch the overwrite_model_params to use the real static method + mock_config_handler.overwrite_model_params.side_effect = InferenceConfigHandler.overwrite_model_params + + psf_inf._config_handler = mock_config_handler + + # Run prepare_configs + psf_inf.prepare_configs() + + # Assert that the training model_params were updated + assert original_model_params.output_Q == 1 + assert original_model_params.output_dim == 64 + + +def test_config_handler_lazy_load(monkeypatch): + inference = PSFInference("dummy_path.yaml") + + called = {} + + class DummyHandler: + def load_configs(self): + called['load'] = True + self.inference_config = {} + self.training_config = {} + self.data_config = {} + def overwrite_model_params(self, *args): pass + + monkeypatch.setattr("wf_psf.inference.psf_inference.InferenceConfigHandler", lambda path: DummyHandler()) + + inference.prepare_configs() + + assert 'load' in called # Confirm lazy load happened + +def test_batch_size_positive(): + inference = PSFInference("dummy_path.yaml") + inference._config_handler = MagicMock() + inference._config_handler.inference_config = SimpleNamespace( + inference=SimpleNamespace(batch_size=4, model_params=SimpleNamespace(output_dim=32)) + ) + assert inference.batch_size == 4 + + +@patch.object(PSFInference, 'prepare_configs') +@patch('wf_psf.inference.psf_inference.load_trained_psf_model') +def test_load_inference_model(mock_load_trained_psf_model, mock_prepare_configs, mock_training_config, mock_inference_config): + + data_config = MagicMock() + mock_config_handler = MagicMock(spec=InferenceConfigHandler) + mock_config_handler.trained_model_path = "mock/path/to/model" + mock_config_handler.training_config = mock_training_config + mock_config_handler.inference_config = mock_inference_config + mock_config_handler.model_subdir = "psf_model" + mock_config_handler.data_config = data_config + + psf_inf = PSFInference("dummy_path.yaml") + psf_inf._config_handler = mock_config_handler + + psf_inf.load_inference_model() + + weights_path_pattern = os.path.join( + mock_config_handler.trained_model_path, + mock_config_handler.model_subdir, + f"{mock_config_handler.model_subdir}*_{mock_config_handler.training_config.training.model_params.model_name}*{mock_config_handler.training_config.training.id_name}_cycle{mock_config_handler.inference_config.inference.cycle}*" + ) + + # Assert calls to the mocked methods + mock_prepare_configs.assert_called_once() + mock_load_trained_psf_model.assert_called_once_with( + mock_config_handler.training_config, + mock_config_handler.data_config, + weights_path_pattern + ) + + +@patch.object(PSFInference, '_prepare_positions_and_seds') +@patch.object(PSFInferenceEngine, 'compute_psfs') +def test_run_inference(mock_compute_psfs, mock_prepare_positions_and_seds, psf_test_setup): + inference = psf_test_setup["inference"] + mock_positions = psf_test_setup["mock_positions"] + mock_seds = psf_test_setup["mock_seds"] + expected_psfs = psf_test_setup["expected_psfs"] + + mock_prepare_positions_and_seds.return_value = (mock_positions, mock_seds) + mock_compute_psfs.return_value = expected_psfs + + psfs = inference.run_inference() + + assert isinstance(psfs, np.ndarray) + assert psfs.shape == expected_psfs.shape + mock_prepare_positions_and_seds.assert_called_once() + mock_compute_psfs.assert_called_once_with(mock_positions, mock_seds) + + +@patch.object(PSFInference, '_prepare_positions_and_seds') +@patch.object(PSFInferenceEngine, 'compute_psfs') +def test_get_psfs_runs_inference(mock_compute_psfs, mock_prepare_positions_and_seds, psf_test_setup): + inference = psf_test_setup["inference"] + mock_positions = psf_test_setup["mock_positions"] + mock_seds = psf_test_setup["mock_seds"] + expected_psfs = psf_test_setup["expected_psfs"] + + mock_prepare_positions_and_seds.return_value = (mock_positions, mock_seds) + + def fake_compute_psfs(positions, seds): + inference.engine._inferred_psfs = expected_psfs + return expected_psfs + + mock_compute_psfs.side_effect = fake_compute_psfs + + psfs_1 = inference.get_psfs() + assert np.all(psfs_1 == expected_psfs) + + psfs_2 = inference.get_psfs() + assert np.all(psfs_2 == expected_psfs) + + assert mock_compute_psfs.call_count == 1 From e791771f90c4ecb474c41c25d7c3006a342b9172 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Thu, 12 Jun 2025 13:22:51 +0200 Subject: [PATCH 047/146] Bugfix: Ensure updated training_config.training.model_params are passed to simPSF Move call to prepare_configs() into run_inference() to ensure model_params are overwritten before preparing SEDs. --- src/wf_psf/inference/psf_inference.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 9012e3fc..01e7cf1a 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -147,7 +147,7 @@ def data_config(self): @property def simPSF(self): if self._simPSF is None: - self._simPSF = psf_models.simPSF(self.model_params) + self._simPSF = psf_models.simPSF(self.training_config.training.model_params) return self._simPSF @property @@ -171,9 +171,7 @@ def trained_psf_model(self): return self._trained_psf_model def load_inference_model(self): - # Prepare the configuration for inference - self.prepare_configs() - + """Load the trained PSF model based on the inference configuration.""" model_path = self.config_handler.trained_model_path model_dir = self.config_handler.model_subdir model_name = self.training_config.training.model_params.model_name @@ -228,6 +226,10 @@ def _prepare_positions_and_seds(self): def run_inference(self): """Run PSF inference and return the full PSF array.""" + # Prepare the configuration for inference + self.prepare_configs() + + # Prepare positions and SEDs for inference positions, sed_data = self._prepare_positions_and_seds() self.engine = PSFInferenceEngine( From 447f55442cd40b4cf7de2e6cc023d7c4a7b37ca8 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Thu, 12 Jun 2025 13:25:28 +0200 Subject: [PATCH 048/146] test(simPSF): add unit test to verify updated model_params are passed Also moves prepare_configs() assertion from test load_inference() to unit test for run_inference(). --- .../test_inference/psf_inference_test.py | 47 ++++++++++++++++--- 1 file changed, 40 insertions(+), 7 deletions(-) diff --git a/src/wf_psf/tests/test_inference/psf_inference_test.py b/src/wf_psf/tests/test_inference/psf_inference_test.py index 5709ed6d..4add460b 100644 --- a/src/wf_psf/tests/test_inference/psf_inference_test.py +++ b/src/wf_psf/tests/test_inference/psf_inference_test.py @@ -179,17 +179,15 @@ def test_batch_size_positive(): assert inference.batch_size == 4 -@patch.object(PSFInference, 'prepare_configs') @patch('wf_psf.inference.psf_inference.load_trained_psf_model') -def test_load_inference_model(mock_load_trained_psf_model, mock_prepare_configs, mock_training_config, mock_inference_config): +def test_load_inference_model(mock_load_trained_psf_model, mock_training_config, mock_inference_config): - data_config = MagicMock() mock_config_handler = MagicMock(spec=InferenceConfigHandler) mock_config_handler.trained_model_path = "mock/path/to/model" mock_config_handler.training_config = mock_training_config mock_config_handler.inference_config = mock_inference_config mock_config_handler.model_subdir = "psf_model" - mock_config_handler.data_config = data_config + mock_config_handler.data_config = MagicMock() psf_inf = PSFInference("dummy_path.yaml") psf_inf._config_handler = mock_config_handler @@ -203,17 +201,16 @@ def test_load_inference_model(mock_load_trained_psf_model, mock_prepare_configs, ) # Assert calls to the mocked methods - mock_prepare_configs.assert_called_once() mock_load_trained_psf_model.assert_called_once_with( mock_config_handler.training_config, mock_config_handler.data_config, weights_path_pattern ) - +@patch.object(PSFInference, 'prepare_configs') @patch.object(PSFInference, '_prepare_positions_and_seds') @patch.object(PSFInferenceEngine, 'compute_psfs') -def test_run_inference(mock_compute_psfs, mock_prepare_positions_and_seds, psf_test_setup): +def test_run_inference(mock_compute_psfs, mock_prepare_positions_and_seds, mock_prepare_configs, psf_test_setup): inference = psf_test_setup["inference"] mock_positions = psf_test_setup["mock_positions"] mock_seds = psf_test_setup["mock_seds"] @@ -228,6 +225,42 @@ def test_run_inference(mock_compute_psfs, mock_prepare_positions_and_seds, psf_t assert psfs.shape == expected_psfs.shape mock_prepare_positions_and_seds.assert_called_once() mock_compute_psfs.assert_called_once_with(mock_positions, mock_seds) + mock_prepare_configs.assert_called_once() + +@patch("wf_psf.inference.psf_inference.psf_models.simPSF") +def test_simpsf_uses_updated_model_params(mock_simpsf, mock_training_config, mock_inference_config): + """Test that simPSF uses the updated model parameters.""" + training_config = mock_training_config + inference_config = mock_inference_config + + # Set the expected output_Q + expected_output_Q = inference_config.inference.model_params.output_Q + training_config.training.model_params.output_Q = expected_output_Q + + # Create fake psf instance + fake_psf_instance = MagicMock() + fake_psf_instance.output_Q = expected_output_Q + mock_simpsf.return_value = fake_psf_instance + + mock_config_handler = MagicMock(spec=InferenceConfigHandler) + mock_config_handler.trained_model_path = "mock/path/to/model" + mock_config_handler.training_config = training_config + mock_config_handler.inference_config = inference_config + mock_config_handler.model_subdir = "psf_model" + mock_config_handler.data_config = MagicMock() + + modeller = PSFInference("dummy_path.yaml") + modeller._config_handler = mock_config_handler + + modeller.prepare_configs() + result = modeller.simPSF + + # Confirm simPSF was called once with the updated model_params + mock_simpsf.assert_called_once() + called_args, _ = mock_simpsf.call_args + model_params_passed = called_args[0] + assert model_params_passed.output_Q == expected_output_Q + assert result.output_Q == expected_output_Q @patch.object(PSFInference, '_prepare_positions_and_seds') From 89da2e06c8d4c38d7d66ebfcc537109cc59e7c14 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 12 Jun 2025 16:34:20 +0100 Subject: [PATCH 049/146] Bug: replace self.data_conf with self.data_config --- src/wf_psf/inference/psf_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 01e7cf1a..6f798245 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -36,7 +36,7 @@ def load_configs(self): if self.data_config_path is not None: # Load the data configuration - self.data_conf = read_conf(self.data_config_path) + self.data_config = read_conf(self.data_config_path) def set_config_paths(self): From 2cb7515a6b6c282e34b6644f2f7da35dc74cee6c Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Tue, 5 Aug 2025 19:59:51 +0200 Subject: [PATCH 050/146] Change logger.warnings to ValueErrors for missing fields in datasets & remove redundant check --- src/wf_psf/data/data_handler.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index 95eb2e11..e2575d74 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -161,16 +161,16 @@ def _validate_dataset_structure(self): if "positions" not in self.dataset: raise ValueError("Dataset missing required field: 'positions'") - if self.dataset_type == "train": + if self.dataset_type == "training": if "noisy_stars" not in self.dataset: - logger.warning("Missing 'noisy_stars' in 'train' dataset.") + raise ValueError(f"Missing required field 'noisy_stars' in {self.dataset_type} dataset.") elif self.dataset_type == "test": if "stars" not in self.dataset: - logger.warning("Missing 'stars' in 'test' dataset.") + raise ValueError(f"Missing required field 'stars' in {self.dataset_type} dataset.") elif self.dataset_type == "inference": pass else: - logger.warning(f"Unrecognized dataset_type: {self.dataset_type}") + raise ValueError(f"Unrecognized dataset_type: {self.dataset_type}") def _convert_dataset_to_tensorflow(self): """Convert dataset to TensorFlow tensors.""" @@ -179,12 +179,10 @@ def _convert_dataset_to_tensorflow(self): self.dataset["positions"], dtype=tf.float32 ) if self.dataset_type == "training": - if "noisy_stars" in self.dataset: self.dataset["noisy_stars"] = tf.convert_to_tensor( self.dataset["noisy_stars"], dtype=tf.float32 ) - else: - logger.warning(f"Missing 'noisy_stars' in {self.dataset_type} dataset.") + elif self.dataset_type == "test": if "stars" in self.dataset: self.dataset["stars"] = tf.convert_to_tensor( From 355f0301308eca07ebe26d52b5360401507a77f4 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Tue, 5 Aug 2025 20:00:21 +0200 Subject: [PATCH 051/146] Update unit tests with changes to data_handler.py --- .../tests/test_data/data_handler_test.py | 80 ++++--------------- 1 file changed, 16 insertions(+), 64 deletions(-) diff --git a/src/wf_psf/tests/test_data/data_handler_test.py b/src/wf_psf/tests/test_data/data_handler_test.py index 838db8cb..91226d66 100644 --- a/src/wf_psf/tests/test_data/data_handler_test.py +++ b/src/wf_psf/tests/test_data/data_handler_test.py @@ -14,21 +14,16 @@ def mock_sed(): # Create a fake SED with shape (n_wavelengths,) — match what your real SEDs look like return np.linspace(0.1, 1.0, 50) -def test_process_sed_data(data_params, simPSF): - # Test processing SED data without initialization - data_handler = DataHandler( - "training", data_params, simPSF, n_bins_lambda=10, load_data=False - ) - assert data_handler.sed_data is None # SED data should not be processed - def test_process_sed_data_auto_load(data_params, simPSF): # load_data=True → dataset is used and SEDs processed automatically data_handler = DataHandler( - "training", data_params, simPSF, n_bins_lambda=10, load_data=True + "training", data_params.training, simPSF, n_bins_lambda=10, load_data=True ) + assert data_handler.sed_data is not None + assert data_handler.sed_data.shape[1] == 10 # n_bins_lambda -def test_load_train_dataset(tmp_path, data_params, simPSF): +def test_load_train_dataset(tmp_path, simPSF): # Create a temporary directory and a temporary data file data_dir = tmp_path / "data" data_dir.mkdir() @@ -45,9 +40,7 @@ def test_load_train_dataset(tmp_path, data_params, simPSF): np.save(temp_data_dir, mock_dataset) # Initialize DataHandler instance - data_params = RecursiveNamespace( - training=RecursiveNamespace(data_dir=str(data_dir), file="train_data.npy") - ) + data_params = RecursiveNamespace(data_dir=str(data_dir), file="train_data.npy") n_bins_lambda = 10 data_handler = DataHandler("training", data_params, simPSF, n_bins_lambda, load_data=False) @@ -63,7 +56,7 @@ def test_load_train_dataset(tmp_path, data_params, simPSF): assert np.array_equal(data_handler.dataset["SEDs"], mock_dataset["SEDs"]) -def test_load_test_dataset(tmp_path, data_params, simPSF): +def test_load_test_dataset(tmp_path, simPSF): # Create a temporary directory and a temporary data file data_dir = tmp_path / "data" data_dir.mkdir() @@ -80,14 +73,12 @@ def test_load_test_dataset(tmp_path, data_params, simPSF): np.save(temp_data_dir, mock_dataset) # Initialize DataHandler instance - data_params = RecursiveNamespace( - test=RecursiveNamespace(data_dir=str(data_dir), file="test_data.npy") - ) - + data_params = RecursiveNamespace(data_dir=str(data_dir), file="test_data.npy") + n_bins_lambda = 10 data_handler = DataHandler( dataset_type="test", - data_params=data_params.test, + data_params=data_params, simPSF=simPSF, n_bins_lambda=n_bins_lambda, load_data=False) @@ -114,18 +105,16 @@ def test_validate_train_dataset_missing_noisy_stars_raises(tmp_path, simPSF): np.save(temp_data_file, mock_dataset) - data_params = RecursiveNamespace( - training=RecursiveNamespace(data_dir=str(data_dir), file="train_data.npy") - ) + data_params = RecursiveNamespace(data_dir=str(data_dir), file="train_data.npy") n_bins_lambda = 10 data_handler = DataHandler("training", data_params, simPSF, n_bins_lambda, load_data=False) - with patch("wf_psf.data.data_handler.logger.warning") as mock_warning: + with pytest.raises(ValueError, match="Missing required field 'noisy_stars' in training dataset."): data_handler.load_dataset() - mock_warning.assert_called_with("Missing 'noisy_stars' in training dataset.") + data_handler.validate_and_process_dataset() -def test_load_test_dataset_missing_stars(tmp_path, data_params, simPSF): +def test_load_test_dataset_missing_stars(tmp_path, simPSF): """Test that a warning is raised if 'stars' is missing in test data.""" data_dir = tmp_path / "data" data_dir.mkdir() @@ -138,51 +127,14 @@ def test_load_test_dataset_missing_stars(tmp_path, data_params, simPSF): np.save(temp_data_file, mock_dataset) - data_params = RecursiveNamespace( - test=RecursiveNamespace(data_dir=str(data_dir), file="test_data.npy") - ) + data_params = RecursiveNamespace(data_dir=str(data_dir), file="test_data.npy") n_bins_lambda = 10 data_handler = DataHandler("test", data_params, simPSF, n_bins_lambda, load_data=False) - with patch("wf_psf.data.data_handler.logger.warning") as mock_warning: + with pytest.raises(ValueError, match="Missing required field 'stars' in test dataset."): data_handler.load_dataset() - mock_warning.assert_called_with("Missing 'stars' in test dataset.") - - -def test_process_sed_data(data_params, simPSF): - mock_dataset = { - "positions": np.array([[1, 2], [3, 4]]), - "SEDs": np.array([ - [[0.1, 0.2], [0.3, 0.4]], - [[0.5, 0.6], [0.7, 0.8]] - ]), - # Missing 'noisy_stars' - } - # Initialize DataHandler instance - n_bins_lambda = 4 - data_handler = DataHandler("training", data_params, simPSF, n_bins_lambda, False) - - np.save(temp_data_file, mock_dataset) - - data_params = RecursiveNamespace( - data_dir=str(data_dir), file="train_data.npy" - ) - - data_handler = DataHandler( - dataset_type="train", - data_params=data_params, - simPSF=simPSF, - n_bins_lambda=10, - load_data=False - ) - - data_handler.load_dataset() - data_handler.process_sed_data(mock_dataset["SEDs"]) - - with patch("wf_psf.data.data_handler.logger.warning") as mock_warning: - data_handler._validate_dataset_structure() - mock_warning.assert_called_with("Missing 'noisy_stars' in 'train' dataset.") + data_handler.validate_and_process_dataset() def test_get_obs_positions(mock_data): From be7ddb6ee3cfddb91664cd5558f3610aee5603d7 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 16 May 2025 14:41:17 +0200 Subject: [PATCH 052/146] Refactor: reorganise modules, relocate utility functions, rename modules, update import statements and unit tests --- src/wf_psf/data/data_handler.py | 121 ++++++++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index e2575d74..ef267580 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -17,13 +17,17 @@ import wf_psf.utils.utils as utils import tensorflow as tf from fractions import Fraction +<<<<<<< HEAD from typing import Optional, Union +======= +>>>>>>> 80aad95 (Refactor: reorganise modules, relocate utility functions, rename modules, update import statements and unit tests) import logging logger = logging.getLogger(__name__) class DataHandler: +<<<<<<< HEAD """ DataHandler for WaveDiff PSF modeling. @@ -88,10 +92,54 @@ def __init__( and `process_sed_data()` is called with either the given `sed_data` or `dataset["SEDs"]`. 3. **Automatic loading mode** (`load_data=True` and no `dataset`): the dataset is loaded from disk using `data_params`, and SEDs are extracted and processed automatically. +======= + """Data Handler. + + This class manages loading and processing of training and testing data for use during PSF model training and validation. + It provides methods to access and preprocess the data. + + Parameters + ---------- + dataset_type: str + A string indicating type of data ("train" or "test"). + data_params: Recursive Namespace object + Recursive Namespace object containing training data parameters + simPSF: PSFSimulator + An instance of the PSFSimulator class for simulating a PSF. + n_bins_lambda: int + The number of bins in wavelength. + load_data: bool, optional + A flag used to control data loading steps. If True, data is loaded and processed + during initialization. If False, data loading is deferred until explicitly called. + + Attributes + ---------- + dataset_type: str + A string indicating the type of dataset ("train" or "test"). + data_params: Recursive Namespace object + A Recursive Namespace object containing training or test data parameters. + dataset: dict + A dictionary containing the loaded dataset, including positions and stars/noisy_stars. + simPSF: object + An instance of the SimPSFToolkit class for simulating PSF. + n_bins_lambda: int + The number of bins in wavelength. + sed_data: tf.Tensor + A TensorFlow tensor containing the SED data for training/testing. + load_data_on_init: bool, optional + A flag used to control data loading steps. If True, data is loaded and processed + during initialization. If False, data loading is deferred until explicitly called. + """ + + def __init__(self, dataset_type, data_params, simPSF, n_bins_lambda, load_data: bool=True): + """ + Initialize the dataset handler for PSF simulation. +>>>>>>> 80aad95 (Refactor: reorganise modules, relocate utility functions, rename modules, update import statements and unit tests) Parameters ---------- dataset_type : str +<<<<<<< HEAD One of {"train", "test", "inference"} indicating dataset usage. data_params : RecursiveNamespace Configuration object with paths, preprocessing options, and metadata. @@ -133,10 +181,36 @@ def __init__( self.load_dataset() self.process_sed_data(self.dataset["SEDs"]) self.validate_and_process_dataset() +======= + A string indicating the type of data ("train" or "test"). + data_params : Recursive Namespace object + A Recursive Namespace object containing parameters for both 'train' and 'test' datasets. + simPSF : PSFSimulator + An instance of the PSFSimulator class for simulating a PSF. + n_bins_lambda : int + The number of bins in wavelength. + load_data : bool, optional + A flag to control whether data should be loaded and processed during initialization. + If True, data is loaded and processed during initialization; if False, data loading + is deferred until explicitly called. + """ + self.dataset_type = dataset_type + self.data_params = data_params.__dict__[dataset_type] + self.simPSF = simPSF + self.n_bins_lambda = n_bins_lambda + self.load_data_on_init = load_data + if self.load_data_on_init: + self.load_dataset() + self.process_sed_data() +>>>>>>> 80aad95 (Refactor: reorganise modules, relocate utility functions, rename modules, update import statements and unit tests) else: self.dataset = None self.sed_data = None +<<<<<<< HEAD +======= + +>>>>>>> 80aad95 (Refactor: reorganise modules, relocate utility functions, rename modules, update import statements and unit tests) def load_dataset(self): """Load dataset. @@ -147,6 +221,7 @@ def load_dataset(self): os.path.join(self.data_params.data_dir, self.data_params.file), allow_pickle=True, )[()] +<<<<<<< HEAD def validate_and_process_dataset(self): """Validate the dataset structure and convert fields to TensorFlow tensors.""" @@ -184,10 +259,24 @@ def _convert_dataset_to_tensorflow(self): ) elif self.dataset_type == "test": +======= + self.dataset["positions"] = tf.convert_to_tensor( + self.dataset["positions"], dtype=tf.float32 + ) + if "train" == self.dataset_type: + if "noisy_stars" in self.dataset: + self.dataset["noisy_stars"] = tf.convert_to_tensor( + self.dataset["noisy_stars"], dtype=tf.float32 + ) + else: + logger.warning(f"Missing 'noisy_stars' in {self.dataset_type} dataset.") + elif "test" == self.dataset_type: +>>>>>>> 80aad95 (Refactor: reorganise modules, relocate utility functions, rename modules, update import statements and unit tests) if "stars" in self.dataset: self.dataset["stars"] = tf.convert_to_tensor( self.dataset["stars"], dtype=tf.float32 ) +<<<<<<< HEAD def process_sed_data(self, sed_data): @@ -223,11 +312,28 @@ def process_sed_data(self, sed_data): if sed_data is None: raise ValueError("SED data must be provided explicitly or via dataset.") +======= + else: + logger.warning(f"Missing 'stars' in {self.dataset_type} dataset.") + elif "inference" == self.dataset_type: + pass + + def process_sed_data(self): + """Process SED Data. + + A method to generate and process SED data. + + """ +>>>>>>> 80aad95 (Refactor: reorganise modules, relocate utility functions, rename modules, update import statements and unit tests) self.sed_data = [ utils.generate_SED_elems_in_tensorflow( _sed, self.simPSF, n_bins=self.n_bins_lambda, tf_dtype=tf.float64 ) +<<<<<<< HEAD for _sed in sed_data +======= + for _sed in self.dataset["SEDs"] +>>>>>>> 80aad95 (Refactor: reorganise modules, relocate utility functions, rename modules, update import statements and unit tests) ] self.sed_data = tf.convert_to_tensor(self.sed_data, dtype=tf.float32) self.sed_data = tf.transpose(self.sed_data, perm=[0, 2, 1]) @@ -285,7 +391,11 @@ def get_obs_positions(data): def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: """Extract specific star-related data from training and test datasets. +<<<<<<< HEAD +======= + +>>>>>>> 80aad95 (Refactor: reorganise modules, relocate utility functions, rename modules, update import statements and unit tests) This function retrieves and concatenates specific star-related data (e.g., stars, masks) from the star training and test datasets such as star stamps or masks, based on the provided keys. @@ -315,6 +425,7 @@ def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: """ # Ensure the requested keys exist in both training and test datasets missing_keys = [ +<<<<<<< HEAD key for key, dataset in [ (train_key, data.training_data.dataset), @@ -323,6 +434,12 @@ def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: if key not in dataset ] +======= + key for key, dataset in [(train_key, data.training_data.dataset), (test_key, data.test_data.dataset)] + if key not in dataset + ] + +>>>>>>> 80aad95 (Refactor: reorganise modules, relocate utility functions, rename modules, update import statements and unit tests) if missing_keys: raise KeyError(f"Missing keys in dataset: {missing_keys}") @@ -338,3 +455,7 @@ def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: # Concatenate and return return np.concatenate((train_data, test_data), axis=0) +<<<<<<< HEAD +======= + +>>>>>>> 80aad95 (Refactor: reorganise modules, relocate utility functions, rename modules, update import statements and unit tests) From 869900fafa9b0240fd42cf858f825ce0dbe90d75 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sun, 18 May 2025 12:22:27 +0200 Subject: [PATCH 053/146] Refactor data_handler with new utility functions to validate and process datasets and update docstrings --- src/wf_psf/data/data_handler.py | 115 ++++++++++++++++---------------- 1 file changed, 58 insertions(+), 57 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index ef267580..805cd5df 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -97,49 +97,74 @@ def __init__( This class manages loading and processing of training and testing data for use during PSF model training and validation. It provides methods to access and preprocess the data. + """ + DataHandler for WaveDiff PSF modeling. + + This class manages loading, preprocessing, and TensorFlow conversion of datasets used + for PSF model training, testing, and inference in the WaveDiff framework. Parameters ---------- - dataset_type: str - A string indicating type of data ("train" or "test"). - data_params: Recursive Namespace object - Recursive Namespace object containing training data parameters - simPSF: PSFSimulator - An instance of the PSFSimulator class for simulating a PSF. - n_bins_lambda: int - The number of bins in wavelength. - load_data: bool, optional - A flag used to control data loading steps. If True, data is loaded and processed - during initialization. If False, data loading is deferred until explicitly called. + dataset_type : str + Indicates the dataset mode ("train", "test", or "inference"). + data_params : RecursiveNamespace + Configuration object containing dataset parameters (e.g., file paths, preprocessing flags). + simPSF : PSFSimulator + An instance of the PSFSimulator class used to encode SEDs into a TensorFlow-compatible format. + n_bins_lambda : int + Number of wavelength bins used to discretize SEDs. + load_data : bool, optional + If True (default), loads and processes data during initialization. If False, data loading + must be triggered explicitly. + dataset : dict or list, optional + If provided, uses this pre-loaded dataset instead of triggering automatic loading. + sed_data : dict or list, optional + If provided, uses this SED data directly instead of extracting it from the dataset. Attributes ---------- - dataset_type: str - A string indicating the type of dataset ("train" or "test"). - data_params: Recursive Namespace object - A Recursive Namespace object containing training or test data parameters. - dataset: dict - A dictionary containing the loaded dataset, including positions and stars/noisy_stars. - simPSF: object - An instance of the SimPSFToolkit class for simulating PSF. - n_bins_lambda: int - The number of bins in wavelength. - sed_data: tf.Tensor - A TensorFlow tensor containing the SED data for training/testing. - load_data_on_init: bool, optional - A flag used to control data loading steps. If True, data is loaded and processed - during initialization. If False, data loading is deferred until explicitly called. + dataset_type : str + Indicates the dataset mode ("train", "test", or "inference"). + data_params : RecursiveNamespace + Configuration parameters for data access and structure. + simPSF : PSFSimulator + Simulator used to transform SEDs into TensorFlow-ready tensors. + n_bins_lambda : int + Number of wavelength bins in the SED representation. + load_data_on_init : bool + Whether data was loaded automatically during initialization. + dataset : dict + Loaded dataset including keys such as 'positions', 'stars', 'noisy_stars', or similar. + sed_data : tf.Tensor + TensorFlow-formatted SED data with shape [batch_size, n_bins_lambda, features]. """ - def __init__(self, dataset_type, data_params, simPSF, n_bins_lambda, load_data: bool=True): + def __init__( + self, + dataset_type, + data_params, + simPSF, + n_bins_lambda, + load_data: bool = True, + dataset: Optional[Union[dict, list]] = None, + sed_data: Optional[Union[dict, list]] = None, + ): """ - Initialize the dataset handler for PSF simulation. ->>>>>>> 80aad95 (Refactor: reorganise modules, relocate utility functions, rename modules, update import statements and unit tests) + Initialize the DataHandler for PSF dataset preparation. + + This constructor sets up the dataset handler used for PSF simulation tasks, + such as training, testing, or inference. It supports three modes of use: + + 1. **Manual mode** (`load_data=False`, no `dataset`): data loading and SED processing + must be triggered manually via `load_dataset()` and `process_sed_data()`. + 2. **Pre-loaded dataset mode** (`dataset` is provided): the given dataset is used directly, + and `process_sed_data()` is called with either the given `sed_data` or `dataset["SEDs"]`. + 3. **Automatic loading mode** (`load_data=True` and no `dataset`): the dataset is loaded + from disk using `data_params`, and SEDs are extracted and processed automatically. Parameters ---------- dataset_type : str -<<<<<<< HEAD One of {"train", "test", "inference"} indicating dataset usage. data_params : RecursiveNamespace Configuration object with paths, preprocessing options, and metadata. @@ -152,7 +177,7 @@ def __init__(self, dataset_type, data_params, simPSF, n_bins_lambda, load_data: dataset : dict or list, optional A pre-loaded dataset to use directly (overrides `load_data`). sed_data : array-like, optional - Pre-loaded SED data to use directly. If not provided but `dataset` is, + Pre-loaded SED data to use directly. If not provided but `dataset` is, SEDs are taken from `dataset["SEDs"]`. Raises @@ -162,7 +187,7 @@ def __init__(self, dataset_type, data_params, simPSF, n_bins_lambda, load_data: Notes ----- - - `self.dataset` and `self.sed_data` are both `None` if neither `dataset` nor + - `self.dataset` and `self.sed_data` are both `None` if neither `dataset` nor `load_data=True` is used. - TensorFlow conversion is performed at the end of initialization via `convert_dataset_to_tensorflow()`. """ @@ -181,28 +206,6 @@ def __init__(self, dataset_type, data_params, simPSF, n_bins_lambda, load_data: self.load_dataset() self.process_sed_data(self.dataset["SEDs"]) self.validate_and_process_dataset() -======= - A string indicating the type of data ("train" or "test"). - data_params : Recursive Namespace object - A Recursive Namespace object containing parameters for both 'train' and 'test' datasets. - simPSF : PSFSimulator - An instance of the PSFSimulator class for simulating a PSF. - n_bins_lambda : int - The number of bins in wavelength. - load_data : bool, optional - A flag to control whether data should be loaded and processed during initialization. - If True, data is loaded and processed during initialization; if False, data loading - is deferred until explicitly called. - """ - self.dataset_type = dataset_type - self.data_params = data_params.__dict__[dataset_type] - self.simPSF = simPSF - self.n_bins_lambda = n_bins_lambda - self.load_data_on_init = load_data - if self.load_data_on_init: - self.load_dataset() - self.process_sed_data() ->>>>>>> 80aad95 (Refactor: reorganise modules, relocate utility functions, rename modules, update import statements and unit tests) else: self.dataset = None self.sed_data = None @@ -263,13 +266,11 @@ def _convert_dataset_to_tensorflow(self): self.dataset["positions"] = tf.convert_to_tensor( self.dataset["positions"], dtype=tf.float32 ) + if "train" == self.dataset_type: - if "noisy_stars" in self.dataset: self.dataset["noisy_stars"] = tf.convert_to_tensor( self.dataset["noisy_stars"], dtype=tf.float32 ) - else: - logger.warning(f"Missing 'noisy_stars' in {self.dataset_type} dataset.") elif "test" == self.dataset_type: >>>>>>> 80aad95 (Refactor: reorganise modules, relocate utility functions, rename modules, update import statements and unit tests) if "stars" in self.dataset: From 149a02c80bf728c1730813b60245cd1d9ed58707 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sun, 18 May 2025 12:23:27 +0200 Subject: [PATCH 054/146] Update unit tests associated to changes in data_handler.py --- src/wf_psf/data/data_handler.py | 150 +----------------- .../tests/test_data/data_handler_test.py | 15 ++ 2 files changed, 17 insertions(+), 148 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index 805cd5df..94e074ad 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -17,17 +17,13 @@ import wf_psf.utils.utils as utils import tensorflow as tf from fractions import Fraction -<<<<<<< HEAD from typing import Optional, Union -======= ->>>>>>> 80aad95 (Refactor: reorganise modules, relocate utility functions, rename modules, update import statements and unit tests) import logging logger = logging.getLogger(__name__) class DataHandler: -<<<<<<< HEAD """ DataHandler for WaveDiff PSF modeling. @@ -92,76 +88,7 @@ def __init__( and `process_sed_data()` is called with either the given `sed_data` or `dataset["SEDs"]`. 3. **Automatic loading mode** (`load_data=True` and no `dataset`): the dataset is loaded from disk using `data_params`, and SEDs are extracted and processed automatically. -======= - """Data Handler. - - This class manages loading and processing of training and testing data for use during PSF model training and validation. - It provides methods to access and preprocess the data. - """ - DataHandler for WaveDiff PSF modeling. - - This class manages loading, preprocessing, and TensorFlow conversion of datasets used - for PSF model training, testing, and inference in the WaveDiff framework. - - Parameters - ---------- - dataset_type : str - Indicates the dataset mode ("train", "test", or "inference"). - data_params : RecursiveNamespace - Configuration object containing dataset parameters (e.g., file paths, preprocessing flags). - simPSF : PSFSimulator - An instance of the PSFSimulator class used to encode SEDs into a TensorFlow-compatible format. - n_bins_lambda : int - Number of wavelength bins used to discretize SEDs. - load_data : bool, optional - If True (default), loads and processes data during initialization. If False, data loading - must be triggered explicitly. - dataset : dict or list, optional - If provided, uses this pre-loaded dataset instead of triggering automatic loading. - sed_data : dict or list, optional - If provided, uses this SED data directly instead of extracting it from the dataset. - - Attributes - ---------- - dataset_type : str - Indicates the dataset mode ("train", "test", or "inference"). - data_params : RecursiveNamespace - Configuration parameters for data access and structure. - simPSF : PSFSimulator - Simulator used to transform SEDs into TensorFlow-ready tensors. - n_bins_lambda : int - Number of wavelength bins in the SED representation. - load_data_on_init : bool - Whether data was loaded automatically during initialization. - dataset : dict - Loaded dataset including keys such as 'positions', 'stars', 'noisy_stars', or similar. - sed_data : tf.Tensor - TensorFlow-formatted SED data with shape [batch_size, n_bins_lambda, features]. - """ - - def __init__( - self, - dataset_type, - data_params, - simPSF, - n_bins_lambda, - load_data: bool = True, - dataset: Optional[Union[dict, list]] = None, - sed_data: Optional[Union[dict, list]] = None, - ): - """ - Initialize the DataHandler for PSF dataset preparation. - - This constructor sets up the dataset handler used for PSF simulation tasks, - such as training, testing, or inference. It supports three modes of use: - - 1. **Manual mode** (`load_data=False`, no `dataset`): data loading and SED processing - must be triggered manually via `load_dataset()` and `process_sed_data()`. - 2. **Pre-loaded dataset mode** (`dataset` is provided): the given dataset is used directly, - and `process_sed_data()` is called with either the given `sed_data` or `dataset["SEDs"]`. - 3. **Automatic loading mode** (`load_data=True` and no `dataset`): the dataset is loaded - from disk using `data_params`, and SEDs are extracted and processed automatically. - + Parameters ---------- dataset_type : str @@ -210,10 +137,6 @@ def __init__( self.dataset = None self.sed_data = None -<<<<<<< HEAD -======= - ->>>>>>> 80aad95 (Refactor: reorganise modules, relocate utility functions, rename modules, update import statements and unit tests) def load_dataset(self): """Load dataset. @@ -224,7 +147,6 @@ def load_dataset(self): os.path.join(self.data_params.data_dir, self.data_params.file), allow_pickle=True, )[()] -<<<<<<< HEAD def validate_and_process_dataset(self): """Validate the dataset structure and convert fields to TensorFlow tensors.""" @@ -262,79 +184,26 @@ def _convert_dataset_to_tensorflow(self): ) elif self.dataset_type == "test": -======= - self.dataset["positions"] = tf.convert_to_tensor( - self.dataset["positions"], dtype=tf.float32 - ) - - if "train" == self.dataset_type: - self.dataset["noisy_stars"] = tf.convert_to_tensor( - self.dataset["noisy_stars"], dtype=tf.float32 - ) - elif "test" == self.dataset_type: ->>>>>>> 80aad95 (Refactor: reorganise modules, relocate utility functions, rename modules, update import statements and unit tests) if "stars" in self.dataset: self.dataset["stars"] = tf.convert_to_tensor( self.dataset["stars"], dtype=tf.float32 ) -<<<<<<< HEAD - - - def process_sed_data(self, sed_data): - """ - Generate and process SED (Spectral Energy Distribution) data. - - This method transforms raw SED inputs into TensorFlow tensors suitable for model input. - It generates wavelength-binned SED elements using the PSF simulator, converts the result - into a tensor, and transposes it to match the expected shape for training or inference. - - Parameters - ---------- - sed_data : list or array-like - A list or array of raw SEDs, where each SED is typically a vector of flux values - or coefficients. These will be processed using the PSF simulator. - - Raises - ------ - ValueError - If `sed_data` is None. - - Notes - ----- - The resulting tensor is stored in `self.sed_data` and has shape - `(num_samples, n_bins_lambda, n_components)`, where: - - `num_samples` is the number of SEDs, - - `n_bins_lambda` is the number of wavelength bins, - - `n_components` is the number of components per SED (e.g., filters or basis terms). - - The intermediate tensor is created with `tf.float64` for precision during generation, - but is converted to `tf.float32` after processing for use in training. - """ - if sed_data is None: - raise ValueError("SED data must be provided explicitly or via dataset.") - -======= else: logger.warning(f"Missing 'stars' in {self.dataset_type} dataset.") elif "inference" == self.dataset_type: pass - def process_sed_data(self): + def process_sed_data(self, sed_data): """Process SED Data. A method to generate and process SED data. """ ->>>>>>> 80aad95 (Refactor: reorganise modules, relocate utility functions, rename modules, update import statements and unit tests) self.sed_data = [ utils.generate_SED_elems_in_tensorflow( _sed, self.simPSF, n_bins=self.n_bins_lambda, tf_dtype=tf.float64 ) -<<<<<<< HEAD for _sed in sed_data -======= - for _sed in self.dataset["SEDs"] ->>>>>>> 80aad95 (Refactor: reorganise modules, relocate utility functions, rename modules, update import statements and unit tests) ] self.sed_data = tf.convert_to_tensor(self.sed_data, dtype=tf.float32) self.sed_data = tf.transpose(self.sed_data, perm=[0, 2, 1]) @@ -392,11 +261,7 @@ def get_obs_positions(data): def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: """Extract specific star-related data from training and test datasets. -<<<<<<< HEAD - -======= ->>>>>>> 80aad95 (Refactor: reorganise modules, relocate utility functions, rename modules, update import statements and unit tests) This function retrieves and concatenates specific star-related data (e.g., stars, masks) from the star training and test datasets such as star stamps or masks, based on the provided keys. @@ -426,7 +291,6 @@ def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: """ # Ensure the requested keys exist in both training and test datasets missing_keys = [ -<<<<<<< HEAD key for key, dataset in [ (train_key, data.training_data.dataset), @@ -435,12 +299,6 @@ def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: if key not in dataset ] -======= - key for key, dataset in [(train_key, data.training_data.dataset), (test_key, data.test_data.dataset)] - if key not in dataset - ] - ->>>>>>> 80aad95 (Refactor: reorganise modules, relocate utility functions, rename modules, update import statements and unit tests) if missing_keys: raise KeyError(f"Missing keys in dataset: {missing_keys}") @@ -456,7 +314,3 @@ def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: # Concatenate and return return np.concatenate((train_data, test_data), axis=0) -<<<<<<< HEAD -======= - ->>>>>>> 80aad95 (Refactor: reorganise modules, relocate utility functions, rename modules, update import statements and unit tests) diff --git a/src/wf_psf/tests/test_data/data_handler_test.py b/src/wf_psf/tests/test_data/data_handler_test.py index 91226d66..02214959 100644 --- a/src/wf_psf/tests/test_data/data_handler_test.py +++ b/src/wf_psf/tests/test_data/data_handler_test.py @@ -136,6 +136,21 @@ def test_load_test_dataset_missing_stars(tmp_path, simPSF): data_handler.load_dataset() data_handler.validate_and_process_dataset() + data_handler = DataHandler( + dataset_type="train", + data_params=data_params, + simPSF=simPSF, + n_bins_lambda=10, + load_data=False + ) + + data_handler.load_dataset() + data_handler.process_sed_data(mock_dataset["SEDs"]) + + with patch("wf_psf.data.data_handler.logger.warning") as mock_warning: + data_handler._validate_dataset_structure() + mock_warning.assert_called_with("Missing 'noisy_stars' in 'train' dataset.") + def test_get_obs_positions(mock_data): observed_positions = get_obs_positions(mock_data) From f0c7abefef4f5d5ce7e423286430969c9d9ecf7a Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 14:45:24 +0200 Subject: [PATCH 055/146] automatic formatting --- src/wf_psf/data/data_handler.py | 37 ++++++++++++++++++++++++++++----- 1 file changed, 32 insertions(+), 5 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index 94e074ad..7b5a6705 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -104,7 +104,7 @@ def __init__( dataset : dict or list, optional A pre-loaded dataset to use directly (overrides `load_data`). sed_data : array-like, optional - Pre-loaded SED data to use directly. If not provided but `dataset` is, + Pre-loaded SED data to use directly. If not provided but `dataset` is, SEDs are taken from `dataset["SEDs"]`. Raises @@ -114,7 +114,7 @@ def __init__( Notes ----- - - `self.dataset` and `self.sed_data` are both `None` if neither `dataset` nor + - `self.dataset` and `self.sed_data` are both `None` if neither `dataset` nor `load_data=True` is used. - TensorFlow conversion is performed at the end of initialization via `convert_dataset_to_tensorflow()`. """ @@ -194,11 +194,38 @@ def _convert_dataset_to_tensorflow(self): pass def process_sed_data(self, sed_data): - """Process SED Data. + """ + Generate and process SED (Spectral Energy Distribution) data. + + This method transforms raw SED inputs into TensorFlow tensors suitable for model input. + It generates wavelength-binned SED elements using the PSF simulator, converts the result + into a tensor, and transposes it to match the expected shape for training or inference. + + Parameters + ---------- + sed_data : list or array-like + A list or array of raw SEDs, where each SED is typically a vector of flux values + or coefficients. These will be processed using the PSF simulator. - A method to generate and process SED data. + Raises + ------ + ValueError + If `sed_data` is None. + Notes + ----- + The resulting tensor is stored in `self.sed_data` and has shape + `(num_samples, n_bins_lambda, n_components)`, where: + - `num_samples` is the number of SEDs, + - `n_bins_lambda` is the number of wavelength bins, + - `n_components` is the number of components per SED (e.g., filters or basis terms). + + The intermediate tensor is created with `tf.float64` for precision during generation, + but is converted to `tf.float32` after processing for use in training. """ + if sed_data is None: + raise ValueError("SED data must be provided explicitly or via dataset.") + self.sed_data = [ utils.generate_SED_elems_in_tensorflow( _sed, self.simPSF, n_bins=self.n_bins_lambda, tf_dtype=tf.float64 @@ -261,7 +288,7 @@ def get_obs_positions(data): def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: """Extract specific star-related data from training and test datasets. - + This function retrieves and concatenates specific star-related data (e.g., stars, masks) from the star training and test datasets such as star stamps or masks, based on the provided keys. From 44460e4491b070c6a844866b5d0e1b75d614fc54 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Wed, 18 Jun 2025 23:16:15 +0200 Subject: [PATCH 056/146] Refactor: add ZernikeInputs dataclass, ZernikeInputsFactory, helper methods for assembling zernike contributions according to run_type mode: training, simulation, or inference --- src/wf_psf/data/data_zernike_utils.py | 207 +++++++++++++++++++------- 1 file changed, 152 insertions(+), 55 deletions(-) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index 760adb11..648b260c 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -11,14 +11,89 @@ """ +from dataclasses import dataclass +from typing import Optional, Union import numpy as np import tensorflow as tf -from typing import Optional +from wf_psf.utils.read_config import RecursiveNamespace import logging logger = logging.getLogger(__name__) +@dataclass +class ZernikeInputs: + zernike_prior: Optional[np.ndarray] # true prior, if provided (e.g. from PDC) + centroid_dataset: Optional[Union[dict, 'RecursiveNamespace']] # only used in training/simulation + misalignment_positions: Optional[np.ndarray] # needed for CCD corrections + batch_size: int + + +class ZernikeInputsFactory: + @staticmethod + def build(data, run_type: str, model_params, prior: Optional[np.ndarray] = None) -> ZernikeInputs: + """Builds a ZernikeInputs dataclass instance based on run type and data. + + Parameters + ---------- + data : Union[dict, DataConfigHandler] + Dataset object containing star positions, priors, and optionally pixel data. + run_type : str + One of 'training', 'simulation', or 'inference'. + model_params : RecursiveNamespace + Model parameters, including flags for prior/corrections. + prior : Optional[np.ndarray] + An explicitly passed prior (overrides any inferred one if provided). + + Returns + ------- + ZernikeInputs + """ + centroid_dataset = None + positions = None + + if run_type in {"training", "simulation"}: + centroid_dataset = data # Assuming data is a DataConfigHandler or similar object containing train and test datasets + positions = np.concatenate( + [ + data.training_dataset["positions"], + data.test_dataset["positions"] + ], + axis=0, + ) + + if model_params.use_prior: + if prior is not None: + logger.warning( + "Zernike prior explicitly provided; ignoring dataset-based prior despite use_prior=True." + ) + else: + prior = get_np_zernike_prior(data) + + elif run_type == "inference": + centroid_dataset = None + positions = data["positions"] + + if model_params.use_prior: + # Try to extract prior from `data`, if present + prior = getattr(data, "zernike_prior", None) if not isinstance(data, dict) else data.get("zernike_prior") + + if prior is None: + logger.warning( + "model_params.use_prior=True but no prior found in inference data. Proceeding with None." + ) + + else: + raise ValueError(f"Unsupported run_type: {run_type}") + + return ZernikeInputs( + zernike_prior=prior, + centroid_dataset=centroid_dataset, + misalignment_positions=positions, + batch_size=model_params.batch_size, + ) + + def get_np_zernike_prior(data): """Get the zernike prior from the provided dataset. @@ -45,80 +120,102 @@ def get_np_zernike_prior(data): return zernike_prior - -def get_zernike_prior(model_params, data, batch_size: int=16): - """Get Zernike priors from the provided dataset. - - This method concatenates the Zernike priors from both the training - and test datasets. +def pad_contribution_to_order(contribution: np.ndarray, max_order: int) -> np.ndarray: + """Pad a Zernike contribution array to the max Zernike order.""" + current_order = contribution.shape[1] + pad_width = ((0, 0), (0, max_order - current_order)) + return np.pad(contribution, pad_width=pad_width, mode="constant", constant_values=0) + +def combine_zernike_contributions(contributions: list[np.ndarray]) -> np.ndarray: + """Combine multiple Zernike contributions, padding each to the max order before summing.""" + if not contributions: + raise ValueError("No contributions provided.") + + max_order = max(contrib.shape[1] for contrib in contributions) + n_samples = contributions[0].shape[0] + + combined = np.zeros((n_samples, max_order), dtype=np.float32) + for contrib in contributions: + padded = pad_contribution_to_order(contrib, max_order) + combined += padded + + return combined + +def assemble_zernike_contributions( + model_params, + zernike_prior=None, + centroid_dataset=None, + positions=None, + batch_size=16, +): + """ + Assemble the total Zernike contribution map by combining the prior, + centroid correction, and CCD misalignment correction. Parameters ---------- model_params : RecursiveNamespace - Object containing parameters for this PSF model class. - data : DataConfigHandler - Object containing training and test datasets. - batch_size : int, optional - The batch size to use when processing the stars. Default is 16. + Parameters controlling which contributions to apply. + zernike_prior : Optional[np.ndarray or tf.Tensor] + The precomputed Zernike prior (e.g., from PDC or another model). + centroid_dataset : Optional[object] + Dataset used to compute centroid correction. Must have both training and test sets. + positions : Optional[np.ndarray or tf.Tensor] + Positions used for computing CCD misalignment. Must be available in inference mode. + batch_size : int + Batch size for centroid correction. Returns ------- tf.Tensor - Tensor containing the observed positions of the stars. - - Notes - ----- - The Zernike prior are obtained by concatenating the Zernike priors - from both the training and test datasets along the 0th axis. - + A tensor representing the full Zernike contribution map. """ - # List of zernike contribution + zernike_contribution_list = [] - if model_params.use_prior: - logger.info("Reading in Zernike prior into Zernike contribution list...") - zernike_contribution_list.append(get_np_zernike_prior(data)) + # Prior + if model_params.use_prior and zernike_prior is not None: + logger.info("Adding Zernike prior...") + if isinstance(zernike_prior, np.ndarray): + zernike_prior = tf.convert_to_tensor(zernike_prior, dtype=tf.float32) + zernike_contribution_list.append(zernike_prior) + else: + logger.info("Skipping Zernike prior (not used or not provided).") - if model_params.correct_centroids: - logger.info("Adding centroid correction to Zernike contribution list...") + # Centroid correction (tip/tilt) + if model_params.correct_centroids and centroid_dataset is not None: + logger.info("Computing centroid correction...") + centroid_correction = compute_centroid_correction( + model_params, centroid_dataset, batch_size=batch_size + ) zernike_contribution_list.append( - compute_centroid_correction(model_params, data, batch_size) + tf.convert_to_tensor(centroid_correction, dtype=tf.float32) ) - - if model_params.add_ccd_misalignments: - logger.info("Adding CCD mis-alignments to Zernike contribution list...") - zernike_contribution_list.append(compute_ccd_misalignment(model_params, data)) - - if len(zernike_contribution_list) == 1: - zernike_contribution = zernike_contribution_list[0] else: - # Get max zk order - max_zk_order = np.max( - np.array( - [ - zk_contribution.shape[1] - for zk_contribution in zernike_contribution_list - ] - ) - ) + logger.info("Skipping centroid correction (not enabled or no dataset).") - zernike_contribution = np.zeros( - (zernike_contribution_list[0].shape[0], max_zk_order) + # CCD misalignment (focus term) + if model_params.add_ccd_misalignments and positions is not None: + logger.info("Computing CCD misalignment correction...") + ccd_misalignment = compute_ccd_misalignment(model_params, positions) + zernike_contribution_list.append( + tf.convert_to_tensor(ccd_misalignment, dtype=tf.float32) ) + else: + logger.info("Skipping CCD misalignment correction (not enabled or no positions).") - # Pad arrays to get the same length and add the final contribution - for it in range(len(zernike_contribution_list)): - current_zk_order = zernike_contribution_list[it].shape[1] - current_zernike_contribution = np.pad( - zernike_contribution_list[it], - pad_width=[(0, 0), (0, int(max_zk_order - current_zk_order))], - mode="constant", - constant_values=0, - ) + # If no contributions, return zeros tensor to avoid crashes + if not zernike_contribution_list: + logger.warning("No Zernike contributions found. Returning zero tensor.") + # Infer batch size and zernike order from model_params + n_samples = 1 + n_zks = getattr(model_params.param_hparams, "n_zernikes", 10) + return tf.zeros((n_samples, n_zks), dtype=tf.float32) - zernike_contribution += current_zernike_contribution + combined_zernike_prior = combine_zernike_contributions(zernike_contribution_list) - return tf.convert_to_tensor(zernike_contribution, dtype=tf.float32) + return tf.convert_to_tensor(combined_zernike_prior, dtype=tf.float32) + def shift_x_y_to_zk1_2_wavediff(dxy, tel_focal_length=24.5, tel_diameter=1.2): From 23a86cc2c8a37b5f7a3decffd9c5cb566c1ad6b8 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Wed, 18 Jun 2025 23:18:57 +0200 Subject: [PATCH 057/146] Update docstring describing data_conf types permitted --- src/wf_psf/psf_models/psf_model_loader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/wf_psf/psf_models/psf_model_loader.py b/src/wf_psf/psf_models/psf_model_loader.py index 1d2e267f..797be8fc 100644 --- a/src/wf_psf/psf_models/psf_model_loader.py +++ b/src/wf_psf/psf_models/psf_model_loader.py @@ -25,8 +25,8 @@ def load_trained_psf_model(training_conf, data_conf, weights_path_pattern): training_conf : RecursiveNamespace Configuration object containing model parameters and training hyperparameters. Supports attribute-style access to nested fields. - data_conf : RecursiveNamespace - Configuration object containing data-related parameters. + data_conf : RecursiveNamespace or dict + Configuration RecursiveNamespace object or a dictionary containing data parameters (e.g. pixel data, positions, masks, etc). weights_path_pattern : str Glob-style pattern used to locate the model weights file. From b2fc9281c455bfa2f9ed412d7beb654583577eed Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sat, 21 Jun 2025 18:45:41 +0200 Subject: [PATCH 058/146] Move imports to method to avoid circular imports --- src/wf_psf/data/centroids.py | 2 +- src/wf_psf/instrument/ccd_misalignments.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/wf_psf/data/centroids.py b/src/wf_psf/data/centroids.py index c95b303f..a5992608 100644 --- a/src/wf_psf/data/centroids.py +++ b/src/wf_psf/data/centroids.py @@ -10,7 +10,6 @@ import scipy.signal as scisig from wf_psf.data.data_handler import extract_star_data from fractions import Fraction -from wf_psf.data.data_zernike_utils import shift_x_y_to_zk1_2_wavediff import tensorflow as tf from typing import Optional @@ -127,6 +126,7 @@ def compute_zernike_tip_tilt( - This function processes all images at once using vectorized operations. - The Zernike coefficients are computed in the WaveDiff convention. """ + from wf_psf.data.data_zernike_utils import shift_x_y_to_zk1_2_wavediff # Vectorize the centroid computation centroid_estimator = CentroidEstimator( im=star_images, diff --git a/src/wf_psf/instrument/ccd_misalignments.py b/src/wf_psf/instrument/ccd_misalignments.py index 1b153cb3..d2bd2fa2 100644 --- a/src/wf_psf/instrument/ccd_misalignments.py +++ b/src/wf_psf/instrument/ccd_misalignments.py @@ -10,7 +10,7 @@ import numpy as np import matplotlib.path as mpltPath from scipy.spatial import KDTree -from wf_psf.data.data_preprocessing import defocus_to_zk4_wavediff +from wf_psf.data.data_handler import get_np_obs_positions def compute_ccd_misalignment(model_params, data): @@ -383,6 +383,7 @@ def get_zk4_from_position(self, pos): Zernike 4 value in wavediff convention corresponding to the delta z of the given input position `pos`. """ + from wf_psf.data.data_zernike_utils import defocus_to_zk4_wavediff dz = self.get_dz_from_position(pos) From 34da05df00b49cae8cf2e392f1f37d9c8ab8bf99 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sat, 21 Jun 2025 18:47:35 +0200 Subject: [PATCH 059/146] Remove batch_size arg from ZernikeInputsFactory ; raise ValueError to check Zernike contributions have the same number of samples --- src/wf_psf/data/data_zernike_utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index 648b260c..b03ff400 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -15,6 +15,8 @@ from typing import Optional, Union import numpy as np import tensorflow as tf +from wf_psf.data.centroids import compute_centroid_correction +from wf_psf.instrument.ccd_misalignments import compute_ccd_misalignment from wf_psf.utils.read_config import RecursiveNamespace import logging @@ -26,7 +28,6 @@ class ZernikeInputs: zernike_prior: Optional[np.ndarray] # true prior, if provided (e.g. from PDC) centroid_dataset: Optional[Union[dict, 'RecursiveNamespace']] # only used in training/simulation misalignment_positions: Optional[np.ndarray] # needed for CCD corrections - batch_size: int class ZernikeInputsFactory: @@ -89,8 +90,7 @@ def build(data, run_type: str, model_params, prior: Optional[np.ndarray] = None) return ZernikeInputs( zernike_prior=prior, centroid_dataset=centroid_dataset, - misalignment_positions=positions, - batch_size=model_params.batch_size, + misalignment_positions=positions ) @@ -133,6 +133,8 @@ def combine_zernike_contributions(contributions: list[np.ndarray]) -> np.ndarray max_order = max(contrib.shape[1] for contrib in contributions) n_samples = contributions[0].shape[0] + if any(c.shape[0] != n_samples for c in contributions): + raise ValueError("All contributions must have the same number of samples.") combined = np.zeros((n_samples, max_order), dtype=np.float32) for contrib in contributions: From c683aa29fb624a6bf27aec3ae1a381f7da203981 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sat, 21 Jun 2025 18:49:32 +0200 Subject: [PATCH 060/146] Add and set run_type attribute to DataConfigHandler object in TrainingConfigHandler constructor --- src/wf_psf/utils/configs_handler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/wf_psf/utils/configs_handler.py b/src/wf_psf/utils/configs_handler.py index a037069f..840f3822 100644 --- a/src/wf_psf/utils/configs_handler.py +++ b/src/wf_psf/utils/configs_handler.py @@ -188,6 +188,7 @@ def __init__(self, training_conf, file_handler): self.training_conf.training.training_hparams.batch_size, self.training_conf.training.load_data_on_init, ) + self.data_conf.run_type = "training" self.file_handler.copy_conffile_to_output_dir( self.training_conf.training.data_config ) From df9a4fe07ca574680f921b03e0f0d59366114dea Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sat, 21 Jun 2025 18:51:04 +0200 Subject: [PATCH 061/146] Add and set run_type attribute ; Replace var name end with end_sample in PSFInferenceEngine --- src/wf_psf/inference/psf_inference.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 6f798245..5be49343 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -162,6 +162,7 @@ def data_handler(self): load_data=False, dataset=None, ) + self._data_handler.run_type = "inference" return self._data_handler @property @@ -272,7 +273,7 @@ def compute_psfs(self, positions: tf.Tensor, sed_data: tf.Tensor) -> np.ndarray: counter = 0 while counter < n_samples: # Calculate the batch end element - end = min(counter + self.batch_size, n_samples) + end_sample = min(counter + self.batch_size, n_samples) # Define the batch positions batch_pos = positions[counter:end_sample, :] @@ -281,10 +282,10 @@ def compute_psfs(self, positions: tf.Tensor, sed_data: tf.Tensor) -> np.ndarray: # Generate PSFs for the current batch batch_psfs = self.trained_model(batch_inputs, training=False) - self.inferred_psfs[counter:end, :, :] = batch_psfs.numpy() + self.inferred_psfs[counter:end_sample, :, :] = batch_psfs.numpy() # Update the counter - counter = end + counter = end_sample return self._inferred_psfs From 5d12b55c3e2e77e951a6d9481f1ec11269586c5d Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sat, 21 Jun 2025 18:52:26 +0200 Subject: [PATCH 062/146] Refactor TFPhysicalPolychromaticField to lazy load property objects and attributes dynamically at run-time according to the run_type: training or inference --- .../psf_model_physical_polychromatic.py | 339 ++++++++---------- 1 file changed, 148 insertions(+), 191 deletions(-) diff --git a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index baec8923..971dbb63 100644 --- a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -11,7 +11,7 @@ import tensorflow as tf from tensorflow.python.keras.engine import data_adapter from wf_psf.data.data_handler import get_obs_positions -from wf_psf.data.data_zernike_utils import get_zernike_prior +from wf_psf.data.data_zernike_utils import ZernikeInputsFactory, assemble_zernike_contributions from wf_psf.psf_models import psf_models as psfm from wf_psf.psf_models.tf_modules.tf_layers import ( TFPolynomialZernikeField, @@ -98,8 +98,8 @@ def __init__(self, model_params, training_params, data, coeff_mat=None): A Recursive Namespace object containing parameters for this PSF model class. training_params: Recursive Namespace A Recursive Namespace object containing training hyperparameters for this PSF model class. - data: DataConfigHandler - A DataConfigHandler object that provides access to training and test datasets, as well as prior knowledge like Zernike coefficients. + data: DataConfigHandler or dict + A DataConfigHandler object or dict that provides access to single or multiple datasets (e.g. train and test), as well as prior knowledge like Zernike coefficients. coeff_mat: Tensor or None, optional Coefficient matrix defining the parametric PSF field model. @@ -109,204 +109,151 @@ def __init__(self, model_params, training_params, data, coeff_mat=None): Initialized instance of the TFPhysicalPolychromaticField class. """ super().__init__(model_params, training_params, coeff_mat) - self._initialize_parameters_and_layers( - model_params, training_params, data, coeff_mat - ) + self.model_params = model_params + self.training_params = training_params + self.data = data + self.run_type = data.run_type - def _initialize_parameters_and_layers( - self, - model_params: RecursiveNamespace, - training_params: RecursiveNamespace, - data: DataConfigHandler, - coeff_mat: Optional[tf.Tensor] = None, - ): - """Initialize Parameters of the PSF model. - - This method sets up the PSF model parameters, observational positions, - Zernike coefficients, and components required for the automatically - differentiable optical forward model. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - training_params: Recursive Namespace - A Recursive Namespace object containing training hyperparameters for this PSF model class. - data: DataConfigHandler object - A DataConfigHandler object providing access to training and tests datasets, as well as prior knowledge like Zernike coefficients. - coeff_mat: Tensor or None, optional - Initialization of the coefficient matrix defining the parametric psf field model. - - Notes - ----- - - Initializes Zernike parameters based on dataset priors. - - Configures the PSF model layers according to `model_params`. - - If `coeff_mat` is provided, the model coefficients are updated accordingly. - """ + # Initialize the model parameters and layers self.output_Q = model_params.output_Q - self.obs_pos = get_obs_positions(data) self.l2_param = model_params.param_hparams.l2_param - # Inputs: Save optimiser history Parametric model features - self.save_optim_history_param = ( - model_params.param_hparams.save_optim_history_param - ) - # Inputs: Save optimiser history NonParameteric model features - self.save_optim_history_nonparam = ( - model_params.nonparam_hparams.save_optim_history_nonparam - ) - self._initialize_zernike_parameters(model_params, data) - self._initialize_layers(model_params, training_params) + self.output_dim = model_params.output_dim + + # Initialise lazy loading of external Zernike prior + self._external_prior = None # Initialize the model parameters with non-default value if coeff_mat is not None: self.assign_coeff_matrix(coeff_mat) - def _initialize_zernike_parameters(self, model_params, data): - """Initialize the Zernike parameters. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - data: DataConfigHandler object - A DataConfigHandler object providing access to training and tests datasets, as well as prior knowledge like Zernike coefficients. - """ - self.zks_prior = get_zernike_prior(model_params, data, data.batch_size) - self.n_zks_total = max( - model_params.param_hparams.n_zernikes, - tf.cast(tf.shape(self.zks_prior)[1], tf.int32), + def _assemble_zernike_contributions(self): + zks_inputs = ZernikeInputsFactory.build( + data=self.data, + run_type=self.run_type, + model_params=self.model_params, + prior=self._external_prior if hasattr(self, "_external_prior") else None, ) - self.zernike_maps = psfm.generate_zernike_maps_3d( - self.n_zks_total, model_params.pupil_diameter + return assemble_zernike_contributions( + model_params=self.model_params, + zernike_prior=zks_inputs.zernike_prior, + centroid_dataset=zks_inputs.centroid_dataset, + positions=zks_inputs.misalignment_positions, + batch_size=self.training_params.batch_size, ) - def _initialize_layers(self, model_params, training_params): - """Initialize the layers of the PSF model. - - This method initializes the layers of the PSF model, including the physical layer, polynomial Zernike field, batch polychromatic layer, and non-parametric OPD layer. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - training_params: Recursive Namespace - A Recursive Namespace object containing training hyperparameters for this PSF model class. - coeff_mat: Tensor or None, optional - Initialization of the coefficient matrix defining the parametric psf field model. - - """ - self._initialize_physical_layer(model_params) - self._initialize_polynomial_Z_field(model_params) - self._initialize_Zernike_OPD(model_params) - self._initialize_batch_polychromatic_layer(model_params, training_params) - self._initialize_nonparametric_opd_layer(model_params, training_params) - - def _initialize_physical_layer(self, model_params): - """Initialize the physical layer of the PSF model. - - This method initializes the physical layer of the PSF model using parameters - specified in the `model_params` object. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - """ - self.tf_physical_layer = TFPhysicalLayer( - self.obs_pos, - self.zks_prior, - interpolation_type=model_params.interpolation_type, - interpolation_args=model_params.interpolation_args, + @property + def save_param_history(self) -> bool: + """Check if the model should save the optimization history for parametric features.""" + return getattr(self.model_params.param_hparams, "save_optim_history_param", False) + + @property + def save_nonparam_history(self) -> bool: + """Check if the model should save the optimization history for non-parametric features.""" + return getattr(self.model_params.nonparam_hparams, "save_optim_history_nonparam", False) + + + # === Lazy properties ===. + @property + def obs_pos(self): + """Lazy loading of the observation positions.""" + if not hasattr(self, "_obs_pos"): + if self.run_type == "training" or self.run_type == "simulation": + # Get the observation positions from the data handler + self._obs_pos = get_obs_positions(self.data) + elif self.run_type == "inference": + # For inference, we might not have a data handler, so we use the model parameters + self._obs_pos = self.data.dataset["positions"] + return self._obs_pos + + @property + def zks_total_contribution(self): + """Lazily load all Zernike contributions, including prior and corrections.""" + if not hasattr(self, "_zks_total_contribution"): + self._zks_total_contribution = self._assemble_zernike_contributions() + return self._zks_total_contribution + + @property + def n_zks_total(self): + """Get the total number of Zernike coefficients.""" + if not hasattr(self, "_n_zks_total"): + self._n_zks_total = max( + self.model_params.param_hparams.n_zernikes, + tf.cast(tf.shape(self.zks_total_contribution)[1], tf.int32), ) - - def _initialize_polynomial_Z_field(self, model_params): - """Initialize the polynomial Zernike field of the PSF model. - - This method initializes the polynomial Zernike field of the PSF model using - parameters specified in the `model_params` object. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - - """ - self.tf_poly_Z_field = TFPolynomialZernikeField( - x_lims=model_params.x_lims, - y_lims=model_params.y_lims, - random_seed=model_params.param_hparams.random_seed, - n_zernikes=model_params.param_hparams.n_zernikes, - d_max=model_params.param_hparams.d_max, + return self._n_zks_total + + @property + def zernike_maps(self): + """Lazy loading of the Zernike maps.""" + if not hasattr(self, "_zernike_maps"): + self._zernike_maps = psfm.generate_zernike_maps_3d( + self.n_zks_total, self.model_params.pupil_diameter ) + return self._zernike_maps + + @property + def tf_poly_Z_field(self): + """Lazy loading of the polynomial Zernike field layer.""" + if not hasattr(self, "_tf_poly_Z_field"): + self._tf_poly_Z_field = TFPolynomialZernikeField( + x_lims=self.model_params.x_lims, + y_lims=self.model_params.y_lims, + random_seed=self.model_params.param_hparams.random_seed, + n_zernikes=self.model_params.param_hparams.n_zernikes, + d_max=self.model_params.param_hparams.d_max, + ) + return self._tf_poly_Z_field - def _initialize_Zernike_OPD(self, model_params): - """Initialize the Zernike OPD field of the PSF model. - - This method initializes the Zernike Optical Path Difference - field of the PSF model using parameters specified in the `model_params` object. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - - """ - # Initialize the zernike to OPD layer - self.tf_zernike_OPD = TFZernikeOPD(zernike_maps=self.zernike_maps) - - def _initialize_batch_polychromatic_layer(self, model_params, training_params): - """Initialize the batch polychromatic PSF layer. - - This method initializes the batch opd to batch polychromatic PSF layer - using the provided `model_params` and `training_params`. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - training_params: Recursive Namespace - A Recursive Namespace object containing training hyperparameters for this PSF model class. - - - """ - self.batch_size = training_params.batch_size - self.obscurations = psfm.tf_obscurations( - pupil_diam=model_params.pupil_diameter, - N_filter=model_params.LP_filter_length, - rotation_angle=model_params.obscuration_rotation_angle, - ) - self.output_dim = model_params.output_dim + @tf_poly_Z_field.deleter + def tf_poly_Z_field(self): + del self._tf_poly_Z_field - self.tf_batch_poly_PSF = TFBatchPolychromaticPSF( - obscurations=self.obscurations, - output_Q=self.output_Q, - output_dim=self.output_dim, + @property + def tf_physical_layer(self): + """Lazy loading of the physical layer of the PSF model.""" + if not hasattr(self, "_tf_physical_layer"): + self._tf_physical_layer = TFPhysicalLayer( + self.obs_pos, + self.zks_total_contribution, + interpolation_type=self.model_params.interpolation_type, + interpolation_args=self.model_params.interpolation_args, ) + + @property + def tf_zernike_OPD(self): + """Lazy loading of the Zernike Optical Path Difference (OPD) layer.""" + if not hasattr(self, "_tf_zernike_OPD"): + self._tf_zernike_OPD = TFZernikeOPD(zernike_maps=self.zernike_maps) + return self._tf_zernike_OPD + + @property + def tf_batch_poly_PSF(self): + """Lazily initialize the batch polychromatic PSF layer.""" + if not hasattr(self, "_tf_batch_poly_PSF"): + obscurations = psfm.tf_obscurations( + pupil_diam=self.model_params.pupil_diameter, + N_filter=self.model_params.LP_filter_length, + rotation_angle=self.model_params.obscuration_rotation_angle, + ) - def _initialize_nonparametric_opd_layer(self, model_params, training_params): - """Initialize the non-parametric OPD layer. - - This method initializes the non-parametric OPD layer using the provided - `model_params` and `training_params`. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - training_params: Recursive Namespace - A Recursive Namespace object containing training hyperparameters for this PSF model class. - - """ - # self.d_max_nonparam = model_params.nonparam_hparams.d_max_nonparam - # self.opd_dim = tf.shape(self.zernike_maps)[1].numpy() - - self.tf_np_poly_opd = TFNonParametricPolynomialVariationsOPD( - x_lims=model_params.x_lims, - y_lims=model_params.y_lims, - random_seed=model_params.param_hparams.random_seed, - d_max=model_params.nonparam_hparams.d_max_nonparam, - opd_dim=tf.shape(self.zernike_maps)[1].numpy(), - ) + self._tf_batch_poly_PSF = TFBatchPolychromaticPSF( + obscurations=obscurations, + output_Q=self.output_Q, + output_dim=self.output_dim, + ) + return self._tf_batch_poly_PSF + + @property + def tf_np_poly_opd(self): + """Lazy loading of the non-parametric polynomial variations OPD layer.""" + if not hasattr(self, "_tf_np_poly_opd"): + self._tf_np_poly_opd = TFNonParametricPolynomialVariationsOPD( + x_lims=self.model_params.x_lims, + y_lims=self.model_params.y_lims, + d_max=self.model_params.nonparam_hparams.d_max_nonparam, + opd_dim=tf.shape(self.zernike_maps)[1].numpy(), + ) + return self._tf_np_poly_opd def get_coeff_matrix(self): """Get coefficient matrix.""" @@ -331,23 +278,21 @@ def assign_coeff_matrix(self, coeff_mat: Optional[tf.Tensor]) -> None: """ self.tf_poly_Z_field.assign_coeff_matrix(coeff_mat) + def set_output_Q(self, output_Q: float, output_dim: Optional[int] = None) -> None: """Set the output sampling rate (output_Q) for PSF generation. This method updates the `output_Q` parameter, which defines the resampling factor for generating PSFs at different resolutions - relative to the telescope's native sampling. It also allows optionally - updating `output_dim`, which sets the output resolution of the PSF model. + relative to the telescope's native sampling. It also allows optionally updating `output_dim`, which sets the output resolution of the PSF model. If `output_dim` is provided, the PSF model's output resolution is updated. - The method then reinitializes the batch polychromatic PSF generator - to reflect the updated parameters. + The method then reinitializes the batch polychromatic PSF generator to reflect the updated parameters. Parameters ---------- output_Q : float - The resampling factor that determines the output PSF resolution - relative to the telescope's native sampling. + The resampling factor that determines the output PSF resolution relative to the telescope's native sampling. output_dim : Optional[int], default=None The new output dimension for the PSF model. If `None`, the output dimension remains unchanged. @@ -359,6 +304,7 @@ def set_output_Q(self, output_Q: float, output_dim: Optional[int] = None) -> Non self.output_Q = output_Q if output_dim is not None: self.output_dim = output_dim + # Reinitialize the PSF batch poly generator self.tf_batch_poly_PSF = TFBatchPolychromaticPSF( obscurations=self.obscurations, @@ -472,12 +418,16 @@ def predict_step(self, data, evaluate_step=False): # Compute zernikes from parametric model and physical layer zks_coeffs = self.predict_zernikes(input_positions) + # Propagate to obtain the OPD param_opd_maps = self.tf_zernike_OPD(zks_coeffs) + # Calculate the non parametric part nonparam_opd_maps = self.tf_np_poly_opd(input_positions) + # Add the estimations opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) + # Compute the polychromatic PSFs poly_psfs = self.tf_batch_poly_PSF([opd_maps, packed_SEDs]) @@ -520,10 +470,13 @@ def predict_mono_psfs(self, input_positions, lambda_obs, phase_N): # Predict zernikes from parametric model and physical layer zks_coeffs = self.predict_zernikes(input_positions) + # Propagate to obtain the OPD param_opd_maps = self.tf_zernike_OPD(zks_coeffs) + # Calculate the non parametric part nonparam_opd_maps = self.tf_np_poly_opd(input_positions) + # Add the estimations opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) @@ -548,10 +501,13 @@ def predict_opd(self, input_positions): """ # Predict zernikes from parametric model and physical layer zks_coeffs = self.predict_zernikes(input_positions) + # Propagate to obtain the OPD param_opd_maps = self.tf_zernike_OPD(zks_coeffs) + # Calculate the non parametric part nonparam_opd_maps = self.tf_np_poly_opd(input_positions) + # Add the estimations opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) @@ -589,6 +545,7 @@ def compute_zernikes(self, input_positions): padded_zernike_params, padded_zernike_prior = self.pad_zernikes( zernike_params, zernike_prior ) + zernike_coeffs = tf.math.add(padded_zernike_params, padded_zernike_prior) return zernike_coeffs @@ -685,7 +642,7 @@ def call(self, inputs, training=True): packed_SEDs = inputs[1] # For the training - if training: + if training: # Compute zernikes from parametric model and physical layer zks_coeffs = self.compute_zernikes(input_positions) From 8d6a726fb52a91a85c8a43382f388cf6d85ab158 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sat, 21 Jun 2025 18:53:46 +0200 Subject: [PATCH 063/146] Update/Add unit tests to test refactoring changes to psf_model_physical_polychromatic.py and data_zernike_utils.py --- .../test_data/data_zernike_utils_test.py | 296 ++++++++++++++++-- .../psf_model_physical_polychromatic_test.py | 202 ++---------- 2 files changed, 310 insertions(+), 188 deletions(-) diff --git a/src/wf_psf/tests/test_data/data_zernike_utils_test.py b/src/wf_psf/tests/test_data/data_zernike_utils_test.py index 692624be..90994dde 100644 --- a/src/wf_psf/tests/test_data/data_zernike_utils_test.py +++ b/src/wf_psf/tests/test_data/data_zernike_utils_test.py @@ -1,40 +1,296 @@ import pytest import numpy as np +from unittest.mock import MagicMock, patch import tensorflow as tf from wf_psf.data.data_zernike_utils import ( - get_zernike_prior, + ZernikeInputs, + ZernikeInputsFactory, + get_np_zernike_prior, + pad_contribution_to_order, + combine_zernike_contributions, + assemble_zernike_contributions, compute_zernike_tip_tilt, ) from wf_psf.tests.test_data.test_data_utils import MockData, MockDataset +from types import SimpleNamespace as RecursiveNamespace -def test_get_zernike_prior(model_params, mock_data): - zernike_priors = get_zernike_prior(model_params, mock_data) - expected_shape = ( - 4, - 2, - ) # Assuming 2 Zernike priors for each dataset (training and test) - assert zernike_priors.shape == expected_shape +@pytest.fixture +def mock_model_params(): + return RecursiveNamespace( + use_prior=True, + correct_centroids=True, + add_ccd_misalignments=True, + param_hparams=RecursiveNamespace(n_zernikes=6), + ) + +@pytest.fixture +def dummy_prior(): + return np.ones((4, 6), dtype=np.float32) + +@pytest.fixture +def dummy_positions(): + return np.random.rand(4, 2).astype(np.float32) + +@pytest.fixture +def dummy_centroid_dataset(): + return {"training": "dummy_train", "test": "dummy_test"} + +def test_training_without_prior(mock_model_params): + mock_model_params.use_prior = False + data = MagicMock() + data.training_dataset = {"positions": np.ones((2, 2))} + data.test_dataset = {"positions": np.zeros((3, 2))} + + zinputs = ZernikeInputsFactory.build(data=data, run_type="training", model_params=mock_model_params) + + assert zinputs.centroid_dataset is data + assert zinputs.zernike_prior is None + np.testing.assert_array_equal( + zinputs.misalignment_positions, + np.concatenate([data.training_dataset["positions"], data.test_dataset["positions"]]) + ) + +@patch("wf_psf.data.data_zernike_utils.get_np_zernike_prior") +def test_training_with_dataset_prior(mock_get_prior, mock_model_params): + mock_model_params.use_prior = True + data = MagicMock() + data.training_dataset = {"positions": np.ones((2, 2))} + data.test_dataset = {"positions": np.zeros((2, 2))} + mock_get_prior.return_value = np.array([1.0, 2.0, 3.0]) + + zinputs = ZernikeInputsFactory.build(data=data, run_type="training", model_params=mock_model_params) + + assert zinputs.zernike_prior.tolist() == [1.0, 2.0, 3.0] + mock_get_prior.assert_called_once_with(data) + +def test_training_with_explicit_prior(mock_model_params, caplog): + mock_model_params.use_prior = True + data = MagicMock() + data.training_dataset = {"positions": np.ones((1, 2))} + data.test_dataset = {"positions": np.zeros((1, 2))} + + explicit_prior = np.array([9.0, 9.0, 9.0]) + + with caplog.at_level("WARNING"): + zinputs = ZernikeInputsFactory.build(data, "training", mock_model_params, prior=explicit_prior) + + assert "Zernike prior explicitly provided" in caplog.text + assert (zinputs.zernike_prior == explicit_prior).all() + + +def test_inference_with_dict_and_prior(mock_model_params): + mock_model_params.use_prior = True + data = { + "positions": np.ones((5, 2)), + "zernike_prior": np.array([42.0, 0.0]) + } + + zinputs = ZernikeInputsFactory.build(data, "inference", mock_model_params) + + assert zinputs.centroid_dataset is None + assert (zinputs.zernike_prior == data["zernike_prior"]).all() + np.testing.assert_array_equal(zinputs.misalignment_positions, data["positions"]) + + +def test_invalid_run_type(mock_model_params): + data = {"positions": np.ones((2, 2))} + with pytest.raises(ValueError, match="Unsupported run_type"): + ZernikeInputsFactory.build(data, "invalid_mode", mock_model_params) -def test_get_zernike_prior_dtype(model_params, mock_data): - zernike_priors = get_zernike_prior(model_params, mock_data) - assert zernike_priors.dtype == np.float32 +def test_get_np_zernike_prior(): + # Mock training and test data + training_prior = np.array([[1, 2, 3], [4, 5, 6]]) + test_prior = np.array([[7, 8, 9]]) -def test_get_zernike_prior_concatenation(model_params, mock_data): - zernike_priors = get_zernike_prior(model_params, mock_data) - expected_zernike_priors = tf.convert_to_tensor( - np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]]), dtype=tf.float32 + # Construct fake DataConfigHandler structure using RecursiveNamespace + data = RecursiveNamespace( + training_data=RecursiveNamespace(dataset={"zernike_prior": training_prior}), + test_data=RecursiveNamespace(dataset={"zernike_prior": test_prior}) ) - assert np.array_equal(zernike_priors, expected_zernike_priors) + expected_prior = np.concatenate((training_prior, test_prior), axis=0) + result = get_np_zernike_prior(data) -def test_get_zernike_prior_empty_data(model_params): - empty_data = MockData(np.array([]), np.array([]), np.array([]), np.array([])) - zernike_priors = get_zernike_prior(model_params, empty_data) - assert zernike_priors.shape == tf.TensorShape([0]) # Check for empty array shape + # Assert shape and values match expected + np.testing.assert_array_equal(result, expected_prior) + +def test_pad_contribution_to_order(): + # Input: batch of 2 samples, each with 3 Zernike coefficients + input_contribution = np.array([ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ]) + + max_order = 5 # Target size: pad to 5 coefficients + + expected_output = np.array([ + [1.0, 2.0, 3.0, 0.0, 0.0], + [4.0, 5.0, 6.0, 0.0, 0.0], + ]) + + padded = pad_contribution_to_order(input_contribution, max_order) + + assert padded.shape == (2, 5), "Output shape should match padded shape" + np.testing.assert_array_equal(padded, expected_output) + + +def test_no_padding_needed(): + """If current order equals max_order, return should be unchanged.""" + input_contribution = np.array([[1, 2, 3], [4, 5, 6]]) + max_order = 3 + output = pad_contribution_to_order(input_contribution, max_order) + assert output.shape == input_contribution.shape + np.testing.assert_array_equal(output, input_contribution) + +def test_padding_to_much_higher_order(): + """Pad from order 2 to order 10.""" + input_contribution = np.array([[1, 2], [3, 4]]) + max_order = 10 + expected_output = np.hstack([input_contribution, np.zeros((2, 8))]) + output = pad_contribution_to_order(input_contribution, max_order) + assert output.shape == (2, 10) + np.testing.assert_array_equal(output, expected_output) + +def test_empty_contribution(): + """Test behavior with empty input array (0 features).""" + input_contribution = np.empty((3, 0)) # 3 samples, 0 coefficients + max_order = 4 + expected_output = np.zeros((3, 4)) + output = pad_contribution_to_order(input_contribution, max_order) + assert output.shape == (3, 4) + np.testing.assert_array_equal(output, expected_output) + + +def test_zero_samples(): + """Test with zero samples (empty batch).""" + input_contribution = np.empty((0, 3)) # 0 samples, 3 coefficients + max_order = 5 + expected_output = np.empty((0, 5)) + output = pad_contribution_to_order(input_contribution, max_order) + assert output.shape == (0, 5) + np.testing.assert_array_equal(output, expected_output) + + +def test_combine_zernike_contributions_basic_case(): + """Combine two contributions with matching sample count and varying order.""" + contrib1 = np.array([[1, 2], [3, 4]]) # shape (2, 2) + contrib2 = np.array([[5], [6]]) # shape (2, 1) + expected = np.array([ + [1 + 5, 2 + 0], + [3 + 6, 4 + 0] + ]) # padded contrib2 to (2, 2) + result = combine_zernike_contributions([contrib1, contrib2]) + np.testing.assert_array_equal(result, expected) + +def test_combine_multiple_contributions(): + """Combine three contributions.""" + c1 = np.array([[1, 2, 3]]) # shape (1, 3) + c2 = np.array([[4, 5]]) # shape (1, 2) + c3 = np.array([[6]]) # shape (1, 1) + expected = np.array([[1+4+6, 2+5+0, 3+0+0]]) # shape (1, 3) + result = combine_zernike_contributions([c1, c2, c3]) + np.testing.assert_array_equal(result, expected) + +def test_empty_input_list(): + """Raise ValueError when input list is empty.""" + with pytest.raises(ValueError, match="No contributions provided."): + combine_zernike_contributions([]) + +def test_inconsistent_sample_count(): + """Raise error or produce incorrect shape if contributions have inconsistent sample counts.""" + c1 = np.array([[1, 2], [3, 4]]) # shape (2, 2) + c2 = np.array([[5, 6]]) # shape (1, 2) + with pytest.raises(ValueError): + combine_zernike_contributions([c1, c2]) + +def test_single_contribution(): + """Combining a single contribution should return the same array (no-op).""" + contrib = np.array([[7, 8, 9], [10, 11, 12]]) + result = combine_zernike_contributions([contrib]) + np.testing.assert_array_equal(result, contrib) + +def test_zero_order_contributions(): + """Contributions with 0 Zernike coefficients.""" + contrib1 = np.empty((2, 0)) # 2 samples, 0 coefficients + contrib2 = np.empty((2, 0)) + expected = np.empty((2, 0)) + result = combine_zernike_contributions([contrib1, contrib2]) + assert result.shape == (2, 0) + np.testing.assert_array_equal(result, expected) + +@patch("wf_psf.data.data_zernike_utils.compute_centroid_correction") +@patch("wf_psf.data.data_zernike_utils.compute_ccd_misalignment") +def test_full_contribution_combination(mock_ccd, mock_centroid, mock_model_params, dummy_prior, dummy_centroid_dataset, dummy_positions): + mock_centroid.return_value = np.full((4, 6), 2.0) + mock_ccd.return_value = np.full((4, 6), 3.0) + + result = assemble_zernike_contributions( + model_params=mock_model_params, + zernike_prior=dummy_prior, + centroid_dataset=dummy_centroid_dataset, + positions=dummy_positions + ) + + expected = dummy_prior + 2.0 + 3.0 + np.testing.assert_allclose(result.numpy(), expected) + +def test_prior_only(mock_model_params, dummy_prior): + mock_model_params.correct_centroids = False + mock_model_params.add_ccd_misalignments = False + + result = assemble_zernike_contributions( + model_params=mock_model_params, + zernike_prior=dummy_prior, + centroid_dataset=None, + positions=None + ) + + np.testing.assert_array_equal(result.numpy(), dummy_prior) + +def test_no_contributions_returns_zeros(): + model_params = RecursiveNamespace( + use_prior=False, + correct_centroids=False, + add_ccd_misalignments=False, + param_hparams=RecursiveNamespace(n_zernikes=8), + ) + + result = assemble_zernike_contributions(model_params) + + assert isinstance(result, tf.Tensor) + assert result.shape == (1, 8) + np.testing.assert_array_equal(result.numpy(), np.zeros((1, 8))) + +def test_prior_as_tensor(mock_model_params): + tensor_prior = tf.ones((4, 6), dtype=tf.float32) + + mock_model_params.correct_centroids = False + mock_model_params.add_ccd_misalignments = False + + result = assemble_zernike_contributions( + model_params=mock_model_params, + zernike_prior=tensor_prior + ) + + assert isinstance(result, tf.Tensor) + np.testing.assert_array_equal(result.numpy(), np.ones((4, 6))) + +@patch("wf_psf.data.data_zernike_utils.compute_centroid_correction") +def test_inconsistent_shapes_raises_error(mock_centroid, mock_model_params, dummy_prior, dummy_centroid_dataset): + mock_model_params.add_ccd_misalignments = False + mock_centroid.return_value = np.ones((5, 6)) # 5 samples instead of 4 + + with pytest.raises(ValueError, match="All contributions must have the same number of samples"): + assemble_zernike_contributions( + model_params=mock_model_params, + zernike_prior=dummy_prior, + centroid_dataset=dummy_centroid_dataset, + positions=None + ) def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_mask): diff --git a/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py b/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py index 7e967465..13d78667 100644 --- a/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py +++ b/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py @@ -9,6 +9,7 @@ import pytest import numpy as np import tensorflow as tf +from unittest.mock import PropertyMock from wf_psf.psf_models.models.psf_model_physical_polychromatic import ( TFPhysicalPolychromaticField, ) @@ -31,6 +32,7 @@ def zks_prior(): def mock_data(mocker): mock_instance = mocker.Mock(spec=DataConfigHandler) # Configure the mock data object to have the necessary attributes + mock_instance.run_type = "training" mock_instance.training_data = mocker.Mock() mock_instance.training_data.dataset = {"positions": np.array([[1, 2], [3, 4]])} mock_instance.test_data = mocker.Mock() @@ -46,145 +48,11 @@ def mock_model_params(mocker): model_params_mock.pupil_diameter = 256 return model_params_mock - -def test_initialize_parameters(mocker, mock_data, mock_model_params, zks_prior): - # Create mock objects for model_params, training_params - # model_params_mock = mocker.MagicMock() - mock_training_params = mocker.Mock() - - # Mock internal methods called during initialization - mocker.patch( - "wf_psf.psf_models.models.psf_model_physical_polychromatic.get_zernike_prior", - return_value=zks_prior, - ) - - mocker.patch( - "wf_psf.data.data_handler.get_obs_positions", return_value=True - ) - - # Create TFPhysicalPolychromaticField instance - field_instance = TFPhysicalPolychromaticField( - mock_model_params, mock_training_params, mock_data - ) - - mocker.patch.object(field_instance, "_initialize_zernike_parameters") - mocker.patch.object(field_instance, "_initialize_layers") - mocker.patch.object(field_instance, "assign_coeff_matrix") - - # Call the method being tested - field_instance._initialize_parameters_and_layers( - mock_model_params, mock_training_params, mock_data - ) - - # Check if internal methods were called with the correct arguments - field_instance._initialize_zernike_parameters.assert_called_once_with( - mock_model_params, mock_data - ) - field_instance._initialize_layers.assert_called_once_with( - mock_model_params, mock_training_params - ) - field_instance.assign_coeff_matrix.assert_not_called() # Because coeff_mat is None in this test - - -def test_initialize_zernike_parameters(mocker, mock_model_params, mock_data, zks_prior): - # Create training params mock object - mock_training_params = mocker.Mock() - - # Mock internal methods called during initialization - mocker.patch( - "wf_psf.psf_models.models.psf_model_physical_polychromatic.get_zernike_prior", - return_value=zks_prior, - ) - - # Create TFPhysicalPolychromaticField instance - field_instance = TFPhysicalPolychromaticField( - mock_model_params, mock_training_params, mock_data - ) - - # Assert that the attributes are set correctly - # assert field_instance.n_zernikes == mock_model_params.param_hparams.n_zernikes - assert np.array_equal(field_instance.zks_prior.numpy(), zks_prior.numpy()) - assert field_instance.n_zks_total == mock_model_params.param_hparams.n_zernikes - assert isinstance( - field_instance.zernike_maps, tf.Tensor - ) # Check if the returned value is a TensorFlow tensor - assert ( - field_instance.zernike_maps.dtype == tf.float32 - ) # Check if the data type of the tensor is float32 - - # Expected shape of the tensor based on the input parameters - expected_shape = ( - field_instance.n_zks_total, - mock_model_params.pupil_diameter, - mock_model_params.pupil_diameter, - ) - assert field_instance.zernike_maps.shape == expected_shape - - # Modify model_params to simulate zks_prior > n_zernikes - mock_model_params.param_hparams.n_zernikes = 2 - - # Call the method again to initialize the parameters - field_instance._initialize_zernike_parameters(mock_model_params, mock_data) - - assert field_instance.n_zks_total == tf.cast( - tf.shape(field_instance.zks_prior)[1], tf.int32 - ) - # Expected shape of the tensor based on the input parameters - expected_shape = ( - field_instance.n_zks_total, - mock_model_params.pupil_diameter, - mock_model_params.pupil_diameter, - ) - assert field_instance.zernike_maps.shape == expected_shape - - -def test_initialize_physical_layer_mocking( - mocker, mock_model_params, mock_data, zks_prior -): - # Create training params mock object - mock_training_params = mocker.Mock() - - # Mock internal methods called during initialization - mocker.patch( - "wf_psf.psf_models.models.psf_model_physical_polychromatic.get_zernike_prior", - return_value=zks_prior, - ) - - # Create a mock for the TFPhysicalLayer class - mock_physical_layer_class = mocker.patch( - "wf_psf.psf_models.models.psf_model_physical_polychromatic.TFPhysicalLayer" - ) - - # Create TFPhysicalPolychromaticField instance - field_instance = TFPhysicalPolychromaticField( - mock_model_params, mock_training_params, mock_data - ) - - # Assert that the TFPhysicalLayer class was called with the expected arguments - mock_physical_layer_class.assert_called_once_with( - field_instance.obs_pos, - field_instance.zks_prior, - interpolation_type=mock_model_params.interpolation_type, - interpolation_args=mock_model_params.interpolation_args, - ) - - @pytest.fixture def physical_layer_instance(mocker, mock_model_params, mock_data, zks_prior): # Create training params mock object mock_training_params = mocker.Mock() - # Mock internal methods called during initialization - mocker.patch( - "wf_psf.psf_models.models.psf_model_physical_polychromatic.get_zernike_prior", - return_value=zks_prior, - ) - - # Create a mock for the TFPhysicalLayer class - mocker.patch( - "wf_psf.psf_models.models.psf_model_physical_polychromatic.TFPhysicalLayer" - ) - # Create TFPhysicalPolychromaticField instance psf_field_instance = TFPhysicalPolychromaticField( mock_model_params, mock_training_params, mock_data @@ -202,8 +70,8 @@ def test_pad_zernikes_num_of_zernikes_equal(physical_layer_instance): zk_prior = tf.reshape(zk_prior, (1, 2, 1, 1)) # Reshaping tensor2 to (1, 2, 1, 1) # Reset n_zks_total attribute - physical_layer_instance.n_zks_total = max( - tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() + physical_layer_instance._n_zks_total = max( + tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() ) # Call the method under test padded_zk_param, padded_zk_prior = physical_layer_instance.pad_zernikes( @@ -224,7 +92,7 @@ def test_pad_zernikes_prior_greater_than_param(physical_layer_instance): zk_prior = tf.reshape(zk_prior, (1, 5, 1, 1)) # Reshaping tensor2 to (1, 5, 1, 1) # Reset n_zks_total attribute - physical_layer_instance.n_zks_total = max( + physical_layer_instance._n_zks_total = max( tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() ) @@ -247,7 +115,7 @@ def test_pad_zernikes_param_greater_than_prior(physical_layer_instance): zk_prior = tf.reshape(zk_prior, (1, 2, 1, 1)) # Reshaping tensor2 to (1, 4, 1, 1) # Reset n_zks_total attribute - physical_layer_instance.n_zks_total = max( + physical_layer_instance._n_zks_total = max( tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() ) @@ -262,43 +130,41 @@ def test_pad_zernikes_param_greater_than_prior(physical_layer_instance): def test_compute_zernikes(mocker, physical_layer_instance): - # Mock padded tensors - padded_zk_param = tf.constant( - [[[[10]], [[20]], [[30]], [[40]]]] - ) # Shape: (1, 4, 1, 1) - padded_zk_prior = tf.constant([[[[1]], [[2]], [[0]], [[0]]]]) # Shape: (1, 4, 1, 1) - - # Reset n_zks_total attribute - physical_layer_instance.n_zks_total = 4 # Assuming a specific value for simplicity + # Expected output of mock components + padded_zernike_param = tf.constant([[[[10]], [[20]], [[30]], [[40]]]], dtype=tf.float32) + padded_zernike_prior = tf.constant([[[[1]], [[2]], [[0]], [[0]]]], dtype=tf.float32) + expected_values = tf.constant([[[[11]], [[22]], [[30]], [[40]]]], dtype=tf.float32) - # Define the mock return values for tf_poly_Z_field and tf_physical_layer.call - padded_zernike_param = tf.constant( - [[[[10]], [[20]], [[30]], [[40]]]] - ) # Shape: (1, 4, 1, 1) - padded_zernike_prior = tf.constant( - [[[[1]], [[2]], [[0]], [[0]]]] - ) # Shape: (1, 4, 1, 1) + # Patch tf_poly_Z_field property + mock_tf_poly_Z_field = mocker.Mock(return_value=padded_zernike_param) + mocker.patch.object( + TFPhysicalPolychromaticField, + 'tf_poly_Z_field', + new_callable=PropertyMock, + return_value=mock_tf_poly_Z_field + ) + # Patch tf_physical_layer property + mock_tf_physical_layer = mocker.Mock() + mock_tf_physical_layer.call.return_value = padded_zernike_prior mocker.patch.object( - physical_layer_instance, "tf_poly_Z_field", return_value=padded_zk_param + TFPhysicalPolychromaticField, + 'tf_physical_layer', + new_callable=PropertyMock, + return_value=mock_tf_physical_layer ) - mocker.patch.object(physical_layer_instance, "call", return_value=padded_zk_prior) + + # Patch pad_zernikes instance method directly (this one isn't a property) mocker.patch.object( physical_layer_instance, - "pad_zernikes", - return_value=(padded_zernike_param, padded_zernike_prior), + 'pad_zernikes', + return_value=(padded_zernike_param, padded_zernike_prior) ) - # Call the method under test - zernike_coeffs = physical_layer_instance.compute_zernikes(tf.constant([[0.0, 0.0]])) - # Define the expected values - expected_values = tf.constant( - [[[[11]], [[22]], [[30]], [[40]]]] - ) # Shape: (1, 4, 1, 1) - - # Assert that the shapes are equal - assert zernike_coeffs.shape == expected_values.shape + # Run the test + zernike_coeffs = physical_layer_instance.compute_zernikes(tf.constant([[0.0, 0.0]])) - # Assert that the tensor values are equal - assert tf.reduce_all(tf.equal(zernike_coeffs, expected_values)) + # Assertions + tf.debugging.assert_equal(zernike_coeffs, expected_values) + assert zernike_coeffs.shape == expected_values.shape \ No newline at end of file From 8eb4454788e59c314b02c4ca661af1234fdd129f Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sun, 22 Jun 2025 00:08:51 +0200 Subject: [PATCH 064/146] Replace arg: data in compute_ccd_misalignment with positions --- src/wf_psf/instrument/ccd_misalignments.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/wf_psf/instrument/ccd_misalignments.py b/src/wf_psf/instrument/ccd_misalignments.py index d2bd2fa2..b2d06a20 100644 --- a/src/wf_psf/instrument/ccd_misalignments.py +++ b/src/wf_psf/instrument/ccd_misalignments.py @@ -13,22 +13,23 @@ from wf_psf.data.data_handler import get_np_obs_positions -def compute_ccd_misalignment(model_params, data): +def compute_ccd_misalignment(model_params, positions: np.ndarray) -> np.ndarray: """Compute CCD misalignment. Parameters ---------- model_params : RecursiveNamespace Object containing parameters for this PSF model class. - data : DataConfigHandler - Object containing training and test datasets. + positions : np.ndarray + Numpy array containing the positions of the stars in the focal plane. + Shape: (n_stars, 2), where n_stars is the number of stars and 2 corresponds to x and y coordinates. Returns ------- zernike_ccd_misalignment_array : np.ndarray Numpy array containing the Zernike contributions to model the CCD chip misalignments. """ - obs_positions = get_np_obs_positions(data) + obs_positions = positions ccd_misalignment_calculator = CCDMisalignmentCalculator( tiles_path=model_params.ccd_misalignments_input_path, From 5fa9aafd9f6549ead84d280056e365405053fa8d Mon Sep 17 00:00:00 2001 From: jeipollack Date: Sun, 22 Jun 2025 00:15:06 +0200 Subject: [PATCH 065/146] Correct object attributes for DataConfigHandler in ZernikeInputsFactory --- src/wf_psf/data/data_zernike_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index b03ff400..e50381c3 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -57,8 +57,8 @@ def build(data, run_type: str, model_params, prior: Optional[np.ndarray] = None) centroid_dataset = data # Assuming data is a DataConfigHandler or similar object containing train and test datasets positions = np.concatenate( [ - data.training_dataset["positions"], - data.test_dataset["positions"] + data.training_data.dataset["positions"], + data.test_data.dataset["positions"] ], axis=0, ) From e0361b7685612b4c749edc10e94349c8556829d6 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Sun, 22 Jun 2025 00:29:06 +0200 Subject: [PATCH 066/146] Add missing return for tf_physical_layer property --- src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index 971dbb63..7f19171e 100644 --- a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -218,6 +218,7 @@ def tf_physical_layer(self): interpolation_type=self.model_params.interpolation_type, interpolation_args=self.model_params.interpolation_args, ) + return self._tf_physical_layer @property def tf_zernike_OPD(self): From 9ec8bc00b7b5baf989fcc791c1d711071eca37f2 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Sun, 22 Jun 2025 18:02:17 +0200 Subject: [PATCH 067/146] Add tf_utils.py module to tf_modules subpackage --- src/wf_psf/psf_models/tf_modules/tf_utils.py | 45 ++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 src/wf_psf/psf_models/tf_modules/tf_utils.py diff --git a/src/wf_psf/psf_models/tf_modules/tf_utils.py b/src/wf_psf/psf_models/tf_modules/tf_utils.py new file mode 100644 index 00000000..a4795f89 --- /dev/null +++ b/src/wf_psf/psf_models/tf_modules/tf_utils.py @@ -0,0 +1,45 @@ +"""TensorFlow Utilities Module. + +Provides lightweight utility functions for safely converting and managing data types +within TensorFlow-based workflows. + +Includes: +- `ensure_tensor`: ensures inputs are TensorFlow tensors with specified dtype + +These tools are designed to support PSF model components, including lazy property evaluation, +data input validation, and type normalization. + +This module is intended for internal use in model layers and inference components to enforce +TensorFlow-compatible inputs. + +Authors: Jennifer Pollack +""" + +import tensorflow as tf +import numpy as np + +def ensure_tensor(input_array, dtype=tf.float32): + """ + Ensure the input is a TensorFlow tensor of the specified dtype. + + Parameters + ---------- + input_array : array-like, tf.Tensor, or np.ndarray + The input to convert. + dtype : tf.DType, optional + The desired TensorFlow dtype (default: tf.float32). + + Returns + ------- + tf.Tensor + A TensorFlow tensor with the specified dtype. + """ + if tf.is_tensor(input_array): + # If already a tensor, optionally cast dtype if different + if input_array.dtype != dtype: + return tf.cast(input_array, dtype) + return input_array + else: + # Convert numpy arrays or other types to tensor + return tf.convert_to_tensor(input_array, dtype=dtype) + From fdc8c6cb71483c780dbf30cf1491bed23753e97c Mon Sep 17 00:00:00 2001 From: jeipollack Date: Sun, 22 Jun 2025 18:03:50 +0200 Subject: [PATCH 068/146] Use ensure_tensor method from tf_utils.py to check/convert to tensorflow type; Remove get_obs_positions and replace with ensure_tensor method; add property tf_positions --- src/wf_psf/data/data_handler.py | 26 +++++++------------------- 1 file changed, 7 insertions(+), 19 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index 7b5a6705..edba744a 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -15,6 +15,7 @@ import os import numpy as np import wf_psf.utils.utils as utils +from wf_psf.psf_models.tf_modules.tf_utils import ensure_tensor import tensorflow as tf from fractions import Fraction from typing import Optional, Union @@ -137,6 +138,10 @@ def __init__( self.dataset = None self.sed_data = None + @property + def tf_positions(self): + return ensure_tensor(self.dataset["positions"]) + def load_dataset(self): """Load dataset. @@ -232,7 +237,8 @@ def process_sed_data(self, sed_data): ) for _sed in sed_data ] - self.sed_data = tf.convert_to_tensor(self.sed_data, dtype=tf.float32) + # Convert list of generated SED tensors to a single TensorFlow tensor of float32 dtype + self.sed_data = ensure_tensor(self.sed_data, dtype=tf.float32) self.sed_data = tf.transpose(self.sed_data, perm=[0, 2, 1]) @@ -268,24 +274,6 @@ def get_np_obs_positions(data): return obs_positions -def get_obs_positions(data): - """Get observed positions from the provided dataset. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - tf.Tensor - Tensor containing the observed positions of the stars. - """ - obs_positions = get_np_obs_positions(data) - - return tf.convert_to_tensor(obs_positions, dtype=tf.float32) - - def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: """Extract specific star-related data from training and test datasets. From d37b36d187ee4c9bf8e7c060530387bb62491268 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Sun, 22 Jun 2025 18:08:05 +0200 Subject: [PATCH 069/146] Refactor: Add eager-mode helpers and avoid lazy-loading obscurations in graph mode --- .../psf_model_physical_polychromatic.py | 76 +++++++++++-------- 1 file changed, 46 insertions(+), 30 deletions(-) diff --git a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index 7f19171e..0c5d3d06 100644 --- a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -10,7 +10,7 @@ from typing import Optional import tensorflow as tf from tensorflow.python.keras.engine import data_adapter -from wf_psf.data.data_handler import get_obs_positions +from wf_psf.data.data_handler import get_np_obs_positions from wf_psf.data.data_zernike_utils import ZernikeInputsFactory, assemble_zernike_contributions from wf_psf.psf_models import psf_models as psfm from wf_psf.psf_models.tf_modules.tf_layers import ( @@ -21,6 +21,7 @@ TFNonParametricPolynomialVariationsOPD, TFPhysicalLayer, ) +from wf_psf.psf_models.tf_modules.tf_utils import ensure_tensor from wf_psf.utils.read_config import RecursiveNamespace from wf_psf.utils.configs_handler import DataConfigHandler import logging @@ -112,7 +113,8 @@ def __init__(self, model_params, training_params, data, coeff_mat=None): self.model_params = model_params self.training_params = training_params self.data = data - self.run_type = data.run_type + self.run_type = self._get_run_type(data) + self.obs_pos = self.get_obs_pos() # Initialize the model parameters and layers self.output_Q = model_params.output_Q @@ -126,6 +128,22 @@ def __init__(self, model_params, training_params, data, coeff_mat=None): if coeff_mat is not None: self.assign_coeff_matrix(coeff_mat) + # Eagerly initialise tf_batch_poly_PSF + self.tf_batch_poly_PSF = self._build_tf_batch_poly_PSF() + + + def _get_run_type(self, data): + if hasattr(data, 'run_type'): + run_type = data.run_type + elif isinstance(data, dict) and 'run_type' in data: + run_type = data['run_type'] + else: + raise ValueError("data must have a 'run_type' attribute or key") + + if run_type not in {"training", "simulation", "inference"}: + raise ValueError(f"Unknown run_type: {run_type}") + return run_type + def _assemble_zernike_contributions(self): zks_inputs = ZernikeInputsFactory.build( data=self.data, @@ -150,21 +168,20 @@ def save_param_history(self) -> bool: def save_nonparam_history(self) -> bool: """Check if the model should save the optimization history for non-parametric features.""" return getattr(self.model_params.nonparam_hparams, "save_optim_history_nonparam", False) - - # === Lazy properties ===. - @property - def obs_pos(self): - """Lazy loading of the observation positions.""" - if not hasattr(self, "_obs_pos"): - if self.run_type == "training" or self.run_type == "simulation": - # Get the observation positions from the data handler - self._obs_pos = get_obs_positions(self.data) - elif self.run_type == "inference": - # For inference, we might not have a data handler, so we use the model parameters - self._obs_pos = self.data.dataset["positions"] - return self._obs_pos + def get_obs_pos(self): + assert self.run_type in {"training", "simulation", "inference"}, f"Unknown run_type: {self.run_type}" + + if self.run_type in {"training", "simulation"}: + raw_pos = get_np_obs_positions(self.data) + else: + raw_pos = self.data.dataset["positions"] + obs_pos = ensure_tensor(raw_pos, dtype=tf.float32) + + return obs_pos + + # === Lazy properties ===. @property def zks_total_contribution(self): """Lazily load all Zernike contributions, including prior and corrections.""" @@ -227,22 +244,20 @@ def tf_zernike_OPD(self): self._tf_zernike_OPD = TFZernikeOPD(zernike_maps=self.zernike_maps) return self._tf_zernike_OPD - @property - def tf_batch_poly_PSF(self): - """Lazily initialize the batch polychromatic PSF layer.""" - if not hasattr(self, "_tf_batch_poly_PSF"): - obscurations = psfm.tf_obscurations( + def _build_tf_batch_poly_PSF(self): + """Eagerly build the TFBatchPolychromaticPSF layer with numpy-based obscurations.""" + + obscurations = psfm.tf_obscurations( pupil_diam=self.model_params.pupil_diameter, N_filter=self.model_params.LP_filter_length, rotation_angle=self.model_params.obscuration_rotation_angle, ) - self._tf_batch_poly_PSF = TFBatchPolychromaticPSF( + return TFBatchPolychromaticPSF( obscurations=obscurations, output_Q=self.output_Q, output_dim=self.output_dim, ) - return self._tf_batch_poly_PSF @property def tf_np_poly_opd(self): @@ -646,23 +661,24 @@ def call(self, inputs, training=True): if training: # Compute zernikes from parametric model and physical layer zks_coeffs = self.compute_zernikes(input_positions) - - # Propagate to obtain the OPD + + # Parametric OPD maps from Zernikes param_opd_maps = self.tf_zernike_OPD(zks_coeffs) - # Add l2 loss on the parametric OPD + # Add L2 regularization loss on parametric OPD maps self.add_loss( - self.l2_param * tf.math.reduce_sum(tf.math.square(param_opd_maps)) + self.l2_param * tf.reduce_sum(tf.square(param_opd_maps)) ) - # Calculate the non parametric part + # Non-parametric correction nonparam_opd_maps = self.tf_np_poly_opd(input_positions) - # Add the estimations - opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) - + # Combine both contributions + opd_maps = tf.add(param_opd_maps, nonparam_opd_maps) + # Compute the polychromatic PSFs poly_psfs = self.tf_batch_poly_PSF([opd_maps, packed_SEDs]) + # For the inference else: # Compute predictions From 107621584d4e268623adc5dc02f02e984cb153ea Mon Sep 17 00:00:00 2001 From: jeipollack Date: Sun, 22 Jun 2025 18:09:01 +0200 Subject: [PATCH 070/146] Replace deprecated get_obs_positions with get_np_obs_positions and apply ensure_tensor to convert obs_pos to tensorflow float32 --- src/wf_psf/psf_models/tf_modules/tf_psf_field.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/wf_psf/psf_models/tf_modules/tf_psf_field.py b/src/wf_psf/psf_models/tf_modules/tf_psf_field.py index f39f8bdd..df18bfd3 100644 --- a/src/wf_psf/psf_models/tf_modules/tf_psf_field.py +++ b/src/wf_psf/psf_models/tf_modules/tf_psf_field.py @@ -16,8 +16,9 @@ TFPhysicalLayer, ) from wf_psf.psf_models.models.psf_model_semiparametric import TFSemiParametricField -from wf_psf.data.data_handler import get_obs_positions +from wf_psf.data.data_handler import get_np_obs_positions from wf_psf.psf_models import psf_models as psfm +from wf_psf.psf_models.tf_modules.tf_utils import ensure_tensor import logging logger = logging.getLogger(__name__) @@ -221,7 +222,7 @@ def __init__(self, model_params, training_params, data, coeff_mat): self.output_Q = model_params.output_Q # Inputs: TF_physical_layer - self.obs_pos = get_obs_positions(data) + self.obs_pos = ensure_tensor(get_np_obs_positions(data), dtype=tf.float32) self.zks_prior = get_ground_truth_zernike(data) self.n_zks_prior = tf.shape(self.zks_prior)[1].numpy() From b586cf64cd1aaca6196e5d180dbc8f44fa9cb477 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Mon, 23 Jun 2025 18:06:04 +0200 Subject: [PATCH 071/146] Remove tf.convert_to_tensor from all Zernike list contributors --- src/wf_psf/data/data_zernike_utils.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index e50381c3..309732d8 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -53,7 +53,7 @@ def build(data, run_type: str, model_params, prior: Optional[np.ndarray] = None) centroid_dataset = None positions = None - if run_type in {"training", "simulation"}: + if run_type in {"training", "simulation", "metrics"}: centroid_dataset = data # Assuming data is a DataConfigHandler or similar object containing train and test datasets positions = np.concatenate( [ @@ -178,9 +178,9 @@ def assemble_zernike_contributions( # Prior if model_params.use_prior and zernike_prior is not None: logger.info("Adding Zernike prior...") - if isinstance(zernike_prior, np.ndarray): - zernike_prior = tf.convert_to_tensor(zernike_prior, dtype=tf.float32) - zernike_contribution_list.append(zernike_prior) + if isinstance(zernike_prior, tf.Tensor): + zernike_prior = zernike_prior.numpy() + zernike_contribution_list.append(zernike_prior) else: logger.info("Skipping Zernike prior (not used or not provided).") @@ -190,9 +190,7 @@ def assemble_zernike_contributions( centroid_correction = compute_centroid_correction( model_params, centroid_dataset, batch_size=batch_size ) - zernike_contribution_list.append( - tf.convert_to_tensor(centroid_correction, dtype=tf.float32) - ) + zernike_contribution_list.append(centroid_correction) else: logger.info("Skipping centroid correction (not enabled or no dataset).") @@ -200,9 +198,7 @@ def assemble_zernike_contributions( if model_params.add_ccd_misalignments and positions is not None: logger.info("Computing CCD misalignment correction...") ccd_misalignment = compute_ccd_misalignment(model_params, positions) - zernike_contribution_list.append( - tf.convert_to_tensor(ccd_misalignment, dtype=tf.float32) - ) + zernike_contribution_list.append(ccd_misalignment) else: logger.info("Skipping CCD misalignment correction (not enabled or no positions).") From 78b3db1062ea96d2359353937d67774fc44bc317 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Mon, 23 Jun 2025 18:07:20 +0200 Subject: [PATCH 072/146] Add and set self.data_conf.run_type value to 'metrics' in MetricsConfigHandler --- src/wf_psf/utils/configs_handler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/wf_psf/utils/configs_handler.py b/src/wf_psf/utils/configs_handler.py index 840f3822..8de55714 100644 --- a/src/wf_psf/utils/configs_handler.py +++ b/src/wf_psf/utils/configs_handler.py @@ -263,6 +263,7 @@ def __init__(self, metrics_conf, file_handler, training_conf=None): self._file_handler = file_handler self.training_conf = training_conf self.data_conf = self._load_data_conf() + self.data_conf.run_type = "metrics" self.metrics_dir = self._file_handler.get_metrics_dir(self._file_handler._run_output_dir) self.trained_psf_model = self._load_trained_psf_model() From 733760d322aa14ad180bd200f67b64e6e6fa1329 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Mon, 23 Jun 2025 18:19:32 +0200 Subject: [PATCH 073/146] Eagerly precompute Zernike components; add support for 'metrics' run_type - Precomputed `zks_total_contribution` outside the TensorFlow graph and converted it to tf.float32. - Calculated `n_zks_total` from the contribution shape and param config. - Eagerly generated Zernike maps and stored as tf.float32. - Derived OPD dimension (`opd_dim`) from Zernike map shape. - Generated obscurations via `tf_obscurations` and stored as tf.complex64. - Added 'metrics' as a valid `run_type` alongside 'training', 'simulation', and 'inference'. - Adjusted `get_obs_pos()` logic to treat 'metrics' like 'training' and 'simulation' for position loading. These changes avoid runtime `.numpy()` calls inside `@tf.function` contexts, improve robustness across run modes, and ensure compatibility with training and evaluation pipelines. --- .../psf_model_physical_polychromatic.py | 68 ++++++++++++------- 1 file changed, 42 insertions(+), 26 deletions(-) diff --git a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index 0c5d3d06..08c5ce71 100644 --- a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -128,6 +128,32 @@ def __init__(self, model_params, training_params, data, coeff_mat=None): if coeff_mat is not None: self.assign_coeff_matrix(coeff_mat) + # Compute contributions once eagerly (outside graph) + zks_total_contribution_np = self._assemble_zernike_contributions().numpy() + self._zks_total_contribution = tf.convert_to_tensor(zks_total_contribution_np, dtype=tf.float32) + + # Compute n_zks_total as int + self._n_zks_total = max( + self.model_params.param_hparams.n_zernikes, + zks_total_contribution_np.shape[1] + ) + + # Precompute zernike maps as tf.float32 + self._zernike_maps = psfm.generate_zernike_maps_3d( + n_zernikes=self._n_zks_total, + pupil_diam=self.model_params.pupil_diameter + ) + + # Precompute OPD dimension + self._opd_dim = self._zernike_maps.shape[1] + + # Precompute obscurations as tf.complex64 + self._obscurations = psfm.tf_obscurations( + pupil_diam=self.model_params.pupil_diameter, + N_filter=self.model_params.LP_filter_length, + rotation_angle=self.model_params.obscuration_rotation_angle, + ) + # Eagerly initialise tf_batch_poly_PSF self.tf_batch_poly_PSF = self._build_tf_batch_poly_PSF() @@ -140,7 +166,7 @@ def _get_run_type(self, data): else: raise ValueError("data must have a 'run_type' attribute or key") - if run_type not in {"training", "simulation", "inference"}: + if run_type not in {"training", "simulation", "metrics", "inference"}: raise ValueError(f"Unknown run_type: {run_type}") return run_type @@ -170,9 +196,9 @@ def save_nonparam_history(self) -> bool: return getattr(self.model_params.nonparam_hparams, "save_optim_history_nonparam", False) def get_obs_pos(self): - assert self.run_type in {"training", "simulation", "inference"}, f"Unknown run_type: {self.run_type}" + assert self.run_type in {"training", "simulation", "metrics", "inference"}, f"Unknown run_type: {self.run_type}" - if self.run_type in {"training", "simulation"}: + if self.run_type in {"training", "simulation", "metrics"}: raw_pos = get_np_obs_positions(self.data) else: raw_pos = self.data.dataset["positions"] @@ -184,30 +210,26 @@ def get_obs_pos(self): # === Lazy properties ===. @property def zks_total_contribution(self): - """Lazily load all Zernike contributions, including prior and corrections.""" - if not hasattr(self, "_zks_total_contribution"): - self._zks_total_contribution = self._assemble_zernike_contributions() return self._zks_total_contribution - + @property def n_zks_total(self): """Get the total number of Zernike coefficients.""" - if not hasattr(self, "_n_zks_total"): - self._n_zks_total = max( - self.model_params.param_hparams.n_zernikes, - tf.cast(tf.shape(self.zks_total_contribution)[1], tf.int32), - ) return self._n_zks_total @property def zernike_maps(self): - """Lazy loading of the Zernike maps.""" - if not hasattr(self, "_zernike_maps"): - self._zernike_maps = psfm.generate_zernike_maps_3d( - self.n_zks_total, self.model_params.pupil_diameter - ) + """Get Zernike maps.""" return self._zernike_maps - + + @property + def opd_dim(self): + return self._opd_dim + + @property + def obscurations(self): + return self._obscurations + @property def tf_poly_Z_field(self): """Lazy loading of the polynomial Zernike field layer.""" @@ -246,15 +268,9 @@ def tf_zernike_OPD(self): def _build_tf_batch_poly_PSF(self): """Eagerly build the TFBatchPolychromaticPSF layer with numpy-based obscurations.""" - - obscurations = psfm.tf_obscurations( - pupil_diam=self.model_params.pupil_diameter, - N_filter=self.model_params.LP_filter_length, - rotation_angle=self.model_params.obscuration_rotation_angle, - ) return TFBatchPolychromaticPSF( - obscurations=obscurations, + obscurations=self.obscurations, output_Q=self.output_Q, output_dim=self.output_dim, ) @@ -267,7 +283,7 @@ def tf_np_poly_opd(self): x_lims=self.model_params.x_lims, y_lims=self.model_params.y_lims, d_max=self.model_params.nonparam_hparams.d_max_nonparam, - opd_dim=tf.shape(self.zernike_maps)[1].numpy(), + opd_dim=self.opd_dim, ) return self._tf_np_poly_opd From c515666aa6d28f2df8e6cdda1162fdaa1e6ad806 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Wed, 25 Jun 2025 11:13:33 +0200 Subject: [PATCH 074/146] Correct value error: train in dataset_type with training --- src/wf_psf/data/data_handler.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index edba744a..61796201 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -180,17 +180,13 @@ def _validate_dataset_structure(self): def _convert_dataset_to_tensorflow(self): """Convert dataset to TensorFlow tensors.""" - self.dataset["positions"] = tf.convert_to_tensor( - self.dataset["positions"], dtype=tf.float32 - ) - if self.dataset_type == "training": - self.dataset["noisy_stars"] = tf.convert_to_tensor( - self.dataset["noisy_stars"], dtype=tf.float32 - ) - + self.dataset["positions"] = ensure_tensor(self.dataset["positions"], dtype=tf.float32) + + if self.dataset_type == "train": + self.dataset["noisy_stars"] = ensure_tensor(self.dataset["noisy_stars"], dtype=tf.float32) elif self.dataset_type == "test": if "stars" in self.dataset: - self.dataset["stars"] = tf.convert_to_tensor( + self.dataset["stars"] = ensure_tensor( self.dataset["stars"], dtype=tf.float32 ) else: From 245d412cc2ab4152f1183e0acef54f3465ddd85e Mon Sep 17 00:00:00 2001 From: jeipollack Date: Wed, 25 Jun 2025 22:17:36 +0200 Subject: [PATCH 075/146] fix: pass random seed to TFNonParametricPolynomialVariationsOPD constructor in psf_model_physical_polychromatic.py --- src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index 08c5ce71..9979bcf1 100644 --- a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -282,6 +282,7 @@ def tf_np_poly_opd(self): self._tf_np_poly_opd = TFNonParametricPolynomialVariationsOPD( x_lims=self.model_params.x_lims, y_lims=self.model_params.y_lims, + random_seed=self.model_params.param_hparams.random_seed, d_max=self.model_params.nonparam_hparams.d_max_nonparam, opd_dim=self.opd_dim, ) From b91c00641a5407778d064b1f6f3f9bfe5372ba66 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 26 Jun 2025 14:09:58 +0200 Subject: [PATCH 076/146] Refactor to suppress TensorFlow debug msgs: replace lambda in call method with a proper function: find_position_indices enable batch processing for better graph optimization --- src/wf_psf/psf_models/tf_modules/tf_layers.py | 13 ++--- src/wf_psf/psf_models/tf_modules/tf_utils.py | 55 +++++++++++++++++++ 2 files changed, 61 insertions(+), 7 deletions(-) diff --git a/src/wf_psf/psf_models/tf_modules/tf_layers.py b/src/wf_psf/psf_models/tf_modules/tf_layers.py index 9d6f77c9..fdea0077 100644 --- a/src/wf_psf/psf_models/tf_modules/tf_layers.py +++ b/src/wf_psf/psf_models/tf_modules/tf_layers.py @@ -1,13 +1,13 @@ import tensorflow as tf import tensorflow_addons as tfa from wf_psf.psf_models.tf_modules.tf_modules import TFMonochromaticPSF +from wf_psf.psf_models.tf_modules.tf_utils import find_position_indices from wf_psf.utils.utils import calc_poly_position_mat import wf_psf.utils.utils as utils import logging logger = logging.getLogger(__name__) - class TFPolynomialZernikeField(tf.keras.layers.Layer): """Calculate the zernike coefficients for a given position. @@ -925,6 +925,7 @@ def interpolate_independent_Zk(self, positions): return interp_zks[:, :, tf.newaxis, tf.newaxis] + def call(self, positions): """Calculate the prior Zernike coefficients for a batch of positions. @@ -960,12 +961,10 @@ def call(self, positions): """ - def calc_index(idx_pos): - return tf.where(tf.equal(self.obs_pos, idx_pos))[0, 0] + # Find indices for all positions in one batch operation + idx = find_position_indices(self.obs_pos, positions) - # Calculate the indices of the input batch - indices = tf.map_fn(calc_index, positions, fn_output_signature=tf.int64) - # Recover the prior zernikes from the batch indexes - batch_zks = tf.gather(self.zks_prior, indices=indices, axis=0, batch_dims=0) + # Gather the corresponding Zernike coefficients + batch_zks = tf.gather(self.zks_prior, idx, axis=0) return batch_zks[:, :, tf.newaxis, tf.newaxis] diff --git a/src/wf_psf/psf_models/tf_modules/tf_utils.py b/src/wf_psf/psf_models/tf_modules/tf_utils.py index a4795f89..d0a2002c 100644 --- a/src/wf_psf/psf_models/tf_modules/tf_utils.py +++ b/src/wf_psf/psf_models/tf_modules/tf_utils.py @@ -18,6 +18,61 @@ import tensorflow as tf import numpy as np + +@tf.function +def find_position_indices(obs_pos, batch_positions): + """Find indices of batch positions within observed positions using vectorized operations. + + This function locates the indices of multiple query positions within a + reference set of observed positions using broadcasting and vectorized operations. + Each position in the batch must have an exact match in the observed positions. + + Parameters + ---------- + obs_pos : tf.Tensor + Reference positions tensor of shape (n_obs, 2), where n_obs is the number of + observed positions. Each row contains [x, y] coordinates. + batch_positions : tf.Tensor + Query positions tensor of shape (batch_size, 2), where batch_size is the number + of positions to look up. Each row contains [x, y] coordinates. + + Returns + ------- + indices : tf.Tensor + Tensor of shape (batch_size,) containing the indices of each batch position + within obs_pos. The dtype is tf.int64. + + Raises + ------ + tf.errors.InvalidArgumentError + If any position in batch_positions is not found in obs_pos. + + Notes + ----- + Uses exact equality matching - positions must match exactly. More efficient than + iterative lookups for multiple positions due to vectorized operations. + """ + # Shape: obs_pos (n_obs, 2), batch_positions (batch_size, 2) + # Expand for broadcasting: (1, n_obs, 2) and (batch_size, 1, 2) + obs_expanded = tf.expand_dims(obs_pos, 0) + pos_expanded = tf.expand_dims(batch_positions, 1) + + # Compare all positions at once: (batch_size, n_obs) + matches = tf.reduce_all(tf.equal(obs_expanded, pos_expanded), axis=2) + + # Find the index of the matching position for each batch item + # argmax returns the first True value's index along axis=1 + indices = tf.argmax(tf.cast(matches, tf.int32), axis=1) + + # Verify all positions were found + tf.debugging.assert_equal( + tf.reduce_all(tf.reduce_any(matches, axis=1)), + True, + message="Some positions not found in obs_pos" + ) + + return indices + def ensure_tensor(input_array, dtype=tf.float32): """ Ensure the input is a TensorFlow tensor of the specified dtype. From 89dddc5ac468b8dfd248c0d6a9b249341fa37e73 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 26 Jun 2025 17:24:28 +0200 Subject: [PATCH 077/146] Match old behaviour with conditional and float64 accumulation --- src/wf_psf/data/data_zernike_utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index 309732d8..6ef2c93c 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -131,12 +131,17 @@ def combine_zernike_contributions(contributions: list[np.ndarray]) -> np.ndarray if not contributions: raise ValueError("No contributions provided.") + if len(contributions) == 1: + return contributions[0] + max_order = max(contrib.shape[1] for contrib in contributions) n_samples = contributions[0].shape[0] + if any(c.shape[0] != n_samples for c in contributions): raise ValueError("All contributions must have the same number of samples.") - combined = np.zeros((n_samples, max_order), dtype=np.float32) + combined = np.zeros((n_samples, max_order)) + for contrib in contributions: padded = pad_contribution_to_order(contrib, max_order) combined += padded From 7c38895d9ffb00f529af4c35af73981a4e64252d Mon Sep 17 00:00:00 2001 From: jeipollack Date: Tue, 8 Jul 2025 18:26:02 +0200 Subject: [PATCH 078/146] Add helper to stack x/y field coordinates into (N, 2) positions array - Introduced get_positions() to convert x_field and y_field into a stacked (N, 2) array. - Updated data handler to pass positions and sed_data explicitly. - Includes validation for shape mismatches and None inputs. --- src/wf_psf/inference/psf_inference.py | 36 +++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 5be49343..f3e876c3 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -160,7 +160,8 @@ def data_handler(self): simPSF=self.simPSF, n_bins_lambda=self.n_bins_lambda, load_data=False, - dataset=None, + dataset={"positions": self.get_positions()}, + sed_data = self.seds, ) self._data_handler.run_type = "inference" return self._data_handler @@ -171,6 +172,37 @@ def trained_psf_model(self): self._trained_psf_model = self.load_inference_model() return self._trained_psf_model + def get_positions(self): + """ + Combine x_field and y_field into position pairs. + + Returns + ------- + numpy.ndarray + Array of shape (num_positions, 2) where each row contains [x, y] coordinates. + Returns None if either x_field or y_field is None. + + Raises + ------ + ValueError + If x_field and y_field have different lengths. + """ + if self.x_field is None or self.y_field is None: + return None + + x_arr = np.asarray(self.x_field) + y_arr = np.asarray(self.y_field) + + if x_arr.size != y_arr.size: + raise ValueError(f"x_field and y_field must have the same length. " + f"Got {x_arr.size} and {y_arr.size}") + + # Flatten arrays to handle any input shape, then stack + x_flat = x_arr.flatten() + y_flat = y_arr.flatten() + + return np.column_stack((x_flat, y_flat)) + def load_inference_model(self): """Load the trained PSF model based on the inference configuration.""" model_path = self.config_handler.trained_model_path @@ -187,7 +219,7 @@ def load_inference_model(self): # Load the trained PSF model return load_trained_psf_model( self.training_config, - self.data_config, + self.data_handler, weights_path_pattern, ) From 19c98a0287fade23e9d894505bd57bf3a3c10f73 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 10 Jul 2025 10:35:39 +0200 Subject: [PATCH 079/146] Add helper method to prepare dataset for inference & handle empty/None fields --- src/wf_psf/inference/psf_inference.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index f3e876c3..6bba2282 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -150,9 +150,17 @@ def simPSF(self): self._simPSF = psf_models.simPSF(self.training_config.training.model_params) return self._simPSF + + def _prepare_dataset_for_inference(self): + """Prepare dataset dictionary for inference, returning None if positions are invalid.""" + positions = self.get_positions() + if positions is None: + return None + return {"positions": positions} + @property def data_handler(self): - if self._data_handler is None: + if self._data_handler is None: # Instantiate the data handler self._data_handler = DataHandler( dataset_type="inference", @@ -160,7 +168,7 @@ def data_handler(self): simPSF=self.simPSF, n_bins_lambda=self.n_bins_lambda, load_data=False, - dataset={"positions": self.get_positions()}, + dataset=self._prepare_dataset_for_inference(), sed_data = self.seds, ) self._data_handler.run_type = "inference" @@ -193,6 +201,9 @@ def get_positions(self): x_arr = np.asarray(self.x_field) y_arr = np.asarray(self.y_field) + if x_arr.size == 0 or y_arr.size == 0: + return None + if x_arr.size != y_arr.size: raise ValueError(f"x_field and y_field must have the same length. " f"Got {x_arr.size} and {y_arr.size}") From 87dc863a315b03f9d5f027346566e823be694faf Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Tue, 22 Jul 2025 10:09:13 +0200 Subject: [PATCH 080/146] Update data_handler_test replacing "get_obs_positions" (deprecation) with "get_np_obs_positions" --- src/wf_psf/tests/test_data/data_handler_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/wf_psf/tests/test_data/data_handler_test.py b/src/wf_psf/tests/test_data/data_handler_test.py index 02214959..191f9f81 100644 --- a/src/wf_psf/tests/test_data/data_handler_test.py +++ b/src/wf_psf/tests/test_data/data_handler_test.py @@ -3,7 +3,7 @@ import tensorflow as tf from wf_psf.data.data_handler import ( DataHandler, - get_obs_positions, + get_np_obs_positions, extract_star_data, ) from wf_psf.utils.read_config import RecursiveNamespace @@ -152,8 +152,8 @@ def test_load_test_dataset_missing_stars(tmp_path, simPSF): mock_warning.assert_called_with("Missing 'noisy_stars' in 'train' dataset.") -def test_get_obs_positions(mock_data): - observed_positions = get_obs_positions(mock_data) +def test_get_np_obs_positions(mock_data): + observed_positions = get_np_obs_positions(mock_data) expected_positions = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) assert tf.reduce_all(tf.equal(observed_positions, expected_positions)) From f8bc8bd326f5f69e9219f5905a46198fd4c6c27f Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Wed, 6 Aug 2025 17:58:48 +0200 Subject: [PATCH 081/146] Remove deprecated code from rebase --- src/wf_psf/tests/test_data/data_handler_test.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/src/wf_psf/tests/test_data/data_handler_test.py b/src/wf_psf/tests/test_data/data_handler_test.py index 191f9f81..5a76c5af 100644 --- a/src/wf_psf/tests/test_data/data_handler_test.py +++ b/src/wf_psf/tests/test_data/data_handler_test.py @@ -136,21 +136,6 @@ def test_load_test_dataset_missing_stars(tmp_path, simPSF): data_handler.load_dataset() data_handler.validate_and_process_dataset() - data_handler = DataHandler( - dataset_type="train", - data_params=data_params, - simPSF=simPSF, - n_bins_lambda=10, - load_data=False - ) - - data_handler.load_dataset() - data_handler.process_sed_data(mock_dataset["SEDs"]) - - with patch("wf_psf.data.data_handler.logger.warning") as mock_warning: - data_handler._validate_dataset_structure() - mock_warning.assert_called_with("Missing 'noisy_stars' in 'train' dataset.") - def test_get_np_obs_positions(mock_data): observed_positions = get_np_obs_positions(mock_data) From 0a748a2531efb97c86bec6ace77108363ceb00bc Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Wed, 6 Aug 2025 17:59:32 +0200 Subject: [PATCH 082/146] Remove duplicated checks on arg existance --- src/wf_psf/data/data_handler.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index 61796201..09bb3914 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -185,14 +185,7 @@ def _convert_dataset_to_tensorflow(self): if self.dataset_type == "train": self.dataset["noisy_stars"] = ensure_tensor(self.dataset["noisy_stars"], dtype=tf.float32) elif self.dataset_type == "test": - if "stars" in self.dataset: - self.dataset["stars"] = ensure_tensor( - self.dataset["stars"], dtype=tf.float32 - ) - else: - logger.warning(f"Missing 'stars' in {self.dataset_type} dataset.") - elif "inference" == self.dataset_type: - pass + self.dataset["stars"] = ensure_tensor(self.dataset["stars"], dtype=tf.float32) def process_sed_data(self, sed_data): """ From 105bf1a6f1e60add6cd2fbf9ad79241a6372be74 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 8 Aug 2025 20:16:57 +0200 Subject: [PATCH 083/146] Improve Zernike prior handling in assemble_zernike_contributions - Support both NumPy arrays and TensorFlow tensors as valid inputs for zernike_prior. - Added an eager execution check before calling `.numpy()` to ensure safe conversion of tensors. - Raise informative errors for unsupported types or if eager mode is disabled. - Updated function docstring to reflect accepted types and behavior. - Removed extraneous whitespace and added clarifying comments in related code paths. --- src/wf_psf/data/data_zernike_utils.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index 6ef2c93c..8dd74e4d 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -62,7 +62,6 @@ def build(data, run_type: str, model_params, prior: Optional[np.ndarray] = None) ], axis=0, ) - if model_params.use_prior: if prior is not None: logger.warning( @@ -141,7 +140,7 @@ def combine_zernike_contributions(contributions: list[np.ndarray]) -> np.ndarray raise ValueError("All contributions must have the same number of samples.") combined = np.zeros((n_samples, max_order)) - + # Pad each contribution to the max order and sum them for contrib in contributions: padded = pad_contribution_to_order(contrib, max_order) combined += padded @@ -164,7 +163,8 @@ def assemble_zernike_contributions( model_params : RecursiveNamespace Parameters controlling which contributions to apply. zernike_prior : Optional[np.ndarray or tf.Tensor] - The precomputed Zernike prior (e.g., from PDC or another model). + The precomputed Zernike prior. Can be either a NumPy array or a TensorFlow tensor. + If a Tensor, will be converted to NumPy in eager mode. centroid_dataset : Optional[object] Dataset used to compute centroid correction. Must have both training and test sets. positions : Optional[np.ndarray or tf.Tensor] @@ -184,8 +184,17 @@ def assemble_zernike_contributions( if model_params.use_prior and zernike_prior is not None: logger.info("Adding Zernike prior...") if isinstance(zernike_prior, tf.Tensor): - zernike_prior = zernike_prior.numpy() + if tf.executing_eagerly(): + zernike_prior = zernike_prior.numpy() + else: + raise RuntimeError( + "Zernike prior is a TensorFlow tensor but eager execution is disabled. " + "Cannot call `.numpy()` outside of eager mode." + ) + elif isinstance(zernike_prior, np.ndarray): zernike_contribution_list.append(zernike_prior) + else: + raise TypeError("Unsupported zernike_prior type. Must be np.ndarray or tf.Tensor.") else: logger.info("Skipping Zernike prior (not used or not provided).") @@ -220,7 +229,6 @@ def assemble_zernike_contributions( return tf.convert_to_tensor(combined_zernike_prior, dtype=tf.float32) - def shift_x_y_to_zk1_2_wavediff(dxy, tel_focal_length=24.5, tel_diameter=1.2): """Compute Zernike 1(2) for a given shifts in x(y) in WaveDifff conventions. From 3ee97b432c59309728c39a5875f155ee8da6d365 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 8 Aug 2025 20:25:35 +0200 Subject: [PATCH 084/146] Fix bug where Tensor zernike_prior was not appended after eager conversion --- src/wf_psf/data/data_zernike_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index 8dd74e4d..1157be84 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -191,10 +191,10 @@ def assemble_zernike_contributions( "Zernike prior is a TensorFlow tensor but eager execution is disabled. " "Cannot call `.numpy()` outside of eager mode." ) - elif isinstance(zernike_prior, np.ndarray): - zernike_contribution_list.append(zernike_prior) - else: + + elif not isinstance(zernike_prior, np.ndarray): raise TypeError("Unsupported zernike_prior type. Must be np.ndarray or tf.Tensor.") + zernike_contribution_list.append(zernike_prior) else: logger.info("Skipping Zernike prior (not used or not provided).") From b94c703d5c6855cdaa2bb68e5c71b88d16c8129e Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 8 Aug 2025 20:28:46 +0200 Subject: [PATCH 085/146] Update unit tests with latest changes to fixtures and data_zernike_utils.py --- .../test_data/data_zernike_utils_test.py | 67 ++++++++++--------- 1 file changed, 35 insertions(+), 32 deletions(-) diff --git a/src/wf_psf/tests/test_data/data_zernike_utils_test.py b/src/wf_psf/tests/test_data/data_zernike_utils_test.py index 90994dde..cc6e22de 100644 --- a/src/wf_psf/tests/test_data/data_zernike_utils_test.py +++ b/src/wf_psf/tests/test_data/data_zernike_utils_test.py @@ -12,10 +12,9 @@ assemble_zernike_contributions, compute_zernike_tip_tilt, ) -from wf_psf.tests.test_data.test_data_utils import MockData, MockDataset +from wf_psf.tests.test_data.test_data_utils import MockData from types import SimpleNamespace as RecursiveNamespace - @pytest.fixture def mock_model_params(): return RecursiveNamespace( @@ -29,41 +28,45 @@ def mock_model_params(): def dummy_prior(): return np.ones((4, 6), dtype=np.float32) -@pytest.fixture -def dummy_positions(): - return np.random.rand(4, 2).astype(np.float32) @pytest.fixture def dummy_centroid_dataset(): return {"training": "dummy_train", "test": "dummy_test"} -def test_training_without_prior(mock_model_params): + +def test_training_without_prior(mock_model_params, mock_data): mock_model_params.use_prior = False - data = MagicMock() - data.training_dataset = {"positions": np.ones((2, 2))} - data.test_dataset = {"positions": np.zeros((3, 2))} - zinputs = ZernikeInputsFactory.build(data=data, run_type="training", model_params=mock_model_params) + # Clear priors to simulate no prior being used + mock_data.training_data.dataset.pop("zernike_prior", None) + mock_data.test_data.dataset.pop("zernike_prior", None) - assert zinputs.centroid_dataset is data + zinputs = ZernikeInputsFactory.build(data=mock_data, run_type="training", model_params=mock_model_params) + + assert zinputs.centroid_dataset is mock_data assert zinputs.zernike_prior is None - np.testing.assert_array_equal( - zinputs.misalignment_positions, - np.concatenate([data.training_dataset["positions"], data.test_dataset["positions"]]) - ) -@patch("wf_psf.data.data_zernike_utils.get_np_zernike_prior") -def test_training_with_dataset_prior(mock_get_prior, mock_model_params): + expected_positions = np.concatenate([ + mock_data.training_data.dataset["positions"], + mock_data.test_data.dataset["positions"] + ]) + np.testing.assert_array_equal(zinputs.misalignment_positions, expected_positions) + + +def test_training_with_dataset_prior(mock_model_params, mock_data): mock_model_params.use_prior = True - data = MagicMock() - data.training_dataset = {"positions": np.ones((2, 2))} - data.test_dataset = {"positions": np.zeros((2, 2))} - mock_get_prior.return_value = np.array([1.0, 2.0, 3.0]) - zinputs = ZernikeInputsFactory.build(data=data, run_type="training", model_params=mock_model_params) + zinputs = ZernikeInputsFactory.build(data=mock_data, run_type="training", model_params=mock_model_params) + + expected_priors = np.concatenate( + ( + mock_data.training_data.dataset["zernike_prior"], + mock_data.test_data.dataset["zernike_prior"], + ), + axis=0, + ) + np.testing.assert_array_equal(zinputs.zernike_prior, expected_priors) - assert zinputs.zernike_prior.tolist() == [1.0, 2.0, 3.0] - mock_get_prior.assert_called_once_with(data) def test_training_with_explicit_prior(mock_model_params, caplog): mock_model_params.use_prior = True @@ -224,17 +227,18 @@ def test_zero_order_contributions(): @patch("wf_psf.data.data_zernike_utils.compute_centroid_correction") @patch("wf_psf.data.data_zernike_utils.compute_ccd_misalignment") -def test_full_contribution_combination(mock_ccd, mock_centroid, mock_model_params, dummy_prior, dummy_centroid_dataset, dummy_positions): +def test_full_contribution_combination(mock_ccd, mock_centroid, mock_model_params, dummy_prior, dummy_centroid_dataset): mock_centroid.return_value = np.full((4, 6), 2.0) mock_ccd.return_value = np.full((4, 6), 3.0) + dummy_positions = np.full((4, 6), 1.0) result = assemble_zernike_contributions( model_params=mock_model_params, zernike_prior=dummy_prior, centroid_dataset=dummy_centroid_dataset, - positions=dummy_positions + positions = dummy_positions ) - + expected = dummy_prior + 2.0 + 3.0 np.testing.assert_allclose(result.numpy(), expected) @@ -275,7 +279,7 @@ def test_prior_as_tensor(mock_model_params): model_params=mock_model_params, zernike_prior=tensor_prior ) - + assert tf.executing_eagerly(), "TensorFlow must be in eager mode for this test" assert isinstance(result, tf.Tensor) np.testing.assert_array_equal(result.numpy(), np.ones((4, 6))) @@ -294,7 +298,7 @@ def test_inconsistent_shapes_raises_error(mock_centroid, mock_model_params, dumm def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_mask): - """Test compute_zernike_tip_tilt with single batch input and mocks.""" + """Test compute_zernike_tip_tilt handling with single batch input and mocks.""" # Mock the CentroidEstimator class mock_centroid_calc = mocker.patch("wf_psf.data.centroids.CentroidEstimator", autospec=True) @@ -332,11 +336,11 @@ def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_ma np.testing.assert_allclose(args[0][1], expected_dy * pixel_sampling, rtol=1e-7, atol=0) # Expected values based on mock side_effect (0.5 * shift) - np.testing.assert_allclose(zernike_corrections[0, 0], expected_dx * pixel_sampling * 0.5) # Zk1 + np.testing.assert_allclose(zernike_corrections[0, 0], expected_dx * pixel_sampling * 0.5) # Zk1 np.testing.assert_allclose(zernike_corrections[0, 1], expected_dy * pixel_sampling * 0.5) # Zk2 def test_compute_zernike_tip_tilt_batch(mocker, multiple_images): - """Test compute_zernike_tip_tilt with batch input and mocks.""" + """Test compute_zernike_tip_tilt batch handling of multiple inputs.""" # Mock the CentroidEstimator class mock_centroid_calc = mocker.patch("wf_psf.data.centroids.CentroidEstimator", autospec=True) @@ -377,7 +381,6 @@ def test_compute_zernike_tip_tilt_batch(mocker, multiple_images): # Process the displacements and expected values for each image in the batch expected_dx = reference_shifts[1] - mock_instance.get_intra_pixel_shifts.return_value[:, 1] # Expected x-axis shift in meters - expected_dy = reference_shifts[0] - mock_instance.get_intra_pixel_shifts.return_value[:, 0] # Expected y-axis shift in meters # Compare expected values with the actual arguments passed to the mock function From 2576208ce4e64e4f950d25a698b81634b2bf4dc3 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 8 Aug 2025 20:30:00 +0200 Subject: [PATCH 086/146] Set mock Zernike priors to None in test_data_utils.py helper module --- src/wf_psf/tests/test_data/test_data_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/wf_psf/tests/test_data/test_data_utils.py b/src/wf_psf/tests/test_data/test_data_utils.py index a5ead298..1ebc00cb 100644 --- a/src/wf_psf/tests/test_data/test_data_utils.py +++ b/src/wf_psf/tests/test_data/test_data_utils.py @@ -8,8 +8,8 @@ def __init__( self, training_positions, test_positions, - training_zernike_priors, - test_zernike_priors, + training_zernike_priors=None, + test_zernike_priors=None, noisy_stars=None, noisy_masks=None, stars=None, From d0309d10d3e70b66a97767c0751c94df09ae193d Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 8 Aug 2025 20:31:12 +0200 Subject: [PATCH 087/146] Remove -1.0 multiplicative factor applied to Zernike tip and tilt values --- src/wf_psf/data/centroids.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/wf_psf/data/centroids.py b/src/wf_psf/data/centroids.py index a5992608..2c5e5da7 100644 --- a/src/wf_psf/data/centroids.py +++ b/src/wf_psf/data/centroids.py @@ -32,12 +32,11 @@ def compute_centroid_correction(model_params, data, batch_size: int=1) -> np.nda batch_size : int, optional The batch size to use when processing the stars. Default is 16. - Returns ------- zernike_centroid_array : np.ndarray A 2D NumPy array of shape `(n_stars, 3)`, where `n_stars` is the number of - observed stars. The array contains the computed Zernike contributions, + observed stars. The array contains the computed Zernike (Z1, Z2) contributions, with zero padding applied to the first column to ensure a consistent shape. """ star_postage_stamps = extract_star_data(data=data, train_key="noisy_stars", test_key="stars") @@ -70,7 +69,7 @@ def compute_centroid_correction(model_params, data, batch_size: int=1) -> np.nda batch_masks = star_masks[i:i + batch_size] if star_masks is not None else None # Compute Zernike 1 and Zernike 2 for the batch - zk1_2_batch = -1.0 * compute_zernike_tip_tilt( + zk1_2_batch = compute_zernike_tip_tilt( batch_postage_stamps, batch_masks, pix_sampling, reference_shifts ) From 90ba68e2bca59b32c379bcdc5d099af132db6ccf Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 8 Aug 2025 20:31:47 +0200 Subject: [PATCH 088/146] Update unit tests with changes to compute_centroid_correction --- src/wf_psf/tests/test_data/centroids_test.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/wf_psf/tests/test_data/centroids_test.py b/src/wf_psf/tests/test_data/centroids_test.py index 08c05382..00714733 100644 --- a/src/wf_psf/tests/test_data/centroids_test.py +++ b/src/wf_psf/tests/test_data/centroids_test.py @@ -125,8 +125,8 @@ def test_compute_centroid_correction_with_masks(mock_data): # Ensure the result has the correct shape assert result.shape == (4, 3) # Should be (n_stars, 3 Zernike components) - assert np.allclose(result[0, :], np.array([0, -0.1, -0.2])) # First star Zernike coefficients - assert np.allclose(result[1, :], np.array([0, -0.3, -0.4])) # Second star Zernike coefficients + assert np.allclose(result[0, :], np.array([0, 0.1, 0.2])) # First star Zernike coefficients + assert np.allclose(result[1, :], np.array([0, 0.3, 0.4])) # Second star Zernike coefficients def test_compute_centroid_correction_without_masks(mock_data): @@ -162,10 +162,10 @@ def test_compute_centroid_correction_without_masks(mock_data): # Validate expected values (adjust based on behavior) expected_result = np.array([ - [0, -0.1, -0.2], # From training data - [0, -0.3, -0.4], - [0, -0.1, -0.2], # From test data (reused mocked return) - [0, -0.3, -0.4] + [0, 0.1, 0.2], # From training data + [0, 0.3, 0.4], + [0, 0.1, 0.2], # From test data (reused mocked return) + [0, 0.3, 0.4] ]) assert np.allclose(result, expected_result) From 3bd18d2d4ced470db8abd09a6b86276ffa0618a6 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 8 Aug 2025 22:42:28 +0200 Subject: [PATCH 089/146] Move TFPhysicalPolychromaticField.pad_zernikes to helper method pad_tf_zernikes in data_zernike_utils.py - Import pad_tf_zernikes into psf_model_physical_polychromatic.py - Replace calls to pad_zernikes with pad_tf_zernikes - Move padding zernike unit tests to data_zernike_utils_test.py - Update/Remove unit tests in psf_model_physical_polychromatic_test.py --- src/wf_psf/data/data_zernike_utils.py | 40 +++++ .../psf_model_physical_polychromatic.py | 14 +- .../test_data/data_zernike_utils_test.py | 75 ++++++++- .../psf_model_physical_polychromatic_test.py | 144 +++++------------- 4 files changed, 164 insertions(+), 109 deletions(-) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index 1157be84..3a7fa90b 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -147,6 +147,46 @@ def combine_zernike_contributions(contributions: list[np.ndarray]) -> np.ndarray return combined + +def pad_tf_zernikes(zk_param: tf.Tensor, zk_prior: tf.Tensor, n_zks_total: int): + """ + Pad the Zernike coefficient tensors to match the specified total number of Zernikes. + + Parameters + ---------- + zk_param : tf.Tensor + Zernike coefficients for the parametric part. Shape [batch, n_zks_param, 1, 1]. + zk_prior : tf.Tensor + Zernike coefficients for the prior part. Shape [batch, n_zks_prior, 1, 1]. + n_zks_total : int + Total number of Zernikes to pad to. + + Returns + ------- + padded_zk_param : tf.Tensor + Padded Zernike coefficients for the parametric part. Shape [batch, n_zks_total, 1, 1]. + padded_zk_prior : tf.Tensor + Padded Zernike coefficients for the prior part. Shape [batch, n_zks_total, 1, 1]. + """ + + pad_num_param = n_zks_total - tf.shape(zk_param)[1] + pad_num_prior = n_zks_total - tf.shape(zk_prior)[1] + + padded_zk_param = tf.cond( + tf.not_equal(pad_num_param, 0), + lambda: tf.pad(zk_param, [(0, 0), (0, pad_num_param), (0, 0), (0, 0)]), + lambda: zk_param, + ) + + padded_zk_prior = tf.cond( + tf.not_equal(pad_num_prior, 0), + lambda: tf.pad(zk_prior, [(0, 0), (0, pad_num_prior), (0, 0), (0, 0)]), + lambda: zk_prior, + ) + + return padded_zk_param, padded_zk_prior + + def assemble_zernike_contributions( model_params, zernike_prior=None, diff --git a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index 9979bcf1..918ffc4b 100644 --- a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -11,7 +11,11 @@ import tensorflow as tf from tensorflow.python.keras.engine import data_adapter from wf_psf.data.data_handler import get_np_obs_positions -from wf_psf.data.data_zernike_utils import ZernikeInputsFactory, assemble_zernike_contributions +from wf_psf.data.data_zernike_utils import ( + ZernikeInputsFactory, + assemble_zernike_contributions, + pad_tf_zernikes +) from wf_psf.psf_models import psf_models as psfm from wf_psf.psf_models.tf_modules.tf_layers import ( TFPolynomialZernikeField, @@ -575,8 +579,8 @@ def compute_zernikes(self, input_positions): zernike_prior = self.tf_physical_layer.call(input_positions) # Pad and sum the zernike coefficients - padded_zernike_params, padded_zernike_prior = self.pad_zernikes( - zernike_params, zernike_prior + padded_zernike_params, padded_zernike_prior = pad_tf_zernikes( + zernike_params, zernike_prior, self.n_zks_total ) zernike_coeffs = tf.math.add(padded_zernike_params, padded_zernike_prior) @@ -613,8 +617,8 @@ def predict_zernikes(self, input_positions): physical_layer_prediction = self.tf_physical_layer.predict(input_positions) # Pad and sum the Zernike coefficients - padded_zernike_params, padded_physical_layer_prediction = self.pad_zernikes( - zernike_params, physical_layer_prediction + padded_zernike_params, padded_physical_layer_prediction = pad_tf_zernikes( + zernike_params, physical_layer_prediction, self.n_zks_total ) zernike_coeffs = tf.math.add( padded_zernike_params, padded_physical_layer_prediction diff --git a/src/wf_psf/tests/test_data/data_zernike_utils_test.py b/src/wf_psf/tests/test_data/data_zernike_utils_test.py index cc6e22de..afafc1db 100644 --- a/src/wf_psf/tests/test_data/data_zernike_utils_test.py +++ b/src/wf_psf/tests/test_data/data_zernike_utils_test.py @@ -4,13 +4,13 @@ from unittest.mock import MagicMock, patch import tensorflow as tf from wf_psf.data.data_zernike_utils import ( - ZernikeInputs, ZernikeInputsFactory, get_np_zernike_prior, pad_contribution_to_order, combine_zernike_contributions, assemble_zernike_contributions, compute_zernike_tip_tilt, + pad_tf_zernikes ) from wf_psf.tests.test_data.test_data_utils import MockData from types import SimpleNamespace as RecursiveNamespace @@ -297,6 +297,79 @@ def test_inconsistent_shapes_raises_error(mock_centroid, mock_model_params, dumm ) +def test_pad_zernikes_num_of_zernikes_equal(): + # Prepare your test tensors + zk_param = tf.constant([[[[1.0]]], [[[2.0]]]]) # Shape (2, 1, 1, 1) + zk_prior = tf.constant([[[[1.0]]], [[[2.0]]]]) # Same shape + + # Reshape to (1, 2, 1, 1) + zk_param = tf.reshape(zk_param, (1, 2, 1, 1)) + zk_prior = tf.reshape(zk_prior, (1, 2, 1, 1)) + + # Reset _n_zks_total to max number of zernikes (2 here) + n_zks_total = max( + tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() + ) + + # Call pad_zernikes method + padded_zk_param, padded_zk_prior = pad_tf_zernikes( + zk_param, zk_prior, n_zks_total + ) + + # Assert shapes are equal and correct + assert padded_zk_param.shape[1] == n_zks_total + assert padded_zk_prior.shape[1] == n_zks_total + + # If num zernikes already equal, output should be unchanged + np.testing.assert_array_equal(padded_zk_param.numpy(), zk_param.numpy()) + np.testing.assert_array_equal(padded_zk_prior.numpy(), zk_prior.numpy()) + +def test_pad_zernikes_prior_greater_than_param(): + zk_param = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) + zk_prior = tf.constant([[[[1]], [[2]], [[3]], [[4]], [[5]]]]) # Shape: (5, 1, 1, 1) + + # Reshape the tensors to have the desired shapes + zk_param = tf.reshape(zk_param, (1, 2, 1, 1)) # Reshaping tensor1 to (1, 2, 1, 1) + zk_prior = tf.reshape(zk_prior, (1, 5, 1, 1)) # Reshaping tensor2 to (1, 5, 1, 1) + + # Reset n_zks_total attribute + n_zks_total = max( + tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() + ) + + # Call the method under test + padded_zk_param, padded_zk_prior = pad_tf_zernikes( + zk_param, zk_prior, n_zks_total + ) + + # Assert that the padded tensors have the correct shapes + assert padded_zk_param.shape == (1, 5, 1, 1) + assert padded_zk_prior.shape == (1, 5, 1, 1) + + +def test_pad_zernikes_param_greater_than_prior(): + zk_param = tf.constant([[[[10]], [[20]], [[30]], [[40]]]]) # Shape: (4, 1, 1, 1) + zk_prior = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) + + # Reshape the tensors to have the desired shapes + zk_param = tf.reshape(zk_param, (1, 4, 1, 1)) # Reshaping tensor1 to (1, 2, 1, 1) + zk_prior = tf.reshape(zk_prior, (1, 2, 1, 1)) # Reshaping tensor2 to (1, 4, 1, 1) + + # Reset n_zks_total attribute + n_zks_total = max( + tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() + ) + + # Call the method under test + padded_zk_param, padded_zk_prior = pad_tf_zernikes( + zk_param, zk_prior, n_zks_total + ) + + # Assert that the padded tensors have the correct shapes + assert padded_zk_param.shape == (1, 4, 1, 1) + assert padded_zk_prior.shape == (1, 4, 1, 1) + + def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_mask): """Test compute_zernike_tip_tilt handling with single batch input and mocks.""" diff --git a/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py b/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py index 13d78667..95ad6287 100644 --- a/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py +++ b/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py @@ -9,7 +9,7 @@ import pytest import numpy as np import tensorflow as tf -from unittest.mock import PropertyMock +from unittest.mock import patch from wf_psf.psf_models.models.psf_model_physical_polychromatic import ( TFPhysicalPolychromaticField, ) @@ -29,15 +29,27 @@ def zks_prior(): @pytest.fixture -def mock_data(mocker): +def mock_data(mocker, zks_prior): mock_instance = mocker.Mock(spec=DataConfigHandler) - # Configure the mock data object to have the necessary attributes mock_instance.run_type = "training" + + training_dataset = { + "positions": np.array([[1, 2], [3, 4]]), + "zernike_prior": zks_prior, + "noisy_stars": np.zeros((2, 1, 1, 1)), + } + test_dataset = { + "positions": np.array([[5, 6], [7, 8]]), + "zernike_prior": zks_prior, + "stars": np.zeros((2, 1, 1, 1)), + } + mock_instance.training_data = mocker.Mock() - mock_instance.training_data.dataset = {"positions": np.array([[1, 2], [3, 4]])} + mock_instance.training_data.dataset = training_dataset mock_instance.test_data = mocker.Mock() - mock_instance.test_data.dataset = {"positions": np.array([[5, 6], [7, 8]])} - mock_instance.batch_size = 32 + mock_instance.test_data.dataset = test_dataset + mock_instance.batch_size = 16 + return mock_instance @@ -49,122 +61,48 @@ def mock_model_params(mocker): return model_params_mock @pytest.fixture -def physical_layer_instance(mocker, mock_model_params, mock_data, zks_prior): - # Create training params mock object - mock_training_params = mocker.Mock() - - # Create TFPhysicalPolychromaticField instance - psf_field_instance = TFPhysicalPolychromaticField( - mock_model_params, mock_training_params, mock_data - ) - return psf_field_instance - - -def test_pad_zernikes_num_of_zernikes_equal(physical_layer_instance): - # Define input tensors with same length and num of Zernikes - zk_param = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) - zk_prior = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) - - # Reshape the tensors to have the desired shapes - zk_param = tf.reshape(zk_param, (1, 2, 1, 1)) # Reshaping tensor1 to (1, 2, 1, 1) - zk_prior = tf.reshape(zk_prior, (1, 2, 1, 1)) # Reshaping tensor2 to (1, 2, 1, 1) - - # Reset n_zks_total attribute - physical_layer_instance._n_zks_total = max( - tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() - ) - # Call the method under test - padded_zk_param, padded_zk_prior = physical_layer_instance.pad_zernikes( - zk_param, zk_prior - ) - - # Assert that the padded tensors have the correct shapes - assert padded_zk_param.shape == (1, 2, 1, 1) - assert padded_zk_prior.shape == (1, 2, 1, 1) - - -def test_pad_zernikes_prior_greater_than_param(physical_layer_instance): - zk_param = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) - zk_prior = tf.constant([[[[1]], [[2]], [[3]], [[4]], [[5]]]]) # Shape: (5, 1, 1, 1) - - # Reshape the tensors to have the desired shapes - zk_param = tf.reshape(zk_param, (1, 2, 1, 1)) # Reshaping tensor1 to (1, 2, 1, 1) - zk_prior = tf.reshape(zk_prior, (1, 5, 1, 1)) # Reshaping tensor2 to (1, 5, 1, 1) - - # Reset n_zks_total attribute - physical_layer_instance._n_zks_total = max( - tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() - ) - - # Call the method under test - padded_zk_param, padded_zk_prior = physical_layer_instance.pad_zernikes( - zk_param, zk_prior - ) - - # Assert that the padded tensors have the correct shapes - assert padded_zk_param.shape == (1, 5, 1, 1) - assert padded_zk_prior.shape == (1, 5, 1, 1) - - -def test_pad_zernikes_param_greater_than_prior(physical_layer_instance): - zk_param = tf.constant([[[[10]], [[20]], [[30]], [[40]]]]) # Shape: (4, 1, 1, 1) - zk_prior = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) - - # Reshape the tensors to have the desired shapes - zk_param = tf.reshape(zk_param, (1, 4, 1, 1)) # Reshaping tensor1 to (1, 2, 1, 1) - zk_prior = tf.reshape(zk_prior, (1, 2, 1, 1)) # Reshaping tensor2 to (1, 4, 1, 1) - - # Reset n_zks_total attribute - physical_layer_instance._n_zks_total = max( - tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() - ) - - # Call the method under test - padded_zk_param, padded_zk_prior = physical_layer_instance.pad_zernikes( - zk_param, zk_prior - ) - - # Assert that the padded tensors have the correct shapes - assert padded_zk_param.shape == (1, 4, 1, 1) - assert padded_zk_prior.shape == (1, 4, 1, 1) - +def physical_layer_instance(mocker, mock_model_params, mock_data): + # Patch expensive methods during construction to avoid errors + with patch("wf_psf.psf_models.models.psf_model_physical_polychromatic.TFPhysicalPolychromaticField._assemble_zernike_contributions", return_value=tf.constant([[[[1.0]]], [[[2.0]]]])): + from wf_psf.psf_models.models.psf_model_physical_polychromatic import TFPhysicalPolychromaticField + instance = TFPhysicalPolychromaticField(mock_model_params, mocker.Mock(), mock_data) + return instance def test_compute_zernikes(mocker, physical_layer_instance): - # Expected output of mock components + # Expected output of mock components padded_zernike_param = tf.constant([[[[10]], [[20]], [[30]], [[40]]]], dtype=tf.float32) padded_zernike_prior = tf.constant([[[[1]], [[2]], [[0]], [[0]]]], dtype=tf.float32) - expected_values = tf.constant([[[[11]], [[22]], [[30]], [[40]]]], dtype=tf.float32) - - # Patch tf_poly_Z_field property - mock_tf_poly_Z_field = mocker.Mock(return_value=padded_zernike_param) + n_zks_total = physical_layer_instance.n_zks_total + expected_values_list = [11, 22, 30, 40] + [0] * (n_zks_total - 4) + expected_values = tf.constant( + [[[[v]] for v in expected_values_list]], + dtype=tf.float32 +) + # Patch tf_poly_Z_field method mocker.patch.object( TFPhysicalPolychromaticField, - 'tf_poly_Z_field', - new_callable=PropertyMock, - return_value=mock_tf_poly_Z_field + "tf_poly_Z_field", + return_value=padded_zernike_param ) - # Patch tf_physical_layer property + # Patch tf_physical_layer.call method mock_tf_physical_layer = mocker.Mock() mock_tf_physical_layer.call.return_value = padded_zernike_prior mocker.patch.object( TFPhysicalPolychromaticField, - 'tf_physical_layer', - new_callable=PropertyMock, - return_value=mock_tf_physical_layer + "tf_physical_layer", + mock_tf_physical_layer ) - # Patch pad_zernikes instance method directly (this one isn't a property) - mocker.patch.object( - physical_layer_instance, - 'pad_zernikes', + # Patch pad_tf_zernikes function + mocker.patch( + "wf_psf.data.data_zernike_utils.pad_tf_zernikes", return_value=(padded_zernike_param, padded_zernike_prior) ) - # Run the test zernike_coeffs = physical_layer_instance.compute_zernikes(tf.constant([[0.0, 0.0]])) # Assertions tf.debugging.assert_equal(zernike_coeffs, expected_values) - assert zernike_coeffs.shape == expected_values.shape \ No newline at end of file + assert zernike_coeffs.shape == expected_values.shape From 7ca574a3096f67d79089b361f2e24552170f39f9 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 8 Aug 2025 22:59:07 +0200 Subject: [PATCH 090/146] Correct bug in test_load_inference_model - Add missing fixture arguments to mock_training_config and mock_inference_config - Add patch for DataHandler - Remove unused import --- .../test_inference/psf_inference_test.py | 38 +++++++++++++++---- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/src/wf_psf/tests/test_inference/psf_inference_test.py b/src/wf_psf/tests/test_inference/psf_inference_test.py index 4add460b..91328300 100644 --- a/src/wf_psf/tests/test_inference/psf_inference_test.py +++ b/src/wf_psf/tests/test_inference/psf_inference_test.py @@ -12,11 +12,11 @@ import pytest import tensorflow as tf from types import SimpleNamespace -from unittest.mock import MagicMock, patch, PropertyMock +from unittest.mock import MagicMock, patch from wf_psf.inference.psf_inference import ( InferenceConfigHandler, PSFInference, - PSFInferenceEngine + PSFInferenceEngine ) from wf_psf.utils.read_config import RecursiveNamespace @@ -29,7 +29,26 @@ def mock_training_config(): model_params=RecursiveNamespace( model_name="mock_model", output_Q=2, - output_dim=32 + output_dim=32, + pupil_diameter=256, + oversampling_rate=3, + interpolation_type=None, + interpolation_args=None, + sed_interp_pts_per_bin=0, + sed_extrapolate=True, + sed_interp_kind="linear", + sed_sigma=0, + x_lims=[0.0, 1000.0], + y_lims=[0.0, 1000.0], + pix_sampling=12, + tel_diameter=1.2, + tel_focal_length=24.5, + euclid_obsc=True, + LP_filter_length=3, + param_hparams=RecursiveNamespace( + n_zernikes=10, + + ) ) ) ) @@ -48,9 +67,10 @@ def mock_inference_config(): data_config_path=None ), model_params=RecursiveNamespace( + n_bins_lda=8, output_Q=1, output_dim=64 - ) + ), ) ) return inference_config @@ -179,9 +199,11 @@ def test_batch_size_positive(): assert inference.batch_size == 4 +@patch('wf_psf.inference.psf_inference.DataHandler') @patch('wf_psf.inference.psf_inference.load_trained_psf_model') -def test_load_inference_model(mock_load_trained_psf_model, mock_training_config, mock_inference_config): - +def test_load_inference_model(mock_load_trained_psf_model, mock_data_handler, mock_training_config, mock_inference_config): + mock_data_config = MagicMock() + mock_data_handler.return_value = mock_data_config mock_config_handler = MagicMock(spec=InferenceConfigHandler) mock_config_handler.trained_model_path = "mock/path/to/model" mock_config_handler.training_config = mock_training_config @@ -202,8 +224,8 @@ def test_load_inference_model(mock_load_trained_psf_model, mock_training_config, # Assert calls to the mocked methods mock_load_trained_psf_model.assert_called_once_with( - mock_config_handler.training_config, - mock_config_handler.data_config, + mock_training_config, + mock_data_config, weights_path_pattern ) From 47223a9c698bfdb8f2159243b6bc78bce2ac02fd Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Mon, 18 Aug 2025 15:33:26 +0200 Subject: [PATCH 091/146] Revert sign change applied to compute_centroid_correction --- src/wf_psf/data/centroids.py | 2 +- src/wf_psf/tests/test_data/centroids_test.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/wf_psf/data/centroids.py b/src/wf_psf/data/centroids.py index 2c5e5da7..0a3362ef 100644 --- a/src/wf_psf/data/centroids.py +++ b/src/wf_psf/data/centroids.py @@ -69,7 +69,7 @@ def compute_centroid_correction(model_params, data, batch_size: int=1) -> np.nda batch_masks = star_masks[i:i + batch_size] if star_masks is not None else None # Compute Zernike 1 and Zernike 2 for the batch - zk1_2_batch = compute_zernike_tip_tilt( + zk1_2_batch = -1.0 * compute_zernike_tip_tilt( batch_postage_stamps, batch_masks, pix_sampling, reference_shifts ) diff --git a/src/wf_psf/tests/test_data/centroids_test.py b/src/wf_psf/tests/test_data/centroids_test.py index 00714733..9a6c6acc 100644 --- a/src/wf_psf/tests/test_data/centroids_test.py +++ b/src/wf_psf/tests/test_data/centroids_test.py @@ -125,8 +125,8 @@ def test_compute_centroid_correction_with_masks(mock_data): # Ensure the result has the correct shape assert result.shape == (4, 3) # Should be (n_stars, 3 Zernike components) - assert np.allclose(result[0, :], np.array([0, 0.1, 0.2])) # First star Zernike coefficients - assert np.allclose(result[1, :], np.array([0, 0.3, 0.4])) # Second star Zernike coefficients + assert np.allclose(result[0, :], np.array([0, -0.1, -0.2])) # First star Zernike coefficients + assert np.allclose(result[1, :], np.array([0, -0.3, -0.4])) # Second star Zernike coefficients def test_compute_centroid_correction_without_masks(mock_data): @@ -161,7 +161,7 @@ def test_compute_centroid_correction_without_masks(mock_data): assert result.shape == (4, 3) # (n_stars, 3 Zernike components) # Validate expected values (adjust based on behavior) - expected_result = np.array([ + expected_result = -1.0 * np.array([ [0, 0.1, 0.2], # From training data [0, 0.3, 0.4], [0, 0.1, 0.2], # From test data (reused mocked return) From e10edc7cb24ec073b8d614dd5c26b90321905e96 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Tue, 19 Aug 2025 23:37:52 +0200 Subject: [PATCH 092/146] Refactor _prepare_positions_and_seds to enforce shape consistency and add validation - Ensure x_field and y_field are at least 1D and broadcast to positions of shape (n_samples, 2) - Validate that SEDs batch size matches the number of positions; raise ValueError if not - Validate that SEDs last dimension is 2 (flux, wavelength) - Process SEDs via data_handler as before - Removes need for broadcasting SEDs silently, avoiding hidden shape mismatches - Supports single-star, multi-star, and scalar inputs - Improves error messages for easier debugging in unit tests Also updates unit tests to cover: - Single-star input shapes - Mismatched x/y fields or SED batch size (ValueError cases) --- src/wf_psf/inference/psf_inference.py | 60 ++++++-- .../test_inference/psf_inference_test.py | 137 +++++++++++++++++- src/wf_psf/utils/utils.py | 21 +++ 3 files changed, 206 insertions(+), 12 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 6bba2282..76619f5a 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -13,6 +13,7 @@ import numpy as np from wf_psf.data.data_handler import DataHandler from wf_psf.utils.read_config import read_conf +from wf_psf.utils.utils import ensure_batch from wf_psf.psf_models import psf_models from wf_psf.psf_models.psf_model_loader import load_trained_psf_model import tensorflow as tf @@ -260,13 +261,40 @@ def output_dim(self): return self._output_dim def _prepare_positions_and_seds(self): - """Preprocess and return tensors for positions and SEDs.""" - positions = tf.convert_to_tensor( - np.array([self.x_field, self.y_field]).T, dtype=tf.float32 - ) - self.data_handler.process_sed_data(self.seds) - sed_data = self.data_handler.sed_data - return positions, sed_data + """ + Preprocess and return tensors for positions and SEDs with consistent shapes. + + Handles single-star, multi-star, and even scalar inputs, ensuring: + - positions: shape (n_samples, 2) + - sed_data: shape (n_samples, n_bins, 2) + """ + # Ensure x_field and y_field are at least 1D + x_arr = np.atleast_1d(self.x_field) + y_arr = np.atleast_1d(self.y_field) + + if x_arr.size != y_arr.size: + raise ValueError(f"x_field and y_field must have the same length. " + f"Got {x_arr.size} and {y_arr.size}") + + # Combine into positions array (n_samples, 2) + positions = np.column_stack((x_arr, y_arr)) + positions = tf.convert_to_tensor(positions, dtype=tf.float32) + + # Ensure SEDs have shape (n_samples, n_bins, 2) + sed_data = ensure_batch(self.seds) + + if sed_data.shape[0] != positions.shape[0]: + raise ValueError(f"SEDs batch size {sed_data.shape[0]} does not match number of positions {positions.shape[0]}") + + if sed_data.shape[2] != 2: + raise ValueError(f"SEDs last dimension must be 2 (flux, wavelength). Got {sed_data.shape}") + + # Process SEDs through the data handler + self.data_handler.process_sed_data(sed_data) + sed_data_tensor = self.data_handler.sed_data + + return positions, sed_data_tensor + def run_inference(self): """Run PSF inference and return the full PSF array.""" @@ -291,9 +319,23 @@ def get_psfs(self): self._ensure_psf_inference_completed() return self.engine.get_psfs() - def get_psf(self, index): + def get_psf(self, index: int = 0) -> np.ndarray: + """ + Get the PSF at a specific index. + + If only a single star was passed during instantiation, the index defaults to 0. + """ self._ensure_psf_inference_completed() - return self.engine.get_psf(index) + + inferred_psfs = self.engine.get_psfs() + + # If a single-star batch, ignore index bounds + if inferred_psfs.shape[0] == 1: + return inferred_psfs[0] + + # Otherwise, return the PSF at the requested index + return inferred_psfs[index] + class PSFInferenceEngine: def __init__(self, trained_model, batch_size: int, output_dim: int): diff --git a/src/wf_psf/tests/test_inference/psf_inference_test.py b/src/wf_psf/tests/test_inference/psf_inference_test.py index 91328300..ec9f0495 100644 --- a/src/wf_psf/tests/test_inference/psf_inference_test.py +++ b/src/wf_psf/tests/test_inference/psf_inference_test.py @@ -12,7 +12,7 @@ import pytest import tensorflow as tf from types import SimpleNamespace -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, patch, PropertyMock from wf_psf.inference.psf_inference import ( InferenceConfigHandler, PSFInference, @@ -21,6 +21,19 @@ from wf_psf.utils.read_config import RecursiveNamespace +def _patch_data_handler(): + """Helper for patching data_handler to avoid full PSF logic.""" + patcher = patch.object(PSFInference, "data_handler", new_callable=PropertyMock) + mock_data_handler = patcher.start() + mock_instance = MagicMock() + mock_data_handler.return_value = mock_instance + + def fake_process(x): + mock_instance.sed_data = tf.convert_to_tensor(x) + + mock_instance.process_sed_data.side_effect = fake_process + return patcher, mock_instance + @pytest.fixture def mock_training_config(): training_config = RecursiveNamespace( @@ -83,14 +96,46 @@ def psf_test_setup(mock_inference_config): output_dim = 32 mock_positions = tf.convert_to_tensor([[0.1, 0.1], [0.2, 0.2]], dtype=tf.float32) - mock_seds = tf.convert_to_tensor(np.random.rand(num_sources, 2, num_bins), dtype=tf.float32) + mock_seds = tf.convert_to_tensor(np.random.rand(num_sources, num_bins, 2), dtype=tf.float32) expected_psfs = np.random.rand(num_sources, output_dim, output_dim).astype(np.float32) inference = PSFInference( "dummy_path.yaml", x_field=[0.1, 0.2], y_field=[0.1, 0.2], - seds=np.random.rand(num_sources, num_bins) + seds=np.random.rand(num_sources, num_bins, 2) + ) + inference._config_handler = MagicMock() + inference._config_handler.inference_config = mock_inference_config + inference._trained_psf_model = MagicMock() + + return { + "inference": inference, + "mock_positions": mock_positions, + "mock_seds": mock_seds, + "expected_psfs": expected_psfs, + "num_sources": num_sources, + "num_bins": num_bins, + "output_dim": output_dim + } + +@pytest.fixture +def psf_single_star_setup(mock_inference_config): + num_sources = 1 + num_bins = 10 + output_dim = 32 + + # Single position + mock_positions = tf.convert_to_tensor([[0.1, 0.1]], dtype=tf.float32) + # Shape (1, 2, num_bins) + mock_seds = tf.convert_to_tensor(np.random.rand(num_sources, 2, num_bins), dtype=tf.float32) + expected_psfs = np.random.rand(num_sources, output_dim, output_dim).astype(np.float32) + + inference = PSFInference( + "dummy_path.yaml", + x_field=0.1, # scalar for single star + y_field=0.1, + seds=np.random.rand(num_bins, 2) # shape (num_bins, 2) before batching ) inference._config_handler = MagicMock() inference._config_handler.inference_config = mock_inference_config @@ -308,3 +353,89 @@ def fake_compute_psfs(positions, seds): assert np.all(psfs_2 == expected_psfs) assert mock_compute_psfs.call_count == 1 + + + +def test_single_star_inference_shape(psf_single_star_setup): + setup = psf_single_star_setup + + _, mock_instance = _patch_data_handler() + + # Run the method under test + positions, sed_data = setup["inference"]._prepare_positions_and_seds() + + # Check shapes + assert sed_data.shape == (1, setup["num_bins"], 2) + assert positions.shape == (1, 2) + + # Verify the call happened + mock_instance.process_sed_data.assert_called_once() + args, _ = mock_instance.process_sed_data.call_args + input_array = args[0] + + # Check input SED had the right shape before being tensorized + assert input_array.shape == (1, setup["num_bins"], 2), \ + "process_sed_data should have been called with shape (1, num_bins, 2)" + + + +def test_multiple_star_inference_shape(psf_test_setup): + """Test that _prepare_positions_and_seds returns correct shapes for multiple stars.""" + setup = psf_test_setup + + _, mock_instance = _patch_data_handler() + + # Run the method under test + positions, sed_data = setup["inference"]._prepare_positions_and_seds() + + # Check shapes + assert sed_data.shape == (2, setup["num_bins"], 2) + assert positions.shape == (2, 2) + + # Verify the call happened + mock_instance.process_sed_data.assert_called_once() + args, _ = mock_instance.process_sed_data.call_args + input_array = args[0] + + # Check input SED had the right shape before being tensorized + assert input_array.shape == (2, setup["num_bins"], 2), \ + "process_sed_data should have been called with shape (2, num_bins, 2)" + + +def test_valueerror_on_mismatched_batches(psf_single_star_setup): + """Raise if sed_data batch size != positions batch size and sed_data != 1.""" + setup = psf_single_star_setup + inference = setup["inference"] + + patcher, _ = _patch_data_handler() + try: + # Force sed_data to have 2 sources while positions has 1 + bad_sed = np.ones((2, setup["num_bins"], 2), dtype=np.float32) + + # Replace fixture's sed_data with mismatched one + inference.seds = bad_sed + inference.positions = np.ones((1, 2), dtype=np.float32) + + with pytest.raises(ValueError, match="SEDs batch size 2 does not match number of positions 1"): + inference._prepare_positions_and_seds() + finally: + patcher.stop() + + +def test_valueerror_on_mismatched_positions(psf_single_star_setup): + """Raise if positions batch size != sed_data batch size (opposite mismatch).""" + setup = psf_single_star_setup + inference = setup["inference"] + + patcher, _ = _patch_data_handler() + try: + # Force positions to have 3 entries while sed_data has 2 + bad_sed = np.ones((2, setup["num_bins"], 2), dtype=np.float32) + inference.seds = bad_sed + inference.x_field = np.ones((3, 1), dtype=np.float32) + inference.y_field = np.ones((3, 1), dtype=np.float32) + + with pytest.raises(ValueError, match="SEDs batch size 2 does not match number of positions 3"): + inference._prepare_positions_and_seds() + finally: + patcher.stop() \ No newline at end of file diff --git a/src/wf_psf/utils/utils.py b/src/wf_psf/utils/utils.py index e27ab33d..f13d2c59 100644 --- a/src/wf_psf/utils/utils.py +++ b/src/wf_psf/utils/utils.py @@ -19,6 +19,27 @@ def scale_to_range(input_array, old_range, new_range): input_array = input_array * (new_range[1] - new_range[0]) + new_range[0] return input_array +def ensure_batch(arr): + """ + Ensure array/tensor has a batch dimension. Converts shape (M, N) → (1, M, N). + + Parameters + ---------- + arr : np.ndarray or tf.Tensor + Input 2D or 3D array/tensor. + + Returns + ------- + np.ndarray or tf.Tensor + With batch dimension prepended if needed. + """ + if isinstance(arr, np.ndarray): + return arr if arr.ndim == 3 else np.expand_dims(arr, axis=0) + elif isinstance(arr, tf.Tensor): + return arr if arr.ndim == 3 else tf.expand_dims(arr, axis=0) + else: + raise TypeError(f"Expected np.ndarray or tf.Tensor, got {type(arr)}") + def calc_wfe(zernike_basis, zks): wfe = np.einsum("ijk,ijk->jk", zernike_basis, zks.reshape(-1, 1, 1)) From 339bfab545fd010d74cb051f36f29da04ef406cb Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 21 Aug 2025 16:43:44 +0200 Subject: [PATCH 093/146] Fix tensor handling in ZernikeInputsFactory Explicitly convert positions to NumPy arrays with `.numpy()` and adjust inference path to read from `data.dataset` for positions and priors. --- src/wf_psf/data/data_zernike_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index 3a7fa90b..4149d79e 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -57,8 +57,8 @@ def build(data, run_type: str, model_params, prior: Optional[np.ndarray] = None) centroid_dataset = data # Assuming data is a DataConfigHandler or similar object containing train and test datasets positions = np.concatenate( [ - data.training_data.dataset["positions"], - data.test_data.dataset["positions"] + data.training_data.dataset["positions"].numpy(), + data.test_data.dataset["positions"].numpy() ], axis=0, ) @@ -72,11 +72,11 @@ def build(data, run_type: str, model_params, prior: Optional[np.ndarray] = None) elif run_type == "inference": centroid_dataset = None - positions = data["positions"] + positions = data.dataset["positions"].numpy() if model_params.use_prior: # Try to extract prior from `data`, if present - prior = getattr(data, "zernike_prior", None) if not isinstance(data, dict) else data.get("zernike_prior") + prior = getattr(data.dataset, "zernike_prior", None) if not isinstance(data, dict) else data.dataset.get("zernike_prior") if prior is None: logger.warning( From adb88ea7b5e964684b628ec7ea6eef6372e07f18 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Thu, 21 Aug 2025 11:50:45 +0200 Subject: [PATCH 094/146] Reformat and remove unused import --- src/wf_psf/data/data_handler.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index 09bb3914..fe660940 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -17,7 +17,6 @@ import wf_psf.utils.utils as utils from wf_psf.psf_models.tf_modules.tf_utils import ensure_tensor import tensorflow as tf -from fractions import Fraction from typing import Optional, Union import logging @@ -89,7 +88,7 @@ def __init__( and `process_sed_data()` is called with either the given `sed_data` or `dataset["SEDs"]`. 3. **Automatic loading mode** (`load_data=True` and no `dataset`): the dataset is loaded from disk using `data_params`, and SEDs are extracted and processed automatically. - + Parameters ---------- dataset_type : str @@ -168,10 +167,14 @@ def _validate_dataset_structure(self): if self.dataset_type == "training": if "noisy_stars" not in self.dataset: - raise ValueError(f"Missing required field 'noisy_stars' in {self.dataset_type} dataset.") + raise ValueError( + f"Missing required field 'noisy_stars' in {self.dataset_type} dataset." + ) elif self.dataset_type == "test": if "stars" not in self.dataset: - raise ValueError(f"Missing required field 'stars' in {self.dataset_type} dataset.") + raise ValueError( + f"Missing required field 'stars' in {self.dataset_type} dataset." + ) elif self.dataset_type == "inference": pass else: @@ -180,12 +183,18 @@ def _validate_dataset_structure(self): def _convert_dataset_to_tensorflow(self): """Convert dataset to TensorFlow tensors.""" - self.dataset["positions"] = ensure_tensor(self.dataset["positions"], dtype=tf.float32) - + self.dataset["positions"] = ensure_tensor( + self.dataset["positions"], dtype=tf.float32 + ) + if self.dataset_type == "train": - self.dataset["noisy_stars"] = ensure_tensor(self.dataset["noisy_stars"], dtype=tf.float32) + self.dataset["noisy_stars"] = ensure_tensor( + self.dataset["noisy_stars"], dtype=tf.float32 + ) elif self.dataset_type == "test": - self.dataset["stars"] = ensure_tensor(self.dataset["stars"], dtype=tf.float32) + self.dataset["stars"] = ensure_tensor( + self.dataset["stars"], dtype=tf.float32 + ) def process_sed_data(self, sed_data): """ From b147be4eed4c2a17aaf0e050cdbf53e485343e45 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 22 Aug 2025 09:39:40 +0200 Subject: [PATCH 095/146] Correct zernike_prior extraction when dataset is a dict, reformat file --- src/wf_psf/data/data_zernike_utils.py | 76 ++++++++++++++++----------- 1 file changed, 45 insertions(+), 31 deletions(-) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index 4149d79e..f7653540 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -26,13 +26,17 @@ @dataclass class ZernikeInputs: zernike_prior: Optional[np.ndarray] # true prior, if provided (e.g. from PDC) - centroid_dataset: Optional[Union[dict, 'RecursiveNamespace']] # only used in training/simulation + centroid_dataset: Optional[ + Union[dict, "RecursiveNamespace"] + ] # only used in training/simulation misalignment_positions: Optional[np.ndarray] # needed for CCD corrections class ZernikeInputsFactory: @staticmethod - def build(data, run_type: str, model_params, prior: Optional[np.ndarray] = None) -> ZernikeInputs: + def build( + data, run_type: str, model_params, prior: Optional[np.ndarray] = None + ) -> ZernikeInputs: """Builds a ZernikeInputs dataclass instance based on run type and data. Parameters @@ -58,7 +62,7 @@ def build(data, run_type: str, model_params, prior: Optional[np.ndarray] = None) positions = np.concatenate( [ data.training_data.dataset["positions"].numpy(), - data.test_data.dataset["positions"].numpy() + data.test_data.dataset["positions"].numpy(), ], axis=0, ) @@ -76,7 +80,11 @@ def build(data, run_type: str, model_params, prior: Optional[np.ndarray] = None) if model_params.use_prior: # Try to extract prior from `data`, if present - prior = getattr(data.dataset, "zernike_prior", None) if not isinstance(data, dict) else data.dataset.get("zernike_prior") + prior = ( + getattr(data.dataset, "zernike_prior", None) + if not isinstance(data.dataset, dict) + else data.dataset.get("zernike_prior") + ) if prior is None: logger.warning( @@ -89,7 +97,7 @@ def build(data, run_type: str, model_params, prior: Optional[np.ndarray] = None) return ZernikeInputs( zernike_prior=prior, centroid_dataset=centroid_dataset, - misalignment_positions=positions + misalignment_positions=positions, ) @@ -119,12 +127,14 @@ def get_np_zernike_prior(data): return zernike_prior + def pad_contribution_to_order(contribution: np.ndarray, max_order: int) -> np.ndarray: """Pad a Zernike contribution array to the max Zernike order.""" current_order = contribution.shape[1] pad_width = ((0, 0), (0, max_order - current_order)) return np.pad(contribution, pad_width=pad_width, mode="constant", constant_values=0) + def combine_zernike_contributions(contributions: list[np.ndarray]) -> np.ndarray: """Combine multiple Zernike contributions, padding each to the max order before summing.""" if not contributions: @@ -228,12 +238,14 @@ def assemble_zernike_contributions( zernike_prior = zernike_prior.numpy() else: raise RuntimeError( - "Zernike prior is a TensorFlow tensor but eager execution is disabled. " - "Cannot call `.numpy()` outside of eager mode." + "Zernike prior is a TensorFlow tensor but eager execution is disabled. " + "Cannot call `.numpy()` outside of eager mode." ) - + elif not isinstance(zernike_prior, np.ndarray): - raise TypeError("Unsupported zernike_prior type. Must be np.ndarray or tf.Tensor.") + raise TypeError( + "Unsupported zernike_prior type. Must be np.ndarray or tf.Tensor." + ) zernike_contribution_list.append(zernike_prior) else: logger.info("Skipping Zernike prior (not used or not provided).") @@ -254,7 +266,9 @@ def assemble_zernike_contributions( ccd_misalignment = compute_ccd_misalignment(model_params, positions) zernike_contribution_list.append(ccd_misalignment) else: - logger.info("Skipping CCD misalignment correction (not enabled or no positions).") + logger.info( + "Skipping CCD misalignment correction (not enabled or no positions)." + ) # If no contributions, return zeros tensor to avoid crashes if not zernike_contribution_list: @@ -267,7 +281,7 @@ def assemble_zernike_contributions( combined_zernike_prior = combine_zernike_contributions(zernike_contribution_list) return tf.convert_to_tensor(combined_zernike_prior, dtype=tf.float32) - + def shift_x_y_to_zk1_2_wavediff(dxy, tel_focal_length=24.5, tel_diameter=1.2): """Compute Zernike 1(2) for a given shifts in x(y) in WaveDifff conventions. @@ -303,18 +317,19 @@ def shift_x_y_to_zk1_2_wavediff(dxy, tel_focal_length=24.5, tel_diameter=1.2): * 3.0 ) + def compute_zernike_tip_tilt( star_images: np.ndarray, star_masks: Optional[np.ndarray] = None, pixel_sampling: float = 12e-6, - reference_shifts: list[float] = [-1/3, -1/3], + reference_shifts: list[float] = [-1 / 3, -1 / 3], sigma_init: float = 2.5, n_iter: int = 20, ) -> np.ndarray: """ Compute Zernike tip-tilt corrections for a batch of PSF images. - This function estimates the centroid shifts of multiple PSFs and computes + This function estimates the centroid shifts of multiple PSFs and computes the corresponding Zernike tip-tilt corrections to align them with a reference. Parameters @@ -330,7 +345,7 @@ def compute_zernike_tip_tilt( pixel_sampling : float, optional The pixel size in meters. Defaults to `12e-6 m` (12 microns). reference_shifts : list[float], optional - The target centroid shifts in pixels, specified as `[dy, dx]`. + The target centroid shifts in pixels, specified as `[dy, dx]`. Defaults to `[-1/3, -1/3]` (nominal Euclid conditions). sigma_init : float, optional Initial standard deviation for centroid estimation. Default is `2.5`. @@ -343,21 +358,18 @@ def compute_zernike_tip_tilt( An array of shape `(num_images, 2)`, where: - Column 0 contains `Zk1` (tip) values. - Column 1 contains `Zk2` (tilt) values. - + Notes ----- - This function processes all images at once using vectorized operations. - The Zernike coefficients are computed in the WaveDiff convention. """ from wf_psf.data.centroids import CentroidEstimator - + # Vectorize the centroid computation centroid_estimator = CentroidEstimator( - im=star_images, - mask=star_masks, - sigma_init=sigma_init, - n_iter=n_iter - ) + im=star_images, mask=star_masks, sigma_init=sigma_init, n_iter=n_iter + ) shifts = centroid_estimator.get_intra_pixel_shifts() @@ -365,23 +377,25 @@ def compute_zernike_tip_tilt( reference_shifts = np.array(reference_shifts) # Reshape to ensure it's a column vector (1, 2) - reference_shifts = reference_shifts[None,:] - + reference_shifts = reference_shifts[None, :] + # Broadcast reference_shifts to match the shape of shifts - reference_shifts = np.broadcast_to(reference_shifts, shifts.shape) - + reference_shifts = np.broadcast_to(reference_shifts, shifts.shape) + # Compute displacements - displacements = (reference_shifts - shifts) # - + displacements = reference_shifts - shifts # + # Ensure the correct axis order for displacements (x-axis, then y-axis) - displacements_swapped = displacements[:, [1, 0]] # Adjust axis order if necessary + displacements_swapped = displacements[:, [1, 0]] # Adjust axis order if necessary # Call shift_x_y_to_zk1_2_wavediff directly on the vector of displacements - zk1_2_array = shift_x_y_to_zk1_2_wavediff(displacements_swapped.flatten() * pixel_sampling ) # vectorized call - + zk1_2_array = shift_x_y_to_zk1_2_wavediff( + displacements_swapped.flatten() * pixel_sampling + ) # vectorized call + # Reshape the result back to the original shape of displacements zk1_2_array = zk1_2_array.reshape(displacements.shape) - + return zk1_2_array From 8747300d500b48720894d87372efdf0eff0308f6 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 22 Aug 2025 09:41:46 +0200 Subject: [PATCH 096/146] Replace np.array with Tensorflow tensors in unit test and fixtures, reformat files --- src/wf_psf/tests/test_data/conftest.py | 50 ++-- .../test_data/data_zernike_utils_test.py | 242 +++++++++++------- 2 files changed, 176 insertions(+), 116 deletions(-) diff --git a/src/wf_psf/tests/test_data/conftest.py b/src/wf_psf/tests/test_data/conftest.py index 04a56893..47eed929 100644 --- a/src/wf_psf/tests/test_data/conftest.py +++ b/src/wf_psf/tests/test_data/conftest.py @@ -100,38 +100,35 @@ def mock_data(scope="module"): """Fixture to provide mock data for testing.""" # Mock positions and Zernike priors - training_positions = np.array([[1, 2], [3, 4]]) - test_positions = np.array([[5, 6], [7, 8]]) - training_zernike_priors = np.array([[0.1, 0.2], [0.3, 0.4]]) - test_zernike_priors = np.array([[0.5, 0.6], [0.7, 0.8]]) + training_positions = tf.constant([[1, 2], [3, 4]]) + test_positions = tf.constant([[5, 6], [7, 8]]) + training_zernike_priors = tf.constant([[0.1, 0.2], [0.3, 0.4]]) + test_zernike_priors = tf.constant([[0.5, 0.6], [0.7, 0.8]]) # Define dummy 5x5 image patches for stars (mock star images) # Define varied values for 5x5 star images - noisy_stars = tf.constant([ - np.arange(25).reshape(5, 5), - np.arange(25, 50).reshape(5, 5) - ], dtype=tf.float32) - - noisy_masks = tf.constant([ - np.eye(5), - np.ones((5, 5)) - ], dtype=tf.float32) - - stars = tf.constant([ - np.full((5, 5), 100), - np.full((5, 5), 200) - ], dtype=tf.float32) - - masks = tf.constant([ - np.zeros((5, 5)), - np.tri(5) - ], dtype=tf.float32) + noisy_stars = tf.constant( + [np.arange(25).reshape(5, 5), np.arange(25, 50).reshape(5, 5)], dtype=tf.float32 + ) + + noisy_masks = tf.constant([np.eye(5), np.ones((5, 5))], dtype=tf.float32) + + stars = tf.constant([np.full((5, 5), 100), np.full((5, 5), 200)], dtype=tf.float32) + + masks = tf.constant([np.zeros((5, 5)), np.tri(5)], dtype=tf.float32) return MockData( - training_positions, test_positions, training_zernike_priors, - test_zernike_priors, noisy_stars, noisy_masks, stars, masks + training_positions, + test_positions, + training_zernike_priors, + test_zernike_priors, + noisy_stars, + noisy_masks, + stars, + masks, ) + @pytest.fixture def simple_image(scope="module"): """Fixture for a simple star image.""" @@ -140,11 +137,13 @@ def simple_image(scope="module"): image[:, 2, 2] = 1 # Place the star at the center for each image return image + @pytest.fixture def identity_mask(scope="module"): """Creates a mask where all pixels are fully considered.""" return np.ones((5, 5)) + @pytest.fixture def multiple_images(scope="module"): """Fixture for a batch of images with stars at different positions.""" @@ -154,6 +153,7 @@ def multiple_images(scope="module"): images[2, 3, 1] = 1 # Star at (3, 1) in image 2 return images + @pytest.fixture(scope="module", params=[data]) def data_params(): return data diff --git a/src/wf_psf/tests/test_data/data_zernike_utils_test.py b/src/wf_psf/tests/test_data/data_zernike_utils_test.py index afafc1db..390d10f9 100644 --- a/src/wf_psf/tests/test_data/data_zernike_utils_test.py +++ b/src/wf_psf/tests/test_data/data_zernike_utils_test.py @@ -1,4 +1,3 @@ - import pytest import numpy as np from unittest.mock import MagicMock, patch @@ -10,11 +9,12 @@ combine_zernike_contributions, assemble_zernike_contributions, compute_zernike_tip_tilt, - pad_tf_zernikes + pad_tf_zernikes, ) from wf_psf.tests.test_data.test_data_utils import MockData from types import SimpleNamespace as RecursiveNamespace + @pytest.fixture def mock_model_params(): return RecursiveNamespace( @@ -24,6 +24,7 @@ def mock_model_params(): param_hparams=RecursiveNamespace(n_zernikes=6), ) + @pytest.fixture def dummy_prior(): return np.ones((4, 6), dtype=np.float32) @@ -41,22 +42,28 @@ def test_training_without_prior(mock_model_params, mock_data): mock_data.training_data.dataset.pop("zernike_prior", None) mock_data.test_data.dataset.pop("zernike_prior", None) - zinputs = ZernikeInputsFactory.build(data=mock_data, run_type="training", model_params=mock_model_params) + zinputs = ZernikeInputsFactory.build( + data=mock_data, run_type="training", model_params=mock_model_params + ) assert zinputs.centroid_dataset is mock_data assert zinputs.zernike_prior is None - expected_positions = np.concatenate([ - mock_data.training_data.dataset["positions"], - mock_data.test_data.dataset["positions"] - ]) + expected_positions = np.concatenate( + [ + mock_data.training_data.dataset["positions"], + mock_data.test_data.dataset["positions"], + ] + ) np.testing.assert_array_equal(zinputs.misalignment_positions, expected_positions) def test_training_with_dataset_prior(mock_model_params, mock_data): mock_model_params.use_prior = True - zinputs = ZernikeInputsFactory.build(data=mock_data, run_type="training", model_params=mock_model_params) + zinputs = ZernikeInputsFactory.build( + data=mock_data, run_type="training", model_params=mock_model_params + ) expected_priors = np.concatenate( ( @@ -77,7 +84,9 @@ def test_training_with_explicit_prior(mock_model_params, caplog): explicit_prior = np.array([9.0, 9.0, 9.0]) with caplog.at_level("WARNING"): - zinputs = ZernikeInputsFactory.build(data, "training", mock_model_params, prior=explicit_prior) + zinputs = ZernikeInputsFactory.build( + data, "training", mock_model_params, prior=explicit_prior + ) assert "Zernike prior explicitly provided" in caplog.text assert (zinputs.zernike_prior == explicit_prior).all() @@ -85,16 +94,24 @@ def test_training_with_explicit_prior(mock_model_params, caplog): def test_inference_with_dict_and_prior(mock_model_params): mock_model_params.use_prior = True - data = { - "positions": np.ones((5, 2)), - "zernike_prior": np.array([42.0, 0.0]) - } + data = RecursiveNamespace( + dataset={ + "positions": tf.ones((5, 2)), + "zernike_prior": tf.constant([42.0, 0.0]), + } + ) zinputs = ZernikeInputsFactory.build(data, "inference", mock_model_params) assert zinputs.centroid_dataset is None - assert (zinputs.zernike_prior == data["zernike_prior"]).all() - np.testing.assert_array_equal(zinputs.misalignment_positions, data["positions"]) + + # NumPy array comparison + np.testing.assert_array_equal( + zinputs.misalignment_positions, data.dataset["positions"].numpy() + ) + + # TensorFlow tensor comparison + tf.debugging.assert_equal(zinputs.zernike_prior, data.dataset["zernike_prior"]) def test_invalid_run_type(mock_model_params): @@ -111,7 +128,7 @@ def test_get_np_zernike_prior(): # Construct fake DataConfigHandler structure using RecursiveNamespace data = RecursiveNamespace( training_data=RecursiveNamespace(dataset={"zernike_prior": training_prior}), - test_data=RecursiveNamespace(dataset={"zernike_prior": test_prior}) + test_data=RecursiveNamespace(dataset={"zernike_prior": test_prior}), ) expected_prior = np.concatenate((training_prior, test_prior), axis=0) @@ -121,19 +138,24 @@ def test_get_np_zernike_prior(): # Assert shape and values match expected np.testing.assert_array_equal(result, expected_prior) + def test_pad_contribution_to_order(): # Input: batch of 2 samples, each with 3 Zernike coefficients - input_contribution = np.array([ - [1.0, 2.0, 3.0], - [4.0, 5.0, 6.0], - ]) - + input_contribution = np.array( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ) + max_order = 5 # Target size: pad to 5 coefficients - expected_output = np.array([ - [1.0, 2.0, 3.0, 0.0, 0.0], - [4.0, 5.0, 6.0, 0.0, 0.0], - ]) + expected_output = np.array( + [ + [1.0, 2.0, 3.0, 0.0, 0.0], + [4.0, 5.0, 6.0, 0.0, 0.0], + ] + ) padded = pad_contribution_to_order(input_contribution, max_order) @@ -149,6 +171,7 @@ def test_no_padding_needed(): assert output.shape == input_contribution.shape np.testing.assert_array_equal(output, input_contribution) + def test_padding_to_much_higher_order(): """Pad from order 2 to order 10.""" input_contribution = np.array([[1, 2], [3, 4]]) @@ -158,6 +181,7 @@ def test_padding_to_much_higher_order(): assert output.shape == (2, 10) np.testing.assert_array_equal(output, expected_output) + def test_empty_contribution(): """Test behavior with empty input array (0 features).""" input_contribution = np.empty((3, 0)) # 3 samples, 0 coefficients @@ -180,42 +204,44 @@ def test_zero_samples(): def test_combine_zernike_contributions_basic_case(): """Combine two contributions with matching sample count and varying order.""" - contrib1 = np.array([[1, 2], [3, 4]]) # shape (2, 2) - contrib2 = np.array([[5], [6]]) # shape (2, 1) - expected = np.array([ - [1 + 5, 2 + 0], - [3 + 6, 4 + 0] - ]) # padded contrib2 to (2, 2) + contrib1 = np.array([[1, 2], [3, 4]]) # shape (2, 2) + contrib2 = np.array([[5], [6]]) # shape (2, 1) + expected = np.array([[1 + 5, 2 + 0], [3 + 6, 4 + 0]]) # padded contrib2 to (2, 2) result = combine_zernike_contributions([contrib1, contrib2]) np.testing.assert_array_equal(result, expected) + def test_combine_multiple_contributions(): """Combine three contributions.""" - c1 = np.array([[1, 2, 3]]) # shape (1, 3) - c2 = np.array([[4, 5]]) # shape (1, 2) - c3 = np.array([[6]]) # shape (1, 1) - expected = np.array([[1+4+6, 2+5+0, 3+0+0]]) # shape (1, 3) + c1 = np.array([[1, 2, 3]]) # shape (1, 3) + c2 = np.array([[4, 5]]) # shape (1, 2) + c3 = np.array([[6]]) # shape (1, 1) + expected = np.array([[1 + 4 + 6, 2 + 5 + 0, 3 + 0 + 0]]) # shape (1, 3) result = combine_zernike_contributions([c1, c2, c3]) np.testing.assert_array_equal(result, expected) + def test_empty_input_list(): """Raise ValueError when input list is empty.""" with pytest.raises(ValueError, match="No contributions provided."): combine_zernike_contributions([]) + def test_inconsistent_sample_count(): """Raise error or produce incorrect shape if contributions have inconsistent sample counts.""" c1 = np.array([[1, 2], [3, 4]]) # shape (2, 2) - c2 = np.array([[5, 6]]) # shape (1, 2) + c2 = np.array([[5, 6]]) # shape (1, 2) with pytest.raises(ValueError): combine_zernike_contributions([c1, c2]) + def test_single_contribution(): """Combining a single contribution should return the same array (no-op).""" contrib = np.array([[7, 8, 9], [10, 11, 12]]) result = combine_zernike_contributions([contrib]) np.testing.assert_array_equal(result, contrib) + def test_zero_order_contributions(): """Contributions with 0 Zernike coefficients.""" contrib1 = np.empty((2, 0)) # 2 samples, 0 coefficients @@ -225,9 +251,12 @@ def test_zero_order_contributions(): assert result.shape == (2, 0) np.testing.assert_array_equal(result, expected) + @patch("wf_psf.data.data_zernike_utils.compute_centroid_correction") @patch("wf_psf.data.data_zernike_utils.compute_ccd_misalignment") -def test_full_contribution_combination(mock_ccd, mock_centroid, mock_model_params, dummy_prior, dummy_centroid_dataset): +def test_full_contribution_combination( + mock_ccd, mock_centroid, mock_model_params, dummy_prior, dummy_centroid_dataset +): mock_centroid.return_value = np.full((4, 6), 2.0) mock_ccd.return_value = np.full((4, 6), 3.0) dummy_positions = np.full((4, 6), 1.0) @@ -236,12 +265,13 @@ def test_full_contribution_combination(mock_ccd, mock_centroid, mock_model_param model_params=mock_model_params, zernike_prior=dummy_prior, centroid_dataset=dummy_centroid_dataset, - positions = dummy_positions + positions=dummy_positions, ) - + expected = dummy_prior + 2.0 + 3.0 np.testing.assert_allclose(result.numpy(), expected) + def test_prior_only(mock_model_params, dummy_prior): mock_model_params.correct_centroids = False mock_model_params.add_ccd_misalignments = False @@ -250,11 +280,12 @@ def test_prior_only(mock_model_params, dummy_prior): model_params=mock_model_params, zernike_prior=dummy_prior, centroid_dataset=None, - positions=None + positions=None, ) np.testing.assert_array_equal(result.numpy(), dummy_prior) + def test_no_contributions_returns_zeros(): model_params = RecursiveNamespace( use_prior=False, @@ -269,6 +300,7 @@ def test_no_contributions_returns_zeros(): assert result.shape == (1, 8) np.testing.assert_array_equal(result.numpy(), np.zeros((1, 8))) + def test_prior_as_tensor(mock_model_params): tensor_prior = tf.ones((4, 6), dtype=tf.float32) @@ -276,24 +308,28 @@ def test_prior_as_tensor(mock_model_params): mock_model_params.add_ccd_misalignments = False result = assemble_zernike_contributions( - model_params=mock_model_params, - zernike_prior=tensor_prior + model_params=mock_model_params, zernike_prior=tensor_prior ) assert tf.executing_eagerly(), "TensorFlow must be in eager mode for this test" assert isinstance(result, tf.Tensor) np.testing.assert_array_equal(result.numpy(), np.ones((4, 6))) + @patch("wf_psf.data.data_zernike_utils.compute_centroid_correction") -def test_inconsistent_shapes_raises_error(mock_centroid, mock_model_params, dummy_prior, dummy_centroid_dataset): +def test_inconsistent_shapes_raises_error( + mock_centroid, mock_model_params, dummy_prior, dummy_centroid_dataset +): mock_model_params.add_ccd_misalignments = False mock_centroid.return_value = np.ones((5, 6)) # 5 samples instead of 4 - with pytest.raises(ValueError, match="All contributions must have the same number of samples"): + with pytest.raises( + ValueError, match="All contributions must have the same number of samples" + ): assemble_zernike_contributions( model_params=mock_model_params, zernike_prior=dummy_prior, centroid_dataset=dummy_centroid_dataset, - positions=None + positions=None, ) @@ -307,14 +343,10 @@ def test_pad_zernikes_num_of_zernikes_equal(): zk_prior = tf.reshape(zk_prior, (1, 2, 1, 1)) # Reset _n_zks_total to max number of zernikes (2 here) - n_zks_total = max( - tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() - ) + n_zks_total = max(tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy()) # Call pad_zernikes method - padded_zk_param, padded_zk_prior = pad_tf_zernikes( - zk_param, zk_prior, n_zks_total - ) + padded_zk_param, padded_zk_prior = pad_tf_zernikes(zk_param, zk_prior, n_zks_total) # Assert shapes are equal and correct assert padded_zk_param.shape[1] == n_zks_total @@ -324,6 +356,7 @@ def test_pad_zernikes_num_of_zernikes_equal(): np.testing.assert_array_equal(padded_zk_param.numpy(), zk_param.numpy()) np.testing.assert_array_equal(padded_zk_prior.numpy(), zk_prior.numpy()) + def test_pad_zernikes_prior_greater_than_param(): zk_param = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) zk_prior = tf.constant([[[[1]], [[2]], [[3]], [[4]], [[5]]]]) # Shape: (5, 1, 1, 1) @@ -333,14 +366,10 @@ def test_pad_zernikes_prior_greater_than_param(): zk_prior = tf.reshape(zk_prior, (1, 5, 1, 1)) # Reshaping tensor2 to (1, 5, 1, 1) # Reset n_zks_total attribute - n_zks_total = max( - tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() - ) + n_zks_total = max(tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy()) # Call the method under test - padded_zk_param, padded_zk_prior = pad_tf_zernikes( - zk_param, zk_prior, n_zks_total - ) + padded_zk_param, padded_zk_prior = pad_tf_zernikes(zk_param, zk_prior, n_zks_total) # Assert that the padded tensors have the correct shapes assert padded_zk_param.shape == (1, 5, 1, 1) @@ -356,14 +385,10 @@ def test_pad_zernikes_param_greater_than_prior(): zk_prior = tf.reshape(zk_prior, (1, 2, 1, 1)) # Reshaping tensor2 to (1, 4, 1, 1) # Reset n_zks_total attribute - n_zks_total = max( - tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() - ) + n_zks_total = max(tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy()) # Call the method under test - padded_zk_param, padded_zk_prior = pad_tf_zernikes( - zk_param, zk_prior, n_zks_total - ) + padded_zk_param, padded_zk_prior = pad_tf_zernikes(zk_param, zk_prior, n_zks_total) # Assert that the padded tensors have the correct shapes assert padded_zk_param.shape == (1, 4, 1, 1) @@ -374,16 +399,20 @@ def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_ma """Test compute_zernike_tip_tilt handling with single batch input and mocks.""" # Mock the CentroidEstimator class - mock_centroid_calc = mocker.patch("wf_psf.data.centroids.CentroidEstimator", autospec=True) + mock_centroid_calc = mocker.patch( + "wf_psf.data.centroids.CentroidEstimator", autospec=True + ) # Create a mock instance and configure get_intra_pixel_shifts() mock_instance = mock_centroid_calc.return_value - mock_instance.get_intra_pixel_shifts.return_value = np.array([[0.05, -0.02]]) # Shape (1, 2) + mock_instance.get_intra_pixel_shifts.return_value = np.array( + [[0.05, -0.02]] + ) # Shape (1, 2) # Mock shift_x_y_to_zk1_2_wavediff to return predictable values mock_shift_fn = mocker.patch( "wf_psf.data.data_zernike_utils.shift_x_y_to_zk1_2_wavediff", - side_effect=lambda shift: shift * 0.5 # Mocked conversion for test + side_effect=lambda shift: shift * 0.5, # Mocked conversion for test ) # Define test inputs (batch of 1 image) @@ -391,41 +420,58 @@ def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_ma reference_shifts = [-1 / 3, -1 / 3] # Default Euclid conditions # Run the function - zernike_corrections = compute_zernike_tip_tilt(simple_image, identity_mask, pixel_sampling, reference_shifts) - zernike_corrections = compute_zernike_tip_tilt(simple_image, identity_mask, pixel_sampling, reference_shifts) + zernike_corrections = compute_zernike_tip_tilt( + simple_image, identity_mask, pixel_sampling, reference_shifts + ) + zernike_corrections = compute_zernike_tip_tilt( + simple_image, identity_mask, pixel_sampling, reference_shifts + ) # Expected shifts based on centroid calculation - expected_dx = (reference_shifts[1] - (-0.02)) # Expected x-axis shift in meters - expected_dy = (reference_shifts[0] - 0.05) # Expected y-axis shift in meters + expected_dx = reference_shifts[1] - (-0.02) # Expected x-axis shift in meters + expected_dy = reference_shifts[0] - 0.05 # Expected y-axis shift in meters # Expected calls to the mocked function # Extract the arguments passed to mock_shift_fn - args, _ = mock_shift_fn.call_args_list[0] # Get the first call args + args, _ = mock_shift_fn.call_args_list[0] # Get the first call args # Compare expected values with the actual arguments passed to the mock function - np.testing.assert_allclose(args[0][0], expected_dx * pixel_sampling, rtol=1e-7, atol=0) + np.testing.assert_allclose( + args[0][0], expected_dx * pixel_sampling, rtol=1e-7, atol=0 + ) # Check dy values similarly - np.testing.assert_allclose(args[0][1], expected_dy * pixel_sampling, rtol=1e-7, atol=0) + np.testing.assert_allclose( + args[0][1], expected_dy * pixel_sampling, rtol=1e-7, atol=0 + ) # Expected values based on mock side_effect (0.5 * shift) - np.testing.assert_allclose(zernike_corrections[0, 0], expected_dx * pixel_sampling * 0.5) # Zk1 - np.testing.assert_allclose(zernike_corrections[0, 1], expected_dy * pixel_sampling * 0.5) # Zk2 + np.testing.assert_allclose( + zernike_corrections[0, 0], expected_dx * pixel_sampling * 0.5 + ) # Zk1 + np.testing.assert_allclose( + zernike_corrections[0, 1], expected_dy * pixel_sampling * 0.5 + ) # Zk2 + def test_compute_zernike_tip_tilt_batch(mocker, multiple_images): """Test compute_zernike_tip_tilt batch handling of multiple inputs.""" - + # Mock the CentroidEstimator class - mock_centroid_calc = mocker.patch("wf_psf.data.centroids.CentroidEstimator", autospec=True) + mock_centroid_calc = mocker.patch( + "wf_psf.data.centroids.CentroidEstimator", autospec=True + ) # Create a mock instance and configure get_intra_pixel_shifts() mock_instance = mock_centroid_calc.return_value - mock_instance.get_intra_pixel_shifts.return_value = np.array([[0.05, -0.02], [0.04, -0.01], [0.06, -0.03]]) # Shape (3, 2) + mock_instance.get_intra_pixel_shifts.return_value = np.array( + [[0.05, -0.02], [0.04, -0.01], [0.06, -0.03]] + ) # Shape (3, 2) # Mock shift_x_y_to_zk1_2_wavediff to return predictable values mock_shift_fn = mocker.patch( "wf_psf.data.data_zernike_utils.shift_x_y_to_zk1_2_wavediff", - side_effect=lambda shift: shift * 0.5 # Mocked conversion for test + side_effect=lambda shift: shift * 0.5, # Mocked conversion for test ) # Define test inputs (batch of 3 images) @@ -434,16 +480,18 @@ def test_compute_zernike_tip_tilt_batch(mocker, multiple_images): # Run the function zernike_corrections = compute_zernike_tip_tilt( - star_images=multiple_images, - pixel_sampling=pixel_sampling, - reference_shifts=reference_shifts - ) + star_images=multiple_images, + pixel_sampling=pixel_sampling, + reference_shifts=reference_shifts, + ) # Check if the mock function was called once with the full batch - assert len(mock_shift_fn.call_args_list) == 1, f"Expected 1 call, but got {len(mock_shift_fn.call_args_list)}" + assert ( + len(mock_shift_fn.call_args_list) == 1 + ), f"Expected 1 call, but got {len(mock_shift_fn.call_args_list)}" # Get the arguments passed to the mock function for the batch of images - args, _ = mock_shift_fn.call_args_list[0] + args, _ = mock_shift_fn.call_args_list[0] print("Shape of args[0]:", args[0].shape) print("Contents of args[0]:", args[0]) @@ -453,13 +501,25 @@ def test_compute_zernike_tip_tilt_batch(mocker, multiple_images): args_array = np.array(args[0]).reshape(-1, 2) # Process the displacements and expected values for each image in the batch - expected_dx = reference_shifts[1] - mock_instance.get_intra_pixel_shifts.return_value[:, 1] # Expected x-axis shift in meters - expected_dy = reference_shifts[0] - mock_instance.get_intra_pixel_shifts.return_value[:, 0] # Expected y-axis shift in meters + expected_dx = ( + reference_shifts[1] - mock_instance.get_intra_pixel_shifts.return_value[:, 1] + ) # Expected x-axis shift in meters + expected_dy = ( + reference_shifts[0] - mock_instance.get_intra_pixel_shifts.return_value[:, 0] + ) # Expected y-axis shift in meters # Compare expected values with the actual arguments passed to the mock function - np.testing.assert_allclose(args_array[:, 0], expected_dx * pixel_sampling, rtol=1e-7, atol=0) - np.testing.assert_allclose(args_array[:, 1], expected_dy * pixel_sampling, rtol=1e-7, atol=0) + np.testing.assert_allclose( + args_array[:, 0], expected_dx * pixel_sampling, rtol=1e-7, atol=0 + ) + np.testing.assert_allclose( + args_array[:, 1], expected_dy * pixel_sampling, rtol=1e-7, atol=0 + ) # Expected values based on mock side_effect (0.5 * shift) - np.testing.assert_allclose(zernike_corrections[:, 0], expected_dx * pixel_sampling * 0.5) # Zk1 for each image - np.testing.assert_allclose(zernike_corrections[:, 1], expected_dy * pixel_sampling * 0.5) # Zk2 for each image \ No newline at end of file + np.testing.assert_allclose( + zernike_corrections[:, 0], expected_dx * pixel_sampling * 0.5 + ) # Zk1 for each image + np.testing.assert_allclose( + zernike_corrections[:, 1], expected_dy * pixel_sampling * 0.5 + ) # Zk2 for each image From 633266adc1f7e3b73e123ac56d621ac3011853cb Mon Sep 17 00:00:00 2001 From: jeipollack Date: Wed, 27 Aug 2025 11:19:53 +0200 Subject: [PATCH 097/146] Eagerly initialise trainable layers in physical poly model constructor required for evaluation/inference --- .../models/psf_model_physical_polychromatic.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index 918ffc4b..e2302da5 100644 --- a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -120,15 +120,15 @@ def __init__(self, model_params, training_params, data, coeff_mat=None): self.run_type = self._get_run_type(data) self.obs_pos = self.get_obs_pos() - # Initialize the model parameters and layers + # Initialize the model parameters self.output_Q = model_params.output_Q self.l2_param = model_params.param_hparams.l2_param self.output_dim = model_params.output_dim - + # Initialise lazy loading of external Zernike prior self._external_prior = None - # Initialize the model parameters with non-default value + # Set Zernike Polynomial Coefficient Matrix if not None if coeff_mat is not None: self.assign_coeff_matrix(coeff_mat) @@ -158,9 +158,10 @@ def __init__(self, model_params, training_params, data, coeff_mat=None): rotation_angle=self.model_params.obscuration_rotation_angle, ) - # Eagerly initialise tf_batch_poly_PSF + # Eagerly initialise model layers self.tf_batch_poly_PSF = self._build_tf_batch_poly_PSF() - + _ = self.tf_poly_Z_field + _ = self.tf_np_poly_opd def _get_run_type(self, data): if hasattr(data, 'run_type'): From c72cb84cea9edca3e76c0397b547a7578f9c7204 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Wed, 27 Aug 2025 11:21:59 +0200 Subject: [PATCH 098/146] fix: use expect_partial() when loading model weights for evaluation - Add status handling to model.load_weights() call - Use expect_partial() to suppress warnings about unused optimizer state - Allows successful weight loading for metrics evaluation when checkpoint contains training artifacts --- src/wf_psf/psf_models/psf_model_loader.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/wf_psf/psf_models/psf_model_loader.py b/src/wf_psf/psf_models/psf_model_loader.py index 797be8fc..c30d31ad 100644 --- a/src/wf_psf/psf_models/psf_model_loader.py +++ b/src/wf_psf/psf_models/psf_model_loader.py @@ -12,7 +12,7 @@ get_psf_model, get_psf_model_weights_filepath ) - +import tensorflow as tf logger = logging.getLogger(__name__) @@ -48,7 +48,9 @@ def load_trained_psf_model(training_conf, data_conf, weights_path_pattern): try: logger.info(f"Loading PSF model weights from {weights_path}") - model.load_weights(weights_path) + status = model.load_weights(weights_path) + status.expect_partial() + except Exception as e: logger.exception("Failed to load model weights.") raise RuntimeError("Model weight loading failed.") from e From fea4e3e25413eafe515b06e9b212bc25381ab335 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Wed, 27 Aug 2025 11:24:43 +0200 Subject: [PATCH 099/146] Add memory cleanup after training completion - Delete model reference and run garbage collection - Clear TensorFlow session to free GPU memory - Prevents OOM issues in subsequent operations or multiple training runs --- src/wf_psf/training/train.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/wf_psf/training/train.py b/src/wf_psf/training/train.py index 3ee47fa8..2da47900 100644 --- a/src/wf_psf/training/train.py +++ b/src/wf_psf/training/train.py @@ -7,6 +7,7 @@ """ +import gc import numpy as np import time import tensorflow as tf @@ -525,3 +526,8 @@ def train( final_time = time.time() logger.info("\nTotal elapsed time: %f" % (final_time - starting_time)) logger.info("\n Training complete..") + + # Clean up memory + del psf_model + gc.collect() + tf.keras.backend.clear_session() From 12335904bb22c061c24f8b48a8f6a6a0092169c6 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 4 Sep 2025 01:59:17 +0200 Subject: [PATCH 100/146] refactor: centralise PSF data extraction in data_handler - Introduce unified `get_data_array` for training/metrics/inference access - Add helpers `extract_star_data` and `_get_inference_data` - Remove redundant `get_np_obs_positions` - Move data handling logic out of `compute_centroid_corrections` - Standardise `centroid_dataset` as dict (stamps + optional masks) - Support optional keys (masks, priors) via `allow_missing` - Improve and unify docstrings for data extraction utilities - Add optional "sources" and "masks" attributes to `PSFInference - Add `correct_centroids` and `add_ccd_misalignments` as options to inference_config.yaml --- src/wf_psf/data/centroids.py | 35 ++- src/wf_psf/data/data_handler.py | 200 ++++++++++++++---- src/wf_psf/data/data_zernike_utils.py | 27 +-- src/wf_psf/inference/psf_inference.py | 6 +- .../psf_model_physical_polychromatic.py | 7 +- 5 files changed, 193 insertions(+), 82 deletions(-) diff --git a/src/wf_psf/data/centroids.py b/src/wf_psf/data/centroids.py index 0a3362ef..4b421597 100644 --- a/src/wf_psf/data/centroids.py +++ b/src/wf_psf/data/centroids.py @@ -14,7 +14,7 @@ from typing import Optional -def compute_centroid_correction(model_params, data, batch_size: int=1) -> np.ndarray: +def compute_centroid_correction(model_params, centroid_dataset, batch_size: int=1) -> np.ndarray: """Compute centroid corrections using Zernike polynomials. This function calculates the Zernike contributions required to match the centroid @@ -25,10 +25,13 @@ def compute_centroid_correction(model_params, data, batch_size: int=1) -> np.nda model_params : RecursiveNamespace An object containing parameters for the PSF model, including pixel sampling and initial centroid window parameters. - data : DataConfigHandler - An object containing training and test datasets, including observed PSFs - and optional star masks. - + centroid_dataset : dict + Dictionary containing star data needed for centroiding: + - "stamps" : np.ndarray + Array of star postage stamps (required). + - "masks" : Optional[np.ndarray] + Array of star masks (optional, can be None). + batch_size : int, optional The batch size to use when processing the stars. Default is 16. @@ -39,24 +42,14 @@ def compute_centroid_correction(model_params, data, batch_size: int=1) -> np.nda observed stars. The array contains the computed Zernike (Z1, Z2) contributions, with zero padding applied to the first column to ensure a consistent shape. """ - star_postage_stamps = extract_star_data(data=data, train_key="noisy_stars", test_key="stars") - - # Get star mask catalogue only if "masks" exist in both training and test datasets - star_masks = ( - extract_star_data(data=data, train_key="masks", test_key="masks") - if ( - data.training_data.dataset.get("masks") is not None - and data.test_data.dataset.get("masks") is not None - and tf.size(data.training_data.dataset["masks"]) > 0 - and tf.size(data.test_data.dataset["masks"]) > 0 - ) - else None - ) + # Retrieve stamps and masks from centroid_dataset + star_postage_stamps = centroid_dataset.get("stamps") + star_masks = centroid_dataset.get("masks") # may be None - pix_sampling = model_params.pix_sampling * 1e-6 # Change units from [um] to [m] + if star_postage_stamps is None: + raise ValueError("centroid_dataset must contain 'stamps'") - # Ensure star_masks is properly handled - star_masks = star_masks if star_masks is not None else None + pix_sampling = model_params.pix_sampling * 1e-6 # Change units from [um] to [m] reference_shifts = [float(Fraction(value)) for value in model_params.reference_shifts] diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index fe660940..bd802763 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -240,67 +240,43 @@ def process_sed_data(self, sed_data): self.sed_data = tf.transpose(self.sed_data, perm=[0, 2, 1]) -def get_np_obs_positions(data): - """Get observed positions in numpy from the provided dataset. - - This method concatenates the positions of the stars from both the training - and test datasets to obtain the observed positions. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - np.ndarray - Numpy array containing the observed positions of the stars. - - Notes - ----- - The observed positions are obtained by concatenating the positions of stars - from both the training and test datasets along the 0th axis. - """ - obs_positions = np.concatenate( - ( - data.training_data.dataset["positions"], - data.test_data.dataset["positions"], - ), - axis=0, - ) - - return obs_positions - - def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: - """Extract specific star-related data from training and test datasets. + """ + Extract and concatenate star-related data from training and test datasets. - This function retrieves and concatenates specific star-related data (e.g., stars, masks) from the - star training and test datasets such as star stamps or masks, based on the provided keys. + This function retrieves arrays (e.g., postage stamps, masks, positions) from + both the training and test datasets using the specified keys, converts them + to NumPy if necessary, and concatenates them along the first axis. Parameters ---------- data : DataConfigHandler Object containing training and test datasets. train_key : str - The key to retrieve data from the training dataset (e.g., 'noisy_stars', 'masks'). + Key to retrieve data from the training dataset + (e.g., 'noisy_stars', 'masks'). test_key : str - The key to retrieve data from the test dataset (e.g., 'stars', 'masks'). + Key to retrieve data from the test dataset + (e.g., 'stars', 'masks'). Returns ------- np.ndarray - A NumPy array containing the concatenated data for the given keys. + Concatenated NumPy array containing the selected data from both + training and test sets. Raises ------ KeyError - If the specified keys do not exist in the training or test datasets. + If either the training or test dataset does not contain the + requested key. Notes ----- - - If the dataset contains TensorFlow tensors, they will be converted to NumPy arrays. - - Ensure that eager execution is enabled when calling this function. + - Designed for datasets with separate train/test splits, such as when + evaluating metrics on held-out data. + - TensorFlow tensors are automatically converted to NumPy arrays. + - Requires eager execution if TensorFlow tensors are present. """ # Ensure the requested keys exist in both training and test datasets missing_keys = [ @@ -327,3 +303,145 @@ def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: # Concatenate and return return np.concatenate((train_data, test_data), axis=0) + +def get_data_array( + data, + run_type: str, + key: str = None, + train_key: str = None, + test_key: str = None, + allow_missing: bool = False, +) -> np.ndarray | None: + """ + Retrieve data from dataset depending on run type. + + This function provides a unified interface for accessing data across different + execution contexts (training, simulation, metrics, inference). It handles + key resolution with sensible fallbacks and optional missing data tolerance. + + Parameters + ---------- + data : DataConfigHandler + Dataset object containing training, test, or inference data. + Expected to have methods compatible with the specified run_type. + run_type : {"training", "simulation", "metrics", "inference"} + Execution context that determines how data is retrieved: + - "training", "simulation", "metrics": Uses extract_star_data function + - "inference": Retrieves data directly from dataset using key lookup + key : str, optional + Primary key for data lookup. Used directly for inference run_type. + If None, falls back to train_key value. Default is None. + train_key : str, optional + Key for training dataset access. If None and key is provided, + defaults to key value. Default is None. + test_key : str, optional + Key for test dataset access. If None, defaults to the resolved + train_key value. Default is None. + allow_missing : bool, default False + Control behavior when data is missing or keys are not found: + - True: Return None instead of raising exceptions + - False: Raise appropriate exceptions (KeyError, ValueError) + + Returns + ------- + np.ndarray or None + Retrieved data as NumPy array. Returns None only when allow_missing=True + and the requested data is not available. + + Raises + ------ + ValueError + If run_type is not one of the supported values, or if no key can be + resolved for the operation and allow_missing=False. + KeyError + If the specified key is not found in the dataset and allow_missing=False. + + Notes + ----- + Key resolution follows this priority order: + 1. train_key = train_key or key + 2. test_key = test_key or resolved_train_key + 3. key = key or resolved_train_key (for inference fallback) + + For TensorFlow tensors, the .numpy() method is called to convert to NumPy. + Other data types are converted using np.asarray(). + + Examples + -------- + >>> # Training data retrieval + >>> train_data = get_data_array(data, "training", train_key="noisy_stars") + + >>> # Inference with fallback handling + >>> inference_data = get_data_array(data, "inference", key="positions", + ... allow_missing=True) + >>> if inference_data is None: + ... print("No inference data available") + + >>> # Using key parameter for both train and inference + >>> result = get_data_array(data, "inference", key="positions") + """ + # Validate run_type early + valid_run_types = {"training", "simulation", "metrics", "inference"} + if run_type not in valid_run_types: + raise ValueError(f"run_type must be one of {valid_run_types}, got '{run_type}'") + + # Simplify key resolution with clear precedence + effective_train_key = train_key or key + effective_test_key = test_key or effective_train_key + effective_key = key or effective_train_key + + try: + if run_type in {"simulation", "training", "metrics"}: + return extract_star_data(data, effective_train_key, effective_test_key) + else: # inference + return _get_inference_data(data, effective_key, allow_missing) + except Exception as e: + if allow_missing: + return None + raise + + +def _get_inference_data(data, key: str, allow_missing: bool) -> np.ndarray | None: + """ + Extract inference data with proper error handling and type conversion. + + Parameters + ---------- + data : DataConfigHandler + Dataset object with a .dataset attribute that supports .get() method. + key : str or None + Key to lookup in the dataset. If None, behavior depends on allow_missing. + allow_missing : bool + If True, return None for missing keys/data instead of raising exceptions. + + Returns + ------- + np.ndarray or None + Data converted to NumPy array, or None if allow_missing=True and + data is unavailable. + + Raises + ------ + ValueError + If key is None and allow_missing=False. + KeyError + If key is not found in dataset and allow_missing=False. + + Notes + ----- + Conversion logic: + - TensorFlow tensors: Converted using .numpy() method + - Other types: Converted using np.asarray() + """ + if key is None: + if allow_missing: + return None + raise ValueError("No key provided for inference data") + + value = data.dataset.get(key, None) + if value is None: + if allow_missing: + return None + raise KeyError(f"Key '{key}' not found in inference dataset") + + return value.numpy() if tf.is_tensor(value) else np.asarray(value) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index f7653540..9bf9d1fd 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -16,6 +16,7 @@ import numpy as np import tensorflow as tf from wf_psf.data.centroids import compute_centroid_correction +from wf_psf.data.data_handler import get_data_array from wf_psf.instrument.ccd_misalignments import compute_ccd_misalignment from wf_psf.utils.read_config import RecursiveNamespace import logging @@ -54,29 +55,29 @@ def build( ------- ZernikeInputs """ - centroid_dataset = None - positions = None + centroid_dataset, positions = None, None if run_type in {"training", "simulation", "metrics"}: - centroid_dataset = data # Assuming data is a DataConfigHandler or similar object containing train and test datasets - positions = np.concatenate( - [ - data.training_data.dataset["positions"].numpy(), - data.test_data.dataset["positions"].numpy(), - ], - axis=0, - ) + stamps = get_data_array(data, run_type, train_key="noisy_stars", test_key="stars") + masks = get_data_array(data, run_type, key="masks", allow_missing=True) + centroid_dataset = {"stamps": stamps, "masks": masks} + + positions = get_data_array(data=data, run_type=run_type, key="positions") + if model_params.use_prior: if prior is not None: logger.warning( - "Zernike prior explicitly provided; ignoring dataset-based prior despite use_prior=True." + "Explicit prior provided; ignoring dataset-based prior." ) else: prior = get_np_zernike_prior(data) elif run_type == "inference": - centroid_dataset = None - positions = data.dataset["positions"].numpy() + stamps = get_data_array(data=data, run_type=run_type, key="sources") + masks = get_data_array(data, run_type, key="masks", allow_missing=True) + centroid_dataset = {"stamps": stamps, "masks": masks} + + positions = get_data_array(data=data, run_type=run_type, key="positions") if model_params.use_prior: # Try to extract prior from `data`, if present diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 76619f5a..a37e6ba7 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -97,7 +97,7 @@ class PSFInference: Spectral energy distributions (SEDs). """ - def __init__(self, inference_config_path: str, x_field=None, y_field=None, seds=None): + def __init__(self, inference_config_path: str, x_field=None, y_field=None, seds=None, sources=None, masks=None): self.inference_config_path = inference_config_path @@ -105,6 +105,8 @@ def __init__(self, inference_config_path: str, x_field=None, y_field=None, seds= self.x_field = x_field self.y_field = y_field self.seds = seds + self.sources = sources + self.masks = masks # Internal caches for lazy-loading self._config_handler = None @@ -157,7 +159,7 @@ def _prepare_dataset_for_inference(self): positions = self.get_positions() if positions is None: return None - return {"positions": positions} + return {"positions": positions, "sources": self.sources, "masks": self.masks} @property def data_handler(self): diff --git a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index e2302da5..c8086cb7 100644 --- a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -10,7 +10,7 @@ from typing import Optional import tensorflow as tf from tensorflow.python.keras.engine import data_adapter -from wf_psf.data.data_handler import get_np_obs_positions +from wf_psf.data.data_handler import get_data_array from wf_psf.data.data_zernike_utils import ( ZernikeInputsFactory, assemble_zernike_contributions, @@ -203,10 +203,7 @@ def save_nonparam_history(self) -> bool: def get_obs_pos(self): assert self.run_type in {"training", "simulation", "metrics", "inference"}, f"Unknown run_type: {self.run_type}" - if self.run_type in {"training", "simulation", "metrics"}: - raw_pos = get_np_obs_positions(self.data) - else: - raw_pos = self.data.dataset["positions"] + raw_pos = get_data_array(data=self.data, run_type=self.run_type, key="positions") obs_pos = ensure_tensor(raw_pos, dtype=tf.float32) From c3a522ac3f373f38d5db673b6dadcf3ed38e647e Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 4 Sep 2025 02:41:53 +0200 Subject: [PATCH 101/146] Add and options to inference_config.yaml (forgot to stage with previous commit) --- config/inference_conf.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/config/inference_conf.yaml b/config/inference_conf.yaml index c9d29cb8..927723c7 100644 --- a/config/inference_conf.yaml +++ b/config/inference_conf.yaml @@ -30,3 +30,8 @@ inference: # Dimension of the pixel PSF postage stamp output_dim: 64 + # Flag to perform centroid error correction + correct_centroids: False + + # Flag to perform CCD misalignment error correction + add_ccd_misalignments: True From 0d1aa6ca56f8282840b21dcdff62087923bf64f2 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 4 Sep 2025 04:35:23 +0200 Subject: [PATCH 102/146] Update PSFInference doc string with new optional attributes --- src/wf_psf/inference/psf_inference.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index a37e6ba7..5c27fdbf 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -95,6 +95,10 @@ class PSFInference: y coordinates in SHE convention. seds : array-like, optional Spectral energy distributions (SEDs). + sources : array-like, optional + Postage stamps of sources, e.g. star images (shape: [n_stars, h, w]). + masks : array-like, optional + Corresponding masks for the sources (same shape as sources). Defaults to None. """ def __init__(self, inference_config_path: str, x_field=None, y_field=None, seds=None, sources=None, masks=None): From 2a44500adaee06e166d1aa073675d704086dfa01 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 4 Sep 2025 05:44:34 +0200 Subject: [PATCH 103/146] Rename _get_inference_data to _get_direct_data --- src/wf_psf/data/data_handler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index bd802763..c0ddf7b0 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -394,16 +394,16 @@ def get_data_array( if run_type in {"simulation", "training", "metrics"}: return extract_star_data(data, effective_train_key, effective_test_key) else: # inference - return _get_inference_data(data, effective_key, allow_missing) + return _get_direct_data(data, effective_key, allow_missing) except Exception as e: if allow_missing: return None raise -def _get_inference_data(data, key: str, allow_missing: bool) -> np.ndarray | None: +def _get_direct_data(data, key: str, allow_missing: bool) -> np.ndarray | None: """ - Extract inference data with proper error handling and type conversion. + Extract data directly with proper error handling and type conversion. Parameters ---------- From 1fc9cddd5246a77cc6d00559323fd73d83a9a411 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 5 Sep 2025 10:37:27 -0400 Subject: [PATCH 104/146] Reformat with black --- src/wf_psf/__init__.py | 8 +- src/wf_psf/data/centroids.py | 136 ++++++----- src/wf_psf/data/data_zernike_utils.py | 6 +- src/wf_psf/inference/psf_inference.py | 84 ++++--- src/wf_psf/instrument/ccd_misalignments.py | 32 +-- src/wf_psf/metrics/metrics_interface.py | 69 +++--- .../psf_model_physical_polychromatic.py | 125 +++++----- src/wf_psf/psf_models/psf_model_loader.py | 16 +- src/wf_psf/psf_models/tf_modules/tf_layers.py | 2 +- .../psf_models/tf_modules/tf_modules.py | 91 ++++--- src/wf_psf/psf_models/tf_modules/tf_utils.py | 42 ++-- src/wf_psf/sims/psf_simulator.py | 2 +- .../masked_loss/results/plot_results.ipynb | 164 ++++++++++--- src/wf_psf/tests/test_data/centroids_test.py | 227 +++++++++++------- .../tests/test_data/data_handler_test.py | 73 +++--- src/wf_psf/tests/test_data/test_data_utils.py | 22 +- .../test_inference/psf_inference_test.py | 183 ++++++++------ src/wf_psf/tests/test_metrics/conftest.py | 9 +- .../test_metrics/metrics_interface_test.py | 77 +++--- .../psf_model_physical_polychromatic_test.py | 39 +-- .../tests/test_psf_models/psf_models_test.py | 2 +- .../tests/test_utils/configs_handler_test.py | 21 +- src/wf_psf/tests/test_utils/utils_test.py | 32 ++- src/wf_psf/training/train.py | 27 ++- src/wf_psf/utils/configs_handler.py | 58 +++-- src/wf_psf/utils/read_config.py | 4 +- src/wf_psf/utils/utils.py | 1 + 27 files changed, 944 insertions(+), 608 deletions(-) diff --git a/src/wf_psf/__init__.py b/src/wf_psf/__init__.py index 863675f1..988b02fe 100644 --- a/src/wf_psf/__init__.py +++ b/src/wf_psf/__init__.py @@ -1,7 +1,7 @@ import importlib # Dynamically import modules to trigger side effects when wf_psf is imported -importlib.import_module('wf_psf.psf_models.psf_models') -importlib.import_module('wf_psf.psf_models.models.psf_model_semiparametric') -importlib.import_module('wf_psf.psf_models.models.psf_model_physical_polychromatic') -importlib.import_module('wf_psf.psf_models.tf_modules.tf_psf_field') +importlib.import_module("wf_psf.psf_models.psf_models") +importlib.import_module("wf_psf.psf_models.models.psf_model_semiparametric") +importlib.import_module("wf_psf.psf_models.models.psf_model_physical_polychromatic") +importlib.import_module("wf_psf.psf_models.tf_modules.tf_psf_field") diff --git a/src/wf_psf/data/centroids.py b/src/wf_psf/data/centroids.py index 4b421597..20391793 100644 --- a/src/wf_psf/data/centroids.py +++ b/src/wf_psf/data/centroids.py @@ -14,7 +14,9 @@ from typing import Optional -def compute_centroid_correction(model_params, centroid_dataset, batch_size: int=1) -> np.ndarray: +def compute_centroid_correction( + model_params, centroid_dataset, batch_size: int = 1 +) -> np.ndarray: """Compute centroid corrections using Zernike polynomials. This function calculates the Zernike contributions required to match the centroid @@ -31,15 +33,15 @@ def compute_centroid_correction(model_params, centroid_dataset, batch_size: int= Array of star postage stamps (required). - "masks" : Optional[np.ndarray] Array of star masks (optional, can be None). - + batch_size : int, optional The batch size to use when processing the stars. Default is 16. Returns ------- zernike_centroid_array : np.ndarray - A 2D NumPy array of shape `(n_stars, 3)`, where `n_stars` is the number of - observed stars. The array contains the computed Zernike (Z1, Z2) contributions, + A 2D NumPy array of shape `(n_stars, 3)`, where `n_stars` is the number of + observed stars. The array contains the computed Zernike (Z1, Z2) contributions, with zero padding applied to the first column to ensure a consistent shape. """ # Retrieve stamps and masks from centroid_dataset @@ -51,15 +53,17 @@ def compute_centroid_correction(model_params, centroid_dataset, batch_size: int= pix_sampling = model_params.pix_sampling * 1e-6 # Change units from [um] to [m] - reference_shifts = [float(Fraction(value)) for value in model_params.reference_shifts] + reference_shifts = [ + float(Fraction(value)) for value in model_params.reference_shifts + ] n_stars = len(star_postage_stamps) zernike_centroid_array = [] # Batch process the stars for i in range(0, n_stars, batch_size): - batch_postage_stamps = star_postage_stamps[i:i + batch_size] - batch_masks = star_masks[i:i + batch_size] if star_masks is not None else None + batch_postage_stamps = star_postage_stamps[i : i + batch_size] + batch_masks = star_masks[i : i + batch_size] if star_masks is not None else None # Compute Zernike 1 and Zernike 2 for the batch zk1_2_batch = -1.0 * compute_zernike_tip_tilt( @@ -67,23 +71,31 @@ def compute_centroid_correction(model_params, centroid_dataset, batch_size: int= ) # Zero pad array for each batch and append - zernike_centroid_array.append(np.pad(zk1_2_batch, pad_width=[(0, 0), (1, 0)], mode="constant", constant_values=0)) + zernike_centroid_array.append( + np.pad( + zk1_2_batch, + pad_width=[(0, 0), (1, 0)], + mode="constant", + constant_values=0, + ) + ) # Combine all batches into a single array return np.concatenate(zernike_centroid_array, axis=0) + def compute_zernike_tip_tilt( star_images: np.ndarray, star_masks: Optional[np.ndarray] = None, pixel_sampling: float = 12e-6, - reference_shifts: list[float] = [-1/3, -1/3], + reference_shifts: list[float] = [-1 / 3, -1 / 3], sigma_init: float = 2.5, n_iter: int = 20, ) -> np.ndarray: """ Compute Zernike tip-tilt corrections for a batch of PSF images. - This function estimates the centroid shifts of multiple PSFs and computes + This function estimates the centroid shifts of multiple PSFs and computes the corresponding Zernike tip-tilt corrections to align them with a reference. Parameters @@ -99,7 +111,7 @@ def compute_zernike_tip_tilt( pixel_sampling : float, optional The pixel size in meters. Defaults to `12e-6 m` (12 microns). reference_shifts : list[float], optional - The target centroid shifts in pixels, specified as `[dy, dx]`. + The target centroid shifts in pixels, specified as `[dy, dx]`. Defaults to `[-1/3, -1/3]` (nominal Euclid conditions). sigma_init : float, optional Initial standard deviation for centroid estimation. Default is `2.5`. @@ -112,20 +124,18 @@ def compute_zernike_tip_tilt( An array of shape `(num_images, 2)`, where: - Column 0 contains `Zk1` (tip) values. - Column 1 contains `Zk2` (tilt) values. - + Notes ----- - This function processes all images at once using vectorized operations. - The Zernike coefficients are computed in the WaveDiff convention. """ from wf_psf.data.data_zernike_utils import shift_x_y_to_zk1_2_wavediff + # Vectorize the centroid computation centroid_estimator = CentroidEstimator( - im=star_images, - mask=star_masks, - sigma_init=sigma_init, - n_iter=n_iter - ) + im=star_images, mask=star_masks, sigma_init=sigma_init, n_iter=n_iter + ) shifts = centroid_estimator.get_intra_pixel_shifts() @@ -133,78 +143,79 @@ def compute_zernike_tip_tilt( reference_shifts = np.array(reference_shifts) # Reshape to ensure it's a column vector (1, 2) - reference_shifts = reference_shifts[None,:] - + reference_shifts = reference_shifts[None, :] + # Broadcast reference_shifts to match the shape of shifts - reference_shifts = np.broadcast_to(reference_shifts, shifts.shape) - + reference_shifts = np.broadcast_to(reference_shifts, shifts.shape) + # Compute displacements - displacements = (reference_shifts - shifts) # - + displacements = reference_shifts - shifts # + # Ensure the correct axis order for displacements (x-axis, then y-axis) - displacements_swapped = displacements[:, [1, 0]] # Adjust axis order if necessary + displacements_swapped = displacements[:, [1, 0]] # Adjust axis order if necessary # Call shift_x_y_to_zk1_2_wavediff directly on the vector of displacements - zk1_2_array = shift_x_y_to_zk1_2_wavediff(displacements_swapped.flatten() * pixel_sampling ) # vectorized call - + zk1_2_array = shift_x_y_to_zk1_2_wavediff( + displacements_swapped.flatten() * pixel_sampling + ) # vectorized call + # Reshape the result back to the original shape of displacements zk1_2_array = zk1_2_array.reshape(displacements.shape) - - return zk1_2_array + return zk1_2_array class CentroidEstimator: """ Calculate centroids and estimate intra-pixel shifts for a batch of star images. - This class estimates the centroid of each star in a batch of images using an - iterative process that fits an elliptical Gaussian model to the star images. - The estimated centroids are returned along with the intra-pixel shifts, which - represent the difference between the estimated centroid and the center of the + This class estimates the centroid of each star in a batch of images using an + iterative process that fits an elliptical Gaussian model to the star images. + The estimated centroids are returned along with the intra-pixel shifts, which + represent the difference between the estimated centroid and the center of the image grid (or pixel grid). - The process is vectorized, allowing multiple star images to be processed in + The process is vectorized, allowing multiple star images to be processed in parallel, which significantly improves performance when working with large batches. Parameters ---------- im : numpy.ndarray - A 3D numpy array of star image stamps. The shape of the array should be - (n_images, height, width), where n_images is the number of stars, and + A 3D numpy array of star image stamps. The shape of the array should be + (n_images, height, width), where n_images is the number of stars, and height and width are the dimensions of each star's image. - + mask : numpy.ndarray, optional - A 3D numpy array of the same shape as `im`, representing the mask for each star image. - A mask value of `0` indicates that the pixel is fully considered (unmasked), while a value of `1` means the pixel is completely ignored (masked). - Values between `0` and `1` act as weights, allowing partial consideration of the pixel. - If not provided, no mask is applied. + A 3D numpy array of the same shape as `im`, representing the mask for each star image. + A mask value of `0` indicates that the pixel is fully considered (unmasked), while a value of `1` means the pixel is completely ignored (masked). + Values between `0` and `1` act as weights, allowing partial consideration of the pixel. + If not provided, no mask is applied. sigma_init : float, optional - The initial guess for the standard deviation (sigma) of the elliptical Gaussian + The initial guess for the standard deviation (sigma) of the elliptical Gaussian that models the star. Default is 7.5. n_iter : int, optional - The number of iterations for the iterative centroid estimation procedure. + The number of iterations for the iterative centroid estimation procedure. Default is 5. auto_run : bool, optional - If True, the centroid estimation procedure will be automatically run upon + If True, the centroid estimation procedure will be automatically run upon initialization. Default is True. xc : float, optional - The initial guess for the x-component of the centroid. If None, it is set + The initial guess for the x-component of the centroid. If None, it is set to the center of the image. Default is None. yc : float, optional - The initial guess for the y-component of the centroid. If None, it is set + The initial guess for the y-component of the centroid. If None, it is set to the center of the image. Default is None. Attributes ---------- xc : numpy.ndarray The x-components of the estimated centroids for each image. Shape is (n_images,). - + yc : numpy.ndarray The y-components of the estimated centroids for each image. Shape is (n_images,). @@ -215,10 +226,10 @@ class CentroidEstimator: elliptical_gaussian(e1=0, e2=0) Computes an elliptical 2D Gaussian with the specified shear parameters. - + compute_moments() Computes the first-order moments of the star images and updates the centroid estimates. - + estimate() Runs the iterative centroid estimation procedure for all images. @@ -232,12 +243,14 @@ class CentroidEstimator: ----- The iterative centroid estimation procedure fits an elliptical Gaussian to each star image and computes the centroid by calculating the weighted moments. The - `estimate()` method performs the centroid calculation for a batch of images using - the iterative approach defined by the `n_iter` parameter. This class is designed + `estimate()` method performs the centroid calculation for a batch of images using + the iterative approach defined by the `n_iter` parameter. This class is designed to be efficient and scalable when processing large batches of star images. """ - def __init__(self, im, mask=None, sigma_init=7.5, n_iter=5, auto_run=True, xc=None, yc=None): + def __init__( + self, im, mask=None, sigma_init=7.5, n_iter=5, auto_run=True, xc=None, yc=None + ): """Initialize class attributes.""" # Convert to np.ndarray if not already @@ -269,7 +282,6 @@ def __init__(self, im, mask=None, sigma_init=7.5, n_iter=5, auto_run=True, xc=No if auto_run: self.estimate() - def update_grid(self): """Vectorized update of the grid coordinates for multiple star stamps.""" @@ -279,9 +291,9 @@ def update_grid(self): y_range = np.arange(Ny) # Correct subtraction without mixing axes - self.xx = (x_range - self.xc[:, None]) - self.yy = (y_range - self.yc[:, None]) - + self.xx = x_range - self.xc[:, None] + self.yy = y_range - self.yc[:, None] + # Now, expand to the correct shape (num_images, Nx, Ny) # Add the extra dimension for the number of stars self.xx = self.xx[:, :, None] # Shape: (num_images, Nx, 1) @@ -295,7 +307,7 @@ def elliptical_gaussian(self, e1=0, e2=0): # Shear the grid coordinates gxx = (1 - e1) * self.xx - e2 * self.yy gyy = (1 + e1) * self.yy - e2 * self.xx - + # Compute elliptical Gaussian return np.exp(-(gxx**2 + gyy**2) / (2 * self.sigma_init**2)) @@ -309,7 +321,11 @@ def compute_moments(self): Q0 = np.sum(masked_im_window, axis=(1, 2)) # Sum over images and their pixels Q1 = np.array( [ - np.sum(np.sum(masked_im_window, axis=2 - i) * np.arange(self.stamp_size[i]), axis=1) + np.sum( + np.sum(masked_im_window, axis=2 - i) + * np.arange(self.stamp_size[i]), + axis=1, + ) for i in range(2) ] ) @@ -331,7 +347,7 @@ def get_centroids(self): def get_intra_pixel_shifts(self): """Get intra-pixel shifts for all images. - + Intra-pixel shifts are the differences between the estimated centroid and the center of the image stamp (or pixel grid). These shifts are calculated for all images in the batch. Returns @@ -339,8 +355,8 @@ def get_intra_pixel_shifts(self): np.array A 2D array of shape (num_of_images, 2), where each row corresponds to the x and y shifts for each image. """ - shifts = np.stack([self.xc - self.xc0, self.yc - self.yc0], axis=-1) - + shifts = np.stack([self.xc - self.xc0, self.yc - self.yc0], axis=-1) + return shifts diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index 9bf9d1fd..399ee9ef 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -58,12 +58,14 @@ def build( centroid_dataset, positions = None, None if run_type in {"training", "simulation", "metrics"}: - stamps = get_data_array(data, run_type, train_key="noisy_stars", test_key="stars") + stamps = get_data_array( + data, run_type, train_key="noisy_stars", test_key="stars" + ) masks = get_data_array(data, run_type, key="masks", allow_missing=True) centroid_dataset = {"stamps": stamps, "masks": masks} positions = get_data_array(data=data, run_type=run_type, key="positions") - + if model_params.use_prior: if prior is not None: logger.warning( diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 5c27fdbf..c7d73249 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -28,7 +28,6 @@ def __init__(self, inference_config_path: str): self.training_config = None self.data_config = None - def load_configs(self): """Load configuration files based on the inference config.""" self.inference_config = read_conf(self.inference_config_path) @@ -39,7 +38,6 @@ def load_configs(self): # Load the data configuration self.data_config = read_conf(self.data_config_path) - def set_config_paths(self): """Extract and set the configuration paths.""" # Set config paths @@ -47,10 +45,11 @@ def set_config_paths(self): self.trained_model_path = Path(config_paths.trained_model_path) self.model_subdir = config_paths.model_subdir - self.trained_model_config_path = self.trained_model_path / config_paths.trained_model_config_path + self.trained_model_config_path = ( + self.trained_model_path / config_paths.trained_model_config_path + ) self.data_config_path = config_paths.data_config_path - @staticmethod def overwrite_model_params(training_config=None, inference_config=None): """ @@ -75,7 +74,6 @@ def overwrite_model_params(training_config=None, inference_config=None): if hasattr(model_params, key): setattr(model_params, key, value) - class PSFInference: """ @@ -101,7 +99,15 @@ class PSFInference: Corresponding masks for the sources (same shape as sources). Defaults to None. """ - def __init__(self, inference_config_path: str, x_field=None, y_field=None, seds=None, sources=None, masks=None): + def __init__( + self, + inference_config_path: str, + x_field=None, + y_field=None, + seds=None, + sources=None, + masks=None, + ): self.inference_config_path = inference_config_path @@ -111,7 +117,7 @@ def __init__(self, inference_config_path: str, x_field=None, y_field=None, seds= self.seds = seds self.sources = sources self.masks = masks - + # Internal caches for lazy-loading self._config_handler = None self._simPSF = None @@ -123,7 +129,7 @@ def __init__(self, inference_config_path: str, x_field=None, y_field=None, seds= self._output_dim = None # Initialise PSF Inference engine - self.engine = None + self.engine = None @property def config_handler(self): @@ -157,7 +163,6 @@ def simPSF(self): self._simPSF = psf_models.simPSF(self.training_config.training.model_params) return self._simPSF - def _prepare_dataset_for_inference(self): """Prepare dataset dictionary for inference, returning None if positions are invalid.""" positions = self.get_positions() @@ -167,7 +172,7 @@ def _prepare_dataset_for_inference(self): @property def data_handler(self): - if self._data_handler is None: + if self._data_handler is None: # Instantiate the data handler self._data_handler = DataHandler( dataset_type="inference", @@ -176,7 +181,7 @@ def data_handler(self): n_bins_lambda=self.n_bins_lambda, load_data=False, dataset=self._prepare_dataset_for_inference(), - sed_data = self.seds, + sed_data=self.seds, ) self._data_handler.run_type = "inference" return self._data_handler @@ -190,13 +195,13 @@ def trained_psf_model(self): def get_positions(self): """ Combine x_field and y_field into position pairs. - + Returns ------- numpy.ndarray Array of shape (num_positions, 2) where each row contains [x, y] coordinates. Returns None if either x_field or y_field is None. - + Raises ------ ValueError @@ -204,25 +209,27 @@ def get_positions(self): """ if self.x_field is None or self.y_field is None: return None - + x_arr = np.asarray(self.x_field) y_arr = np.asarray(self.y_field) - + if x_arr.size == 0 or y_arr.size == 0: return None if x_arr.size != y_arr.size: - raise ValueError(f"x_field and y_field must have the same length. " - f"Got {x_arr.size} and {y_arr.size}") - + raise ValueError( + f"x_field and y_field must have the same length. " + f"Got {x_arr.size} and {y_arr.size}" + ) + # Flatten arrays to handle any input shape, then stack x_flat = x_arr.flatten() y_flat = y_arr.flatten() - + return np.column_stack((x_flat, y_flat)) def load_inference_model(self): - """Load the trained PSF model based on the inference configuration.""" + """Load the trained PSF model based on the inference configuration.""" model_path = self.config_handler.trained_model_path model_dir = self.config_handler.model_subdir model_name = self.training_config.training.model_params.model_name @@ -231,9 +238,9 @@ def load_inference_model(self): weights_path_pattern = os.path.join( model_path, model_dir, - f"{model_dir}*_{model_name}*{id_name}_cycle{self.cycle}*" + f"{model_dir}*_{model_name}*{id_name}_cycle{self.cycle}*", ) - + # Load the trained PSF model return load_trained_psf_model( self.training_config, @@ -244,7 +251,9 @@ def load_inference_model(self): @property def n_bins_lambda(self): if self._n_bins_lambda is None: - self._n_bins_lambda = self.inference_config.inference.model_params.n_bins_lda + self._n_bins_lambda = ( + self.inference_config.inference.model_params.n_bins_lda + ) return self._n_bins_lambda @property @@ -279,8 +288,10 @@ def _prepare_positions_and_seds(self): y_arr = np.atleast_1d(self.y_field) if x_arr.size != y_arr.size: - raise ValueError(f"x_field and y_field must have the same length. " - f"Got {x_arr.size} and {y_arr.size}") + raise ValueError( + f"x_field and y_field must have the same length. " + f"Got {x_arr.size} and {y_arr.size}" + ) # Combine into positions array (n_samples, 2) positions = np.column_stack((x_arr, y_arr)) @@ -288,12 +299,16 @@ def _prepare_positions_and_seds(self): # Ensure SEDs have shape (n_samples, n_bins, 2) sed_data = ensure_batch(self.seds) - + if sed_data.shape[0] != positions.shape[0]: - raise ValueError(f"SEDs batch size {sed_data.shape[0]} does not match number of positions {positions.shape[0]}") + raise ValueError( + f"SEDs batch size {sed_data.shape[0]} does not match number of positions {positions.shape[0]}" + ) if sed_data.shape[2] != 2: - raise ValueError(f"SEDs last dimension must be 2 (flux, wavelength). Got {sed_data.shape}") + raise ValueError( + f"SEDs last dimension must be 2 (flux, wavelength). Got {sed_data.shape}" + ) # Process SEDs through the data handler self.data_handler.process_sed_data(sed_data) @@ -301,7 +316,6 @@ def _prepare_positions_and_seds(self): return positions, sed_data_tensor - def run_inference(self): """Run PSF inference and return the full PSF array.""" # Prepare the configuration for inference @@ -332,7 +346,7 @@ def get_psf(self, index: int = 0) -> np.ndarray: If only a single star was passed during instantiation, the index defaults to 0. """ self._ensure_psf_inference_completed() - + inferred_psfs = self.engine.get_psfs() # If a single-star batch, ignore index bounds @@ -358,7 +372,9 @@ def inferred_psfs(self) -> np.ndarray: def compute_psfs(self, positions: tf.Tensor, sed_data: tf.Tensor) -> np.ndarray: """Compute and cache PSFs for the input source parameters.""" n_samples = positions.shape[0] - self._inferred_psfs = np.zeros((n_samples, self.output_dim, self.output_dim), dtype=np.float32) + self._inferred_psfs = np.zeros( + (n_samples, self.output_dim, self.output_dim), dtype=np.float32 + ) # Initialize counter counter = 0 @@ -370,14 +386,14 @@ def compute_psfs(self, positions: tf.Tensor, sed_data: tf.Tensor) -> np.ndarray: batch_pos = positions[counter:end_sample, :] batch_seds = sed_data[counter:end_sample, :, :] batch_inputs = [batch_pos, batch_seds] - + # Generate PSFs for the current batch batch_psfs = self.trained_model(batch_inputs, training=False) self.inferred_psfs[counter:end_sample, :, :] = batch_psfs.numpy() # Update the counter counter = end_sample - + return self._inferred_psfs def get_psfs(self) -> np.ndarray: @@ -391,5 +407,3 @@ def get_psf(self, index: int) -> np.ndarray: if self._inferred_psfs is None: raise ValueError("PSFs not yet computed. Call compute_psfs() first.") return self._inferred_psfs[index] - - diff --git a/src/wf_psf/instrument/ccd_misalignments.py b/src/wf_psf/instrument/ccd_misalignments.py index b2d06a20..6b745ad4 100644 --- a/src/wf_psf/instrument/ccd_misalignments.py +++ b/src/wf_psf/instrument/ccd_misalignments.py @@ -59,8 +59,8 @@ class CCDMisalignmentCalculator: This class processes and analyzes CCD misalignment data using tile position information. - The `tiles_data` array is a data cube where each slice is a 4×3 matrix representing - the four corners of a tile. The first two columns correspond to x/y coordinates (in mm), + The `tiles_data` array is a data cube where each slice is a 4×3 matrix representing + the four corners of a tile. The first two columns correspond to x/y coordinates (in mm), and the third column represents z displacement (in µm). Parameters @@ -103,6 +103,7 @@ class CCDMisalignmentCalculator: d_list : np.ndarray List of plane offset values for CCD planes. """ + def __init__( self, tiles_path: str, @@ -123,7 +124,11 @@ def __init__( raise ValueError("Tile data must have three coordinate columns (x, y, z).") # Initialize attributes - self.tiles_x_lims, self.tiles_y_lims, self.tiles_z_lims = np.zeros(2), np.zeros(2), np.zeros(2) + self.tiles_x_lims, self.tiles_y_lims, self.tiles_z_lims = ( + np.zeros(2), + np.zeros(2), + np.zeros(2), + ) self.tiles_z_average: float = 0.0 self.ccd_polygons: list[mpltPath.Path] = [] @@ -135,7 +140,6 @@ def __init__( self._initialize() - def _initialize(self) -> None: """Run all required initialization steps.""" self._preprocess_tile_data() @@ -145,12 +149,17 @@ def _initialize(self) -> None: def _preprocess_tile_data(self) -> None: """Preprocess tile data by computing spatial limits and averages.""" - self.tiles_x_lims = np.array([np.min(self.tiles_data[:, 0, :]), np.max(self.tiles_data[:, 0, :])]) - self.tiles_y_lims = np.array([np.min(self.tiles_data[:, 1, :]), np.max(self.tiles_data[:, 1, :])]) - self.tiles_z_lims = np.array([np.min(self.tiles_data[:, 2, :]), np.max(self.tiles_data[:, 2, :])]) + self.tiles_x_lims = np.array( + [np.min(self.tiles_data[:, 0, :]), np.max(self.tiles_data[:, 0, :])] + ) + self.tiles_y_lims = np.array( + [np.min(self.tiles_data[:, 1, :]), np.max(self.tiles_data[:, 1, :])] + ) + self.tiles_z_lims = np.array( + [np.min(self.tiles_data[:, 2, :]), np.max(self.tiles_data[:, 2, :])] + ) self.tiles_z_average = np.mean(self.tiles_z_lims) - def _initialize_polygons(self): """Initialize polygons to look for CCD IDs""" @@ -221,7 +230,6 @@ def _precompute_CCD_planes(self): self.normal_list.append(normal) self.d_list.append(d) - def scale_position_to_tile_reference(self, pos): """Scale input position into tiles coordinate system. @@ -251,7 +259,6 @@ def scale_position_to_tile_reference(self, pos): return np.array([scaled_x, scaled_y]) - def scale_position_to_wavediff_reference(self, pos): """Scale input position into wavediff coordinate system. @@ -297,7 +304,6 @@ def check_position_tile_limits(self, pos): raise ValueError( "Input position is not within the tile focal plane limits." ) - def get_ccd_from_position(self, pos): """Get CCD ID from the position. @@ -340,7 +346,6 @@ def get_ccd_from_position(self, pos): return ccd_id - def get_dz_from_position(self, pos): """Get z-axis displacement for a focal plane position. @@ -369,7 +374,6 @@ def get_dz_from_position(self, pos): return dz - def get_zk4_from_position(self, pos): """Get defocus Zernike contribution from focal plane position. @@ -390,7 +394,6 @@ def get_zk4_from_position(self, pos): return defocus_to_zk4_wavediff(dz, self.tel_focal_length, self.tel_diameter) - @staticmethod def compute_z_from_plane_data(pos, normal, d): """Compute z value from plane data. @@ -420,7 +423,6 @@ def compute_z_from_plane_data(pos, normal, d): return z - @staticmethod def check_position_format(pos): if type(pos) is list: diff --git a/src/wf_psf/metrics/metrics_interface.py b/src/wf_psf/metrics/metrics_interface.py index 5680cecc..3fa9d0bd 100644 --- a/src/wf_psf/metrics/metrics_interface.py +++ b/src/wf_psf/metrics/metrics_interface.py @@ -36,15 +36,12 @@ def __init__(self, metrics_params, trained_model): self.metrics_params = metrics_params self.trained_model = trained_model - def evaluate_metrics_polychromatic_lowres(self, - psf_model: Any, - simPSF: Any, - data: Any, - dataset: Dict[str, Any] - ) -> Dict[str, float]: + def evaluate_metrics_polychromatic_lowres( + self, psf_model: Any, simPSF: Any, data: Any, dataset: Dict[str, Any] + ) -> Dict[str, float]: """Evaluate RMSE metrics for low-resolution polychromatic PSF. - This method computes Root Mean Square Error (RMSE) metrics for a + This method computes Root Mean Square Error (RMSE) metrics for a low-resolution polychromatic Point Spread Function (PSF) model. Parameters @@ -62,14 +59,14 @@ def evaluate_metrics_polychromatic_lowres(self, - ``C_poly`` Tensor or None, optional The Zernike coefficient matrix used in generating simulations of the PSF model. This matrix defines the Zernike polynomials up to a given order used to simulate the PSF - field. It may be present in some datasets or only required for some classes. + field. It may be present in some datasets or only required for some classes. If not present or required, the model will proceed without it. Returns ------- dict - A dictionary containing the RMSE, relative RMSE, and their + A dictionary containing the RMSE, relative RMSE, and their corresponding standard deviation values. - ``rmse`` : float @@ -113,17 +110,13 @@ def evaluate_metrics_polychromatic_lowres(self, "std_rel_rmse": std_rel_rmse, } - - def evaluate_metrics_mono_rmse(self, - psf_model: Any, - simPSF: Any, - data: Any, - dataset: Dict[str, Any] - ) -> Dict[str, float]: + def evaluate_metrics_mono_rmse( + self, psf_model: Any, simPSF: Any, data: Any, dataset: Dict[str, Any] + ) -> Dict[str, float]: """Evaluate RMSE metrics for Monochromatic PSF. - This method computes Root Mean Square Error (RMSE) metrics for a - monochromatic Point Spread Function (PSF) model across a range of + This method computes Root Mean Square Error (RMSE) metrics for a + monochromatic Point Spread Function (PSF) model across a range of wavelengths. Parameters @@ -140,13 +133,13 @@ def evaluate_metrics_mono_rmse(self, - ``C_poly`` (Tensor or None, optional) The Zernike coefficient matrix used in generating simulations of the PSF model. This matrix defines the Zernike polynomials up to a given order used to simulate the PSF - field. It may be present in some datasets or only required for some classes. + field. It may be present in some datasets or only required for some classes. If not present or required, the model will proceed without it. Returns ------- dict - A dictionary containing RMSE, relative RMSE, and their corresponding + A dictionary containing RMSE, relative RMSE, and their corresponding standard deviation values computed over a wavelength range. - ``rmse_lda`` : float @@ -187,17 +180,13 @@ def evaluate_metrics_mono_rmse(self, "std_rmse_lda": std_rmse_lda, "std_rel_rmse_lda": std_rel_rmse_lda, } - - - def evaluate_metrics_opd(self, - psf_model: Any, - simPSF: Any, - data: Any, - dataset: Dict[str, Any] - ) -> Dict[str, float]: + + def evaluate_metrics_opd( + self, psf_model: Any, simPSF: Any, data: Any, dataset: Dict[str, Any] + ) -> Dict[str, float]: """Evaluate Optical Path Difference (OPD) metrics. - - This method computes Root Mean Square Error (RMSE) and relative RMSE + + This method computes Root Mean Square Error (RMSE) and relative RMSE for Optical Path Differences (OPD), along with their standard deviations. Parameters @@ -214,13 +203,13 @@ def evaluate_metrics_opd(self, - ``C_poly`` (Tensor or None, optional) The Zernike coefficient matrix used in generating simulations of the PSF model. This matrix defines the Zernike polynomials up to a given order used to simulate the PSF - field. It may be present in some datasets or only required for some classes. + field. It may be present in some datasets or only required for some classes. If not present or required, the model will proceed without it. Returns ------- dict - A dictionary containing RMSE, relative RMSE, and their corresponding + A dictionary containing RMSE, relative RMSE, and their corresponding standard deviation values for OPD metrics. - ``rmse_opd`` : float @@ -259,16 +248,12 @@ def evaluate_metrics_opd(self, "rel_rmse_std_opd": rel_rmse_std_opd, } - - def evaluate_metrics_shape(self, - psf_model: Any, - simPSF: Any, - data: Any, - dataset: Dict[str, Any] - ) -> Dict[str, float]: + def evaluate_metrics_shape( + self, psf_model: Any, simPSF: Any, data: Any, dataset: Dict[str, Any] + ) -> Dict[str, float]: """Evaluate PSF Shape Metrics. - Computes shape-related metrics for the PSF model, including RMSE, + Computes shape-related metrics for the PSF model, including RMSE, relative RMSE, and their standard deviations. Parameters @@ -286,7 +271,7 @@ def evaluate_metrics_shape(self, - ``C_poly`` (Tensor or None, optional) The Zernike coefficient matrix used in generating simulations of the PSF model. This matrix defines the Zernike polynomials up to a given order used to simulate the PSF - field. It may be present in some datasets or only required for some classes. + field. It may be present in some datasets or only required for some classes. If not present or required, the model will proceed without it. Returns @@ -328,7 +313,7 @@ def evaluate_model( metrics_output, ): """Evaluate the trained model on both training and test datasets by computing various metrics. - + The metrics to evaluate are determined by the configuration in `metrics_params` and `metric_evaluation_flags`. Metrics are computed for both the training and test datasets, and results are stored in a dictionary. diff --git a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index c8086cb7..e2c84868 100644 --- a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -12,9 +12,9 @@ from tensorflow.python.keras.engine import data_adapter from wf_psf.data.data_handler import get_data_array from wf_psf.data.data_zernike_utils import ( - ZernikeInputsFactory, + ZernikeInputsFactory, assemble_zernike_contributions, - pad_tf_zernikes + pad_tf_zernikes, ) from wf_psf.psf_models import psf_models as psfm from wf_psf.psf_models.tf_modules.tf_layers import ( @@ -124,7 +124,7 @@ def __init__(self, model_params, training_params, data, coeff_mat=None): self.output_Q = model_params.output_Q self.l2_param = model_params.param_hparams.l2_param self.output_dim = model_params.output_dim - + # Initialise lazy loading of external Zernike prior self._external_prior = None @@ -134,25 +134,26 @@ def __init__(self, model_params, training_params, data, coeff_mat=None): # Compute contributions once eagerly (outside graph) zks_total_contribution_np = self._assemble_zernike_contributions().numpy() - self._zks_total_contribution = tf.convert_to_tensor(zks_total_contribution_np, dtype=tf.float32) - + self._zks_total_contribution = tf.convert_to_tensor( + zks_total_contribution_np, dtype=tf.float32 + ) + # Compute n_zks_total as int self._n_zks_total = max( self.model_params.param_hparams.n_zernikes, - zks_total_contribution_np.shape[1] + zks_total_contribution_np.shape[1], ) - - # Precompute zernike maps as tf.float32 + + # Precompute zernike maps as tf.float32 self._zernike_maps = psfm.generate_zernike_maps_3d( - n_zernikes=self._n_zks_total, - pupil_diam=self.model_params.pupil_diameter - ) + n_zernikes=self._n_zks_total, pupil_diam=self.model_params.pupil_diameter + ) - # Precompute OPD dimension + # Precompute OPD dimension self._opd_dim = self._zernike_maps.shape[1] # Precompute obscurations as tf.complex64 - self._obscurations = psfm.tf_obscurations( + self._obscurations = psfm.tf_obscurations( pupil_diam=self.model_params.pupil_diameter, N_filter=self.model_params.LP_filter_length, rotation_angle=self.model_params.obscuration_rotation_angle, @@ -164,10 +165,10 @@ def __init__(self, model_params, training_params, data, coeff_mat=None): _ = self.tf_np_poly_opd def _get_run_type(self, data): - if hasattr(data, 'run_type'): + if hasattr(data, "run_type"): run_type = data.run_type - elif isinstance(data, dict) and 'run_type' in data: - run_type = data['run_type'] + elif isinstance(data, dict) and "run_type" in data: + run_type = data["run_type"] else: raise ValueError("data must have a 'run_type' attribute or key") @@ -193,17 +194,28 @@ def _assemble_zernike_contributions(self): @property def save_param_history(self) -> bool: """Check if the model should save the optimization history for parametric features.""" - return getattr(self.model_params.param_hparams, "save_optim_history_param", False) - + return getattr( + self.model_params.param_hparams, "save_optim_history_param", False + ) + @property def save_nonparam_history(self) -> bool: """Check if the model should save the optimization history for non-parametric features.""" - return getattr(self.model_params.nonparam_hparams, "save_optim_history_nonparam", False) + return getattr( + self.model_params.nonparam_hparams, "save_optim_history_nonparam", False + ) def get_obs_pos(self): - assert self.run_type in {"training", "simulation", "metrics", "inference"}, f"Unknown run_type: {self.run_type}" - - raw_pos = get_data_array(data=self.data, run_type=self.run_type, key="positions") + assert self.run_type in { + "training", + "simulation", + "metrics", + "inference", + }, f"Unknown run_type: {self.run_type}" + + raw_pos = get_data_array( + data=self.data, run_type=self.run_type, key="positions" + ) obs_pos = ensure_tensor(raw_pos, dtype=tf.float32) @@ -213,7 +225,7 @@ def get_obs_pos(self): @property def zks_total_contribution(self): return self._zks_total_contribution - + @property def n_zks_total(self): """Get the total number of Zernike coefficients.""" @@ -254,40 +266,40 @@ def tf_physical_layer(self): """Lazy loading of the physical layer of the PSF model.""" if not hasattr(self, "_tf_physical_layer"): self._tf_physical_layer = TFPhysicalLayer( - self.obs_pos, - self.zks_total_contribution, - interpolation_type=self.model_params.interpolation_type, - interpolation_args=self.model_params.interpolation_args, - ) + self.obs_pos, + self.zks_total_contribution, + interpolation_type=self.model_params.interpolation_type, + interpolation_args=self.model_params.interpolation_args, + ) return self._tf_physical_layer - + @property def tf_zernike_OPD(self): """Lazy loading of the Zernike Optical Path Difference (OPD) layer.""" if not hasattr(self, "_tf_zernike_OPD"): self._tf_zernike_OPD = TFZernikeOPD(zernike_maps=self.zernike_maps) return self._tf_zernike_OPD - + def _build_tf_batch_poly_PSF(self): """Eagerly build the TFBatchPolychromaticPSF layer with numpy-based obscurations.""" return TFBatchPolychromaticPSF( - obscurations=self.obscurations, - output_Q=self.output_Q, - output_dim=self.output_dim, - ) + obscurations=self.obscurations, + output_Q=self.output_Q, + output_dim=self.output_dim, + ) @property def tf_np_poly_opd(self): """Lazy loading of the non-parametric polynomial variations OPD layer.""" if not hasattr(self, "_tf_np_poly_opd"): self._tf_np_poly_opd = TFNonParametricPolynomialVariationsOPD( - x_lims=self.model_params.x_lims, - y_lims=self.model_params.y_lims, - random_seed=self.model_params.param_hparams.random_seed, - d_max=self.model_params.nonparam_hparams.d_max_nonparam, - opd_dim=self.opd_dim, - ) + x_lims=self.model_params.x_lims, + y_lims=self.model_params.y_lims, + random_seed=self.model_params.param_hparams.random_seed, + d_max=self.model_params.nonparam_hparams.d_max_nonparam, + opd_dim=self.opd_dim, + ) return self._tf_np_poly_opd def get_coeff_matrix(self): @@ -313,7 +325,6 @@ def assign_coeff_matrix(self, coeff_mat: Optional[tf.Tensor]) -> None: """ self.tf_poly_Z_field.assign_coeff_matrix(coeff_mat) - def set_output_Q(self, output_Q: float, output_dim: Optional[int] = None) -> None: """Set the output sampling rate (output_Q) for PSF generation. @@ -453,16 +464,16 @@ def predict_step(self, data, evaluate_step=False): # Compute zernikes from parametric model and physical layer zks_coeffs = self.predict_zernikes(input_positions) - + # Propagate to obtain the OPD param_opd_maps = self.tf_zernike_OPD(zks_coeffs) - + # Calculate the non parametric part nonparam_opd_maps = self.tf_np_poly_opd(input_positions) - + # Add the estimations opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) - + # Compute the polychromatic PSFs poly_psfs = self.tf_batch_poly_PSF([opd_maps, packed_SEDs]) @@ -505,13 +516,13 @@ def predict_mono_psfs(self, input_positions, lambda_obs, phase_N): # Predict zernikes from parametric model and physical layer zks_coeffs = self.predict_zernikes(input_positions) - + # Propagate to obtain the OPD param_opd_maps = self.tf_zernike_OPD(zks_coeffs) - + # Calculate the non parametric part nonparam_opd_maps = self.tf_np_poly_opd(input_positions) - + # Add the estimations opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) @@ -536,13 +547,13 @@ def predict_opd(self, input_positions): """ # Predict zernikes from parametric model and physical layer zks_coeffs = self.predict_zernikes(input_positions) - + # Propagate to obtain the OPD param_opd_maps = self.tf_zernike_OPD(zks_coeffs) - + # Calculate the non parametric part nonparam_opd_maps = self.tf_np_poly_opd(input_positions) - + # Add the estimations opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) @@ -677,27 +688,25 @@ def call(self, inputs, training=True): packed_SEDs = inputs[1] # For the training - if training: + if training: # Compute zernikes from parametric model and physical layer zks_coeffs = self.compute_zernikes(input_positions) - + # Parametric OPD maps from Zernikes param_opd_maps = self.tf_zernike_OPD(zks_coeffs) # Add L2 regularization loss on parametric OPD maps - self.add_loss( - self.l2_param * tf.reduce_sum(tf.square(param_opd_maps)) - ) + self.add_loss(self.l2_param * tf.reduce_sum(tf.square(param_opd_maps))) # Non-parametric correction nonparam_opd_maps = self.tf_np_poly_opd(input_positions) # Combine both contributions opd_maps = tf.add(param_opd_maps, nonparam_opd_maps) - + # Compute the polychromatic PSFs poly_psfs = self.tf_batch_poly_PSF([opd_maps, packed_SEDs]) - + # For the inference else: # Compute predictions diff --git a/src/wf_psf/psf_models/psf_model_loader.py b/src/wf_psf/psf_models/psf_model_loader.py index c30d31ad..e445f7af 100644 --- a/src/wf_psf/psf_models/psf_model_loader.py +++ b/src/wf_psf/psf_models/psf_model_loader.py @@ -7,15 +7,14 @@ Author: Jennifer Pollack """ + import logging -from wf_psf.psf_models.psf_models import ( - get_psf_model, - get_psf_model_weights_filepath -) +from wf_psf.psf_models.psf_models import get_psf_model, get_psf_model_weights_filepath import tensorflow as tf logger = logging.getLogger(__name__) + def load_trained_psf_model(training_conf, data_conf, weights_path_pattern): """ Loads a trained PSF model and applies saved weights. @@ -40,9 +39,11 @@ def load_trained_psf_model(training_conf, data_conf, weights_path_pattern): RuntimeError If loading the model weights fails for any reason. """ - model = get_psf_model(training_conf.training.model_params, - training_conf.training.training_hparams, - data_conf) + model = get_psf_model( + training_conf.training.model_params, + training_conf.training.training_hparams, + data_conf, + ) weights_path = get_psf_model_weights_filepath(weights_path_pattern) @@ -55,4 +56,3 @@ def load_trained_psf_model(training_conf, data_conf, weights_path_pattern): logger.exception("Failed to load model weights.") raise RuntimeError("Model weight loading failed.") from e return model - diff --git a/src/wf_psf/psf_models/tf_modules/tf_layers.py b/src/wf_psf/psf_models/tf_modules/tf_layers.py index fdea0077..98d450f2 100644 --- a/src/wf_psf/psf_models/tf_modules/tf_layers.py +++ b/src/wf_psf/psf_models/tf_modules/tf_layers.py @@ -8,6 +8,7 @@ logger = logging.getLogger(__name__) + class TFPolynomialZernikeField(tf.keras.layers.Layer): """Calculate the zernike coefficients for a given position. @@ -925,7 +926,6 @@ def interpolate_independent_Zk(self, positions): return interp_zks[:, :, tf.newaxis, tf.newaxis] - def call(self, positions): """Calculate the prior Zernike coefficients for a batch of positions. diff --git a/src/wf_psf/psf_models/tf_modules/tf_modules.py b/src/wf_psf/psf_models/tf_modules/tf_modules.py index 5a597847..2d93834e 100644 --- a/src/wf_psf/psf_models/tf_modules/tf_modules.py +++ b/src/wf_psf/psf_models/tf_modules/tf_modules.py @@ -1,10 +1,11 @@ """TensorFlow-Based PSF Modeling. -A module containing TensorFlow implementations for modeling monochromatic PSFs using Zernike polynomials and Fourier optics. +A module containing TensorFlow implementations for modeling monochromatic PSFs using Zernike polynomials and Fourier optics. :Author: Tobias Liaudat """ + import numpy as np import tensorflow as tf from typing import Optional @@ -21,7 +22,9 @@ class TFFftDiffract(tf.Module): Downsampling factor. Must be integer. """ - def __init__(self, output_dim: int = 64, output_Q: int = 2, name: Optional[str] = None) -> None: + def __init__( + self, output_dim: int = 64, output_Q: int = 2, name: Optional[str] = None + ) -> None: """Initialize the TFFftDiffract class. Parameters @@ -47,15 +50,15 @@ def __init__(self, output_dim: int = 64, output_Q: int = 2, name: Optional[str] def tf_crop_img(self, image, output_crop_dim): """Crop images using TensorFlow methods. - This method handles a batch of 2D images and crops them to the specified dimension. - The images are expected to have the shape [batch, width, height], and the method + This method handles a batch of 2D images and crops them to the specified dimension. + The images are expected to have the shape [batch, width, height], and the method uses TensorFlow's `crop_to_bounding_box` to crop each image in the batch. Parameters ---------- image : tf.Tensor A batch of 2D images with shape [batch, height, width]. The images are expected - to be 3D tensors where the second and third dimensions represent the height and width. + to be 3D tensors where the second and third dimensions represent the height and width. output_crop_dim : int The dimension of the square crop. The image will be cropped to this dimension. @@ -108,8 +111,8 @@ def normalize_psf(self, psf): def __call__(self, input_phase): """Calculate the normalized Point Spread Function (PSF) from a phase array. - This method takes a 2D input phase array, applies a 2D FFT-based diffraction operation, - crops the resulting PSF, and downscales it by a factor of Q if necessary. Finally, the PSF + This method takes a 2D input phase array, applies a 2D FFT-based diffraction operation, + crops the resulting PSF, and downscales it by a factor of Q if necessary. Finally, the PSF is normalized by summing over its spatial dimensions. Parameters @@ -120,7 +123,7 @@ def __call__(self, input_phase): Returns ------- tf.Tensor - The normalized PSF tensor with shape [batch, height, width], where each PSF is normalized + The normalized PSF tensor with shape [batch, height, width], where each PSF is normalized by its sum over the spatial dimensions. """ # Perform the FFT-based diffraction operation @@ -174,7 +177,13 @@ class TFBuildPhase(tf.Module): A tensor representing the obscurations (e.g., apertures or masks) to be applied to the phase. """ - def __init__(self, phase_N: int, lambda_obs: float, obscurations: tf.Tensor, name: Optional[str] = None) -> None: + def __init__( + self, + phase_N: int, + lambda_obs: float, + obscurations: tf.Tensor, + name: Optional[str] = None, + ) -> None: """Initialize the TFBuildPhase class. Parameters @@ -225,7 +234,6 @@ def zero_padding_diffraction(self, no_pad_phase): padded_phase = tf.pad(no_pad_phase, padding) return padded_phase - def apply_obscurations(self, phase: tf.Tensor) -> tf.Tensor: """Apply obscurations to the phase map. @@ -295,15 +303,15 @@ def __call__(self, opd): class TFZernikeOPD(tf.Module): """Convert Zernike coefficients into an Optical Path Difference (OPD). - This class performs the weighted sum of Zernike coefficients and Zernike maps - to compute the OPD. The Zernike maps and the corresponding Zernike coefficients + This class performs the weighted sum of Zernike coefficients and Zernike maps + to compute the OPD. The Zernike maps and the corresponding Zernike coefficients are required to perform the calculation. Parameters ---------- zernike_maps : tf.Tensor - A tensor containing the Zernike maps. The shape should be - (num_coeffs, x_dim, y_dim), where `num_coeffs` is the number of Zernike coefficients + A tensor containing the Zernike maps. The shape should be + (num_coeffs, x_dim, y_dim), where `num_coeffs` is the number of Zernike coefficients and `x_dim`, `y_dim` are the dimensions of each map. name : str, optional @@ -312,12 +320,12 @@ class TFZernikeOPD(tf.Module): Returns ------- tf.Tensor - A tensor representing the OPD, with shape (num_star, x_dim, y_dim), - where `num_star` corresponds to the number of stars and `x_dim`, `y_dim` are + A tensor representing the OPD, with shape (num_star, x_dim, y_dim), + where `num_star` corresponds to the number of stars and `x_dim`, `y_dim` are the dimensions of the OPD map. """ - def __init__(self, zernike_maps : tf.Tensor, name: Optional[str] = None) -> None: + def __init__(self, zernike_maps: tf.Tensor, name: Optional[str] = None) -> None: """ Initialize the TFZernikeOPD class. @@ -332,18 +340,18 @@ def __init__(self, zernike_maps : tf.Tensor, name: Optional[str] = None) -> None self.zernike_maps = zernike_maps - def __call__(self, z_coeffs : tf.Tensor) -> tf.Tensor: + def __call__(self, z_coeffs: tf.Tensor) -> tf.Tensor: """Compute the OPD from Zernike coefficients and maps. - This method calculates the OPD by performing the weighted sum of Zernike - coefficients and corresponding Zernike maps. The result is a tensor representing + This method calculates the OPD by performing the weighted sum of Zernike + coefficients and corresponding Zernike maps. The result is a tensor representing the computed OPD for the given coefficients. Parameters ---------- z_coeffs : tf.Tensor - A tensor containing the Zernike coefficients. The shape should be - (num_star, num_coeffs, 1, 1), where `num_star` is the number of stars and + A tensor containing the Zernike coefficients. The shape should be + (num_star, num_coeffs, 1, 1), where `num_star` is the number of stars and `num_coeffs` is the number of Zernike coefficients. Returns @@ -358,30 +366,30 @@ def __call__(self, z_coeffs : tf.Tensor) -> tf.Tensor: class TFZernikeMonochromaticPSF(tf.Module): """Build a monochromatic Point Spread Function (PSF) from Zernike coefficients. - This class computes the monochromatic PSF by following the Zernike model. It - involves multiple stages, including the calculation of the OPD (Optical Path - Difference), the phase from the OPD, and diffraction via FFT-based operations. + This class computes the monochromatic PSF by following the Zernike model. It + involves multiple stages, including the calculation of the OPD (Optical Path + Difference), the phase from the OPD, and diffraction via FFT-based operations. The Zernike coefficients are used to generate the PSF. Parameters ---------- phase_N : int The size of the phase grid, typically a square matrix dimension. - + lambda_obs : float The wavelength of the observed light. - + obscurations : tf.Tensor - A tensor representing the obscurations in the system, which will be applied + A tensor representing the obscurations in the system, which will be applied to the phase. zernike_maps : tf.Tensor - A tensor containing the Zernike maps, with the shape (num_coeffs, x_dim, y_dim), - where `num_coeffs` is the number of Zernike coefficients and `x_dim`, `y_dim` are + A tensor containing the Zernike maps, with the shape (num_coeffs, x_dim, y_dim), + where `num_coeffs` is the number of Zernike coefficients and `x_dim`, `y_dim` are the dimensions of the Zernike maps. output_dim : int, optional, default=64 - The output dimension of the PSF, i.e., the size of the resulting image. + The output dimension of the PSF, i.e., the size of the resulting image. name : str, optional The name of the module. Default is `None`. @@ -390,17 +398,22 @@ class TFZernikeMonochromaticPSF(tf.Module): ---------- tf_build_opd_zernike : TFZernikeOPD A module used to generate the OPD from the Zernike coefficients. - + tf_build_phase : TFBuildPhase A module used to compute the phase from the OPD. - + tf_fft_diffract : TFFftDiffract A module that performs the diffraction calculation using FFT-based methods. """ def __init__( - self, phase_N: int, lambda_obs: float, obscurations: tf.Tensor, - zernike_maps: tf.Tensor, output_dim: int = 64, name: Optional[str] = None + self, + phase_N: int, + lambda_obs: float, + obscurations: tf.Tensor, + zernike_maps: tf.Tensor, + output_dim: int = 64, + name: Optional[str] = None, ): """ Initialize the TFZernikeMonochromaticPSF class. @@ -437,15 +450,15 @@ def __call__(self, z_coeffs): Parameters ---------- z_coeffs : tf.Tensor - A tensor containing the Zernike coefficients. The shape should be - (num_star, num_coeffs, 1, 1), where `num_star` is the number of stars + A tensor containing the Zernike coefficients. The shape should be + (num_star, num_coeffs, 1, 1), where `num_star` is the number of stars and `num_coeffs` is the number of Zernike coefficients. Returns ------- tf.Tensor - A tensor representing the computed PSF, with shape - (num_star, output_dim, output_dim), where `output_dim` is the size of + A tensor representing the computed PSF, with shape + (num_star, output_dim, output_dim), where `output_dim` is the size of the resulting PSF image. """ # Generate OPD from Zernike coefficients diff --git a/src/wf_psf/psf_models/tf_modules/tf_utils.py b/src/wf_psf/psf_models/tf_modules/tf_utils.py index d0a2002c..09540e60 100644 --- a/src/wf_psf/psf_models/tf_modules/tf_utils.py +++ b/src/wf_psf/psf_models/tf_modules/tf_utils.py @@ -1,15 +1,15 @@ """TensorFlow Utilities Module. -Provides lightweight utility functions for safely converting and managing data types +Provides lightweight utility functions for safely converting and managing data types within TensorFlow-based workflows. Includes: - `ensure_tensor`: ensures inputs are TensorFlow tensors with specified dtype -These tools are designed to support PSF model components, including lazy property evaluation, +These tools are designed to support PSF model components, including lazy property evaluation, data input validation, and type normalization. -This module is intended for internal use in model layers and inference components to enforce +This module is intended for internal use in model layers and inference components to enforce TensorFlow-compatible inputs. Authors: Jennifer Pollack @@ -22,68 +22,69 @@ @tf.function def find_position_indices(obs_pos, batch_positions): """Find indices of batch positions within observed positions using vectorized operations. - - This function locates the indices of multiple query positions within a + + This function locates the indices of multiple query positions within a reference set of observed positions using broadcasting and vectorized operations. Each position in the batch must have an exact match in the observed positions. Parameters ---------- obs_pos : tf.Tensor - Reference positions tensor of shape (n_obs, 2), where n_obs is the number of + Reference positions tensor of shape (n_obs, 2), where n_obs is the number of observed positions. Each row contains [x, y] coordinates. batch_positions : tf.Tensor - Query positions tensor of shape (batch_size, 2), where batch_size is the number + Query positions tensor of shape (batch_size, 2), where batch_size is the number of positions to look up. Each row contains [x, y] coordinates. - + Returns ------- indices : tf.Tensor - Tensor of shape (batch_size,) containing the indices of each batch position + Tensor of shape (batch_size,) containing the indices of each batch position within obs_pos. The dtype is tf.int64. - + Raises ------ tf.errors.InvalidArgumentError If any position in batch_positions is not found in obs_pos. - + Notes ----- - Uses exact equality matching - positions must match exactly. More efficient than + Uses exact equality matching - positions must match exactly. More efficient than iterative lookups for multiple positions due to vectorized operations. """ # Shape: obs_pos (n_obs, 2), batch_positions (batch_size, 2) # Expand for broadcasting: (1, n_obs, 2) and (batch_size, 1, 2) obs_expanded = tf.expand_dims(obs_pos, 0) pos_expanded = tf.expand_dims(batch_positions, 1) - + # Compare all positions at once: (batch_size, n_obs) matches = tf.reduce_all(tf.equal(obs_expanded, pos_expanded), axis=2) - + # Find the index of the matching position for each batch item - # argmax returns the first True value's index along axis=1 + # argmax returns the first True value's index along axis=1 indices = tf.argmax(tf.cast(matches, tf.int32), axis=1) - + # Verify all positions were found tf.debugging.assert_equal( tf.reduce_all(tf.reduce_any(matches, axis=1)), True, - message="Some positions not found in obs_pos" + message="Some positions not found in obs_pos", ) - + return indices + def ensure_tensor(input_array, dtype=tf.float32): """ Ensure the input is a TensorFlow tensor of the specified dtype. - + Parameters ---------- input_array : array-like, tf.Tensor, or np.ndarray The input to convert. dtype : tf.DType, optional The desired TensorFlow dtype (default: tf.float32). - + Returns ------- tf.Tensor @@ -97,4 +98,3 @@ def ensure_tensor(input_array, dtype=tf.float32): else: # Convert numpy arrays or other types to tensor return tf.convert_to_tensor(input_array, dtype=dtype) - diff --git a/src/wf_psf/sims/psf_simulator.py b/src/wf_psf/sims/psf_simulator.py index f953cded..9b7b58af 100644 --- a/src/wf_psf/sims/psf_simulator.py +++ b/src/wf_psf/sims/psf_simulator.py @@ -227,7 +227,7 @@ def generate_euclid_pupil_obscurations(N_pix=1024, N_filter=3, rotation_angle=0) """Generate Euclid like pupil obscurations. This method simulates the 2D pupil obscurations for the Euclid telescope, - considering the aperture stop, mirror obscurations, and spider arms. It does + considering the aperture stop, mirror obscurations, and spider arms. It does not account for any 3D projections or the angle of the Field of View (FoV). Parameters diff --git a/src/wf_psf/tests/data/validation/masked_loss/results/plot_results.ipynb b/src/wf_psf/tests/data/validation/masked_loss/results/plot_results.ipynb index bec994a4..c36a9e31 100644 --- a/src/wf_psf/tests/data/validation/masked_loss/results/plot_results.ipynb +++ b/src/wf_psf/tests/data/validation/masked_loss/results/plot_results.ipynb @@ -17,19 +17,19 @@ "outputs": [], "source": [ "# Trained on masked data, tested on masked data\n", - "metrics_path_mm = '../runs/masked_train_masked_test/wf-outputs/wf-outputs-202503131718/metrics/metrics-polymask_train_mask_test.npy'\n", + "metrics_path_mm = \"../runs/masked_train_masked_test/wf-outputs/wf-outputs-202503131718/metrics/metrics-polymask_train_mask_test.npy\"\n", "mask_train_mask_test = np.load(metrics_path_mm, allow_pickle=True)[()]\n", "\n", "# Trained on masked data, tested on unmasked data\n", - "metrics_path_mu = '../runs/masked_train_unit_mask_test/wf-outputs/wf-outputs-202503131720/metrics/metrics-polymasked_train_unit_mask_test.npy'\n", + "metrics_path_mu = \"../runs/masked_train_unit_mask_test/wf-outputs/wf-outputs-202503131720/metrics/metrics-polymasked_train_unit_mask_test.npy\"\n", "mask_train_nomask_test = np.load(metrics_path_mu, allow_pickle=True)[()]\n", "\n", "# Trained on unmasked data, tested on unmasked data\n", - "metrics_path_c = '../runs/control_train/wf-outputs/wf-outputs-202503131716/metrics/metrics-polycontrol_train.npy'\n", + "metrics_path_c = \"../runs/control_train/wf-outputs/wf-outputs-202503131716/metrics/metrics-polycontrol_train.npy\"\n", "control_train = np.load(metrics_path_c, allow_pickle=True)[()]\n", "\n", "# Trained and tested with unitary masks\n", - "metrics_path_u = '../runs/unit_masked_train/wf-outputs/wf-outputs-202503131721/metrics/metrics-polyunit_masked_train.npy'\n", + "metrics_path_u = \"../runs/unit_masked_train/wf-outputs/wf-outputs-202503131721/metrics/metrics-polyunit_masked_train.npy\"\n", "unitary = np.load(metrics_path_u, allow_pickle=True)[()]" ] }, @@ -50,8 +50,8 @@ ], "source": [ "print(mask_train_mask_test.keys())\n", - "print(mask_train_mask_test['test_metrics'].keys())\n", - "print(mask_train_mask_test['test_metrics']['poly_metric'].keys())" + "print(mask_train_mask_test[\"test_metrics\"].keys())\n", + "print(mask_train_mask_test[\"test_metrics\"][\"poly_metric\"].keys())" ] }, { @@ -60,17 +60,25 @@ "metadata": {}, "outputs": [], "source": [ - "mask_test_mask_test_rel_rmse = mask_train_mask_test['test_metrics']['poly_metric']['rel_rmse']\n", - "mask_test_mask_test_std_rel_rmse = mask_train_mask_test['test_metrics']['poly_metric']['std_rel_rmse']\n", + "mask_test_mask_test_rel_rmse = mask_train_mask_test[\"test_metrics\"][\"poly_metric\"][\n", + " \"rel_rmse\"\n", + "]\n", + "mask_test_mask_test_std_rel_rmse = mask_train_mask_test[\"test_metrics\"][\"poly_metric\"][\n", + " \"std_rel_rmse\"\n", + "]\n", "\n", - "mask_test_nomask_test_rel_rmse = mask_train_nomask_test['test_metrics']['poly_metric']['rel_rmse']\n", - "mask_test_nomask_test_std_rel_rmse = mask_train_nomask_test['test_metrics']['poly_metric']['std_rel_rmse']\n", + "mask_test_nomask_test_rel_rmse = mask_train_nomask_test[\"test_metrics\"][\"poly_metric\"][\n", + " \"rel_rmse\"\n", + "]\n", + "mask_test_nomask_test_std_rel_rmse = mask_train_nomask_test[\"test_metrics\"][\n", + " \"poly_metric\"\n", + "][\"std_rel_rmse\"]\n", "\n", - "control_test_rel_rmse = control_train['test_metrics']['poly_metric']['rel_rmse']\n", - "control_test_std_rel_rmse = control_train['test_metrics']['poly_metric']['std_rel_rmse']\n", + "control_test_rel_rmse = control_train[\"test_metrics\"][\"poly_metric\"][\"rel_rmse\"]\n", + "control_test_std_rel_rmse = control_train[\"test_metrics\"][\"poly_metric\"][\"std_rel_rmse\"]\n", "\n", - "unitary_test_rel_rmse = unitary['test_metrics']['poly_metric']['rel_rmse']\n", - "unitary_test_std_rel_rmse = unitary['test_metrics']['poly_metric']['std_rel_rmse']" + "unitary_test_rel_rmse = unitary[\"test_metrics\"][\"poly_metric\"][\"rel_rmse\"]\n", + "unitary_test_std_rel_rmse = unitary[\"test_metrics\"][\"poly_metric\"][\"std_rel_rmse\"]" ] }, { @@ -92,12 +100,29 @@ "source": [ "# Plot the results\n", "fig, ax = plt.subplots(figsize=(8, 4))\n", - "plt.title('Relative RMSE 1x - Test dataset')\n", - "ax.errorbar([0, 1, 2, 3], [control_test_rel_rmse, mask_test_mask_test_rel_rmse, mask_test_nomask_test_rel_rmse, unitary_test_rel_rmse], yerr=[control_test_std_rel_rmse, mask_test_mask_test_std_rel_rmse, mask_test_nomask_test_std_rel_rmse, unitary_test_std_rel_rmse], fmt='o')\n", + "plt.title(\"Relative RMSE 1x - Test dataset\")\n", + "ax.errorbar(\n", + " [0, 1, 2, 3],\n", + " [\n", + " control_test_rel_rmse,\n", + " mask_test_mask_test_rel_rmse,\n", + " mask_test_nomask_test_rel_rmse,\n", + " unitary_test_rel_rmse,\n", + " ],\n", + " yerr=[\n", + " control_test_std_rel_rmse,\n", + " mask_test_mask_test_std_rel_rmse,\n", + " mask_test_nomask_test_std_rel_rmse,\n", + " unitary_test_std_rel_rmse,\n", + " ],\n", + " fmt=\"o\",\n", + ")\n", "ax.set_xticks([0, 1, 2, 3])\n", - "ax.set_xticklabels(['Control Train', 'Mask Train Mask Test', 'Mask Train Unit Mask Test', 'Unitary'])\n", - "ax.grid('minor')\n", - "ax.set_ylabel('Relative RMSE')\n", + "ax.set_xticklabels(\n", + " [\"Control Train\", \"Mask Train Mask Test\", \"Mask Train Unit Mask Test\", \"Unitary\"]\n", + ")\n", + "ax.grid(\"minor\")\n", + "ax.set_ylabel(\"Relative RMSE\")\n", "plt.show()" ] }, @@ -107,17 +132,27 @@ "metadata": {}, "outputs": [], "source": [ - "mask_train_mask_test_rel_rmse = mask_train_mask_test['train_metrics']['poly_metric']['rel_rmse']\n", - "mask_train_mask_test_std_rel_rmse = mask_train_mask_test['train_metrics']['poly_metric']['std_rel_rmse']\n", + "mask_train_mask_test_rel_rmse = mask_train_mask_test[\"train_metrics\"][\"poly_metric\"][\n", + " \"rel_rmse\"\n", + "]\n", + "mask_train_mask_test_std_rel_rmse = mask_train_mask_test[\"train_metrics\"][\n", + " \"poly_metric\"\n", + "][\"std_rel_rmse\"]\n", "\n", - "mask_train_nomask_test_rel_rmse = mask_train_nomask_test['train_metrics']['poly_metric']['rel_rmse']\n", - "mask_train_nomask_test_std_rel_rmse = mask_train_nomask_test['train_metrics']['poly_metric']['std_rel_rmse']\n", + "mask_train_nomask_test_rel_rmse = mask_train_nomask_test[\"train_metrics\"][\n", + " \"poly_metric\"\n", + "][\"rel_rmse\"]\n", + "mask_train_nomask_test_std_rel_rmse = mask_train_nomask_test[\"train_metrics\"][\n", + " \"poly_metric\"\n", + "][\"std_rel_rmse\"]\n", "\n", - "control_train_rel_rmse = control_train['train_metrics']['poly_metric']['rel_rmse']\n", - "control_train_std_rel_rmse = control_train['train_metrics']['poly_metric']['std_rel_rmse']\n", + "control_train_rel_rmse = control_train[\"train_metrics\"][\"poly_metric\"][\"rel_rmse\"]\n", + "control_train_std_rel_rmse = control_train[\"train_metrics\"][\"poly_metric\"][\n", + " \"std_rel_rmse\"\n", + "]\n", "\n", - "unitary_rel_rmse = unitary['train_metrics']['poly_metric']['rel_rmse']\n", - "unitary_std_rel_rmse = unitary['train_metrics']['poly_metric']['std_rel_rmse']" + "unitary_rel_rmse = unitary[\"train_metrics\"][\"poly_metric\"][\"rel_rmse\"]\n", + "unitary_std_rel_rmse = unitary[\"train_metrics\"][\"poly_metric\"][\"std_rel_rmse\"]" ] }, { @@ -139,12 +174,29 @@ "source": [ "# Plot the results\n", "fig, ax = plt.subplots(figsize=(8, 4))\n", - "plt.title('Relative RMSE 1x - Train dataset')\n", - "ax.errorbar([0, 1, 2, 3], [control_train_rel_rmse, mask_train_mask_test_rel_rmse, mask_train_nomask_test_rel_rmse, unitary_rel_rmse], yerr=[control_train_std_rel_rmse, mask_train_mask_test_std_rel_rmse, mask_train_nomask_test_std_rel_rmse, unitary_std_rel_rmse], fmt='o')\n", + "plt.title(\"Relative RMSE 1x - Train dataset\")\n", + "ax.errorbar(\n", + " [0, 1, 2, 3],\n", + " [\n", + " control_train_rel_rmse,\n", + " mask_train_mask_test_rel_rmse,\n", + " mask_train_nomask_test_rel_rmse,\n", + " unitary_rel_rmse,\n", + " ],\n", + " yerr=[\n", + " control_train_std_rel_rmse,\n", + " mask_train_mask_test_std_rel_rmse,\n", + " mask_train_nomask_test_std_rel_rmse,\n", + " unitary_std_rel_rmse,\n", + " ],\n", + " fmt=\"o\",\n", + ")\n", "ax.set_xticks([0, 1, 2, 3])\n", - "ax.set_xticklabels(['Control Train', 'Mask Train Mask Test', 'Mask Train Unit Mask Test', 'Unitary'])\n", - "ax.set_ylabel('Relative RMSE')\n", - "ax.grid('minor')\n", + "ax.set_xticklabels(\n", + " [\"Control Train\", \"Mask Train Mask Test\", \"Mask Train Unit Mask Test\", \"Unitary\"]\n", + ")\n", + "ax.set_ylabel(\"Relative RMSE\")\n", + "ax.grid(\"minor\")\n", "plt.show()" ] }, @@ -167,16 +219,50 @@ "source": [ "# Plot test and train relative RMSE in the same plot\n", "fig, ax = plt.subplots(figsize=(8, 4))\n", - "plt.title('Relative RMSE 1x - Train and Test dataset')\n", - "ax.errorbar([0, 1, 2, 3], [control_train_rel_rmse, mask_train_mask_test_rel_rmse, mask_train_nomask_test_rel_rmse, unitary_rel_rmse], yerr=[control_train_std_rel_rmse, mask_train_mask_test_std_rel_rmse, mask_train_nomask_test_std_rel_rmse, unitary_std_rel_rmse], fmt='o', label='Train')\n", - "ax.errorbar([0.02, 1.02, 2.02, 3.02], [control_test_rel_rmse, mask_test_mask_test_rel_rmse, mask_test_nomask_test_rel_rmse, unitary_test_rel_rmse], yerr=[control_test_std_rel_rmse, mask_test_mask_test_std_rel_rmse, mask_test_nomask_test_std_rel_rmse, unitary_test_std_rel_rmse], fmt='o', label='Test')\n", + "plt.title(\"Relative RMSE 1x - Train and Test dataset\")\n", + "ax.errorbar(\n", + " [0, 1, 2, 3],\n", + " [\n", + " control_train_rel_rmse,\n", + " mask_train_mask_test_rel_rmse,\n", + " mask_train_nomask_test_rel_rmse,\n", + " unitary_rel_rmse,\n", + " ],\n", + " yerr=[\n", + " control_train_std_rel_rmse,\n", + " mask_train_mask_test_std_rel_rmse,\n", + " mask_train_nomask_test_std_rel_rmse,\n", + " unitary_std_rel_rmse,\n", + " ],\n", + " fmt=\"o\",\n", + " label=\"Train\",\n", + ")\n", + "ax.errorbar(\n", + " [0.02, 1.02, 2.02, 3.02],\n", + " [\n", + " control_test_rel_rmse,\n", + " mask_test_mask_test_rel_rmse,\n", + " mask_test_nomask_test_rel_rmse,\n", + " unitary_test_rel_rmse,\n", + " ],\n", + " yerr=[\n", + " control_test_std_rel_rmse,\n", + " mask_test_mask_test_std_rel_rmse,\n", + " mask_test_nomask_test_std_rel_rmse,\n", + " unitary_test_std_rel_rmse,\n", + " ],\n", + " fmt=\"o\",\n", + " label=\"Test\",\n", + ")\n", "ax.set_xticks([0, 1, 2, 3])\n", - "ax.set_xticklabels(['Control Train', 'Mask Train Mask Test', 'Mask Train Unit Mask Test', 'Unitary'])\n", - "ax.set_ylabel('Relative RMSE')\n", - "ax.grid('minor')\n", + "ax.set_xticklabels(\n", + " [\"Control Train\", \"Mask Train Mask Test\", \"Mask Train Unit Mask Test\", \"Unitary\"]\n", + ")\n", + "ax.set_ylabel(\"Relative RMSE\")\n", + "ax.grid(\"minor\")\n", "plt.legend()\n", "# plt.show()\n", - "plt.savefig('masked_loss_validation.pdf')\n" + "plt.savefig(\"masked_loss_validation.pdf\")" ] }, { diff --git a/src/wf_psf/tests/test_data/centroids_test.py b/src/wf_psf/tests/test_data/centroids_test.py index 9a6c6acc..1cb0db16 100644 --- a/src/wf_psf/tests/test_data/centroids_test.py +++ b/src/wf_psf/tests/test_data/centroids_test.py @@ -8,25 +8,23 @@ import numpy as np import pytest -from wf_psf.data.centroids import ( - compute_centroid_correction, - CentroidEstimator -) +from wf_psf.data.centroids import compute_centroid_correction, CentroidEstimator from wf_psf.data.data_handler import extract_star_data from wf_psf.data.data_zernike_utils import compute_zernike_tip_tilt from wf_psf.utils.read_config import RecursiveNamespace from unittest.mock import MagicMock, patch + # Function to compute centroid based on first-order moments def calculate_centroid(image, mask=None): if mask is not None: image = np.ma.masked_array(image, mask=mask) - + # Calculate moments M00 = np.sum(image) M10 = np.sum(np.arange(image.shape[1]) * np.sum(image, axis=0)) M01 = np.sum(np.arange(image.shape[0]) * np.sum(image, axis=1)) - + # Centroid formula xc = M10 / M00 yc = M01 / M00 @@ -37,60 +35,69 @@ def calculate_centroid(image, mask=None): def simple_star_and_mask(): """Fixture for an image with multiple non-zero pixels for centroid calculation.""" num_images = 1 # Change this to test with multiple images - image = np.zeros((num_images, 5, 5)) # Create a 3D array (5x5 image for each "image") - + image = np.zeros( + (num_images, 5, 5) + ) # Create a 3D array (5x5 image for each "image") + # Place non-zero values in multiple pixels image[:, 2, 2] = 10 # Star at the center - image[:, 2, 3] = 10 # Adjacent pixel + image[:, 2, 3] = 10 # Adjacent pixel image[:, 3, 2] = 10 # Adjacent pixel image[:, 3, 3] = 10 # Adjacent pixel forming a symmetric pattern mask = np.zeros_like(image) - mask[:, 3, 2] = 1 - mask[:, 3, 3] = 1 + mask[:, 3, 2] = 1 + mask[:, 3, 3] = 1 return image, mask + @pytest.fixture def simple_image_with_mask(simple_image): """Fixture for a batch of star images with masks.""" - num_images = simple_image.shape[0] # Get the number of images from the first dimension + num_images = simple_image.shape[ + 0 + ] # Get the number of images from the first dimension mask = np.ones((num_images, 5, 5)) # Create a batch of masks mask[:, 1:4, 1:4] = 0 # Mask a 3x3 region for each image return simple_image, mask + @pytest.fixture def centroid_estimator(simple_image): """Fixture for initializing CentroidEstimator.""" return CentroidEstimator(simple_image) + @pytest.fixture def centroid_estimator_with_mask(simple_image_with_mask): """Fixture for initializing CentroidEstimator with a mask.""" image, mask = simple_image_with_mask return CentroidEstimator(image, mask=mask) + @pytest.fixture def simple_image_with_centroid(simple_image): """Fixture for a simple image with known centroid and initial position.""" image = simple_image - + # Known centroid and initial position (xc0, yc0) - for testing xc0, yc0 = 2.0, 2.0 # Assume the initial center of the image is (2.0, 2.0) - + # Create CentroidEstimator instance - centroid_estimator = CentroidEstimator(im=image, n_iter=1) + centroid_estimator = CentroidEstimator(im=image, n_iter=1) - centroid_estimator.window=np.ones_like(image) + centroid_estimator.window = np.ones_like(image) centroid_estimator.xc0 = xc0 centroid_estimator.yc0 = yc0 - + # Simulate the computed centroid being slightly off-center centroid_estimator.xc = 2.3 centroid_estimator.yc = 2.7 - + return centroid_estimator + @pytest.fixture def batch_images(): """Fixture for multiple PSF images.""" @@ -106,27 +113,37 @@ def test_compute_centroid_correction_with_masks(mock_data): model_params = RecursiveNamespace( pix_sampling=12e-6, # Example pixel sampling in meters correct_centroids=True, - reference_shifts=["-1/3", "-1/3"] + reference_shifts=["-1/3", "-1/3"], ) # Mock the internal function calls: - with patch('wf_psf.data.centroids.extract_star_data') as mock_extract_star_data, \ - patch('wf_psf.data.centroids.compute_zernike_tip_tilt') as mock_compute_zernike_tip_tilt: - + with ( + patch("wf_psf.data.centroids.extract_star_data") as mock_extract_star_data, + patch( + "wf_psf.data.centroids.compute_zernike_tip_tilt" + ) as mock_compute_zernike_tip_tilt, + ): + # Mock the return values of extract_star_data and compute_zernike_tip_tilt mock_extract_star_data.side_effect = lambda data, train_key, test_key: ( - np.array([[1, 2], [3, 4]]) if train_key == 'noisy_stars' else np.array([[5, 6], [7, 8]]) + np.array([[1, 2], [3, 4]]) + if train_key == "noisy_stars" + else np.array([[5, 6], [7, 8]]) ) mock_compute_zernike_tip_tilt.return_value = np.array([[0.1, 0.2], [0.3, 0.4]]) # Call the function under test result = compute_centroid_correction(model_params, mock_data) - + # Ensure the result has the correct shape assert result.shape == (4, 3) # Should be (n_stars, 3 Zernike components) - - assert np.allclose(result[0, :], np.array([0, -0.1, -0.2])) # First star Zernike coefficients - assert np.allclose(result[1, :], np.array([0, -0.3, -0.4])) # Second star Zernike coefficients + + assert np.allclose( + result[0, :], np.array([0, -0.1, -0.2]) + ) # First star Zernike coefficients + assert np.allclose( + result[1, :], np.array([0, -0.3, -0.4]) + ) # Second star Zernike coefficients def test_compute_centroid_correction_without_masks(mock_data): @@ -134,26 +151,32 @@ def test_compute_centroid_correction_without_masks(mock_data): # Remove masks from mock_data mock_data.test_data.dataset["masks"] = None mock_data.training_data.dataset["masks"] = None - + # Define model parameters model_params = RecursiveNamespace( pix_sampling=12e-6, # Example pixel sampling in meters correct_centroids=True, - reference_shifts=["-1/3", "-1/3"] + reference_shifts=["-1/3", "-1/3"], ) - + # Mock internal function calls - with patch('wf_psf.data.centroids.extract_star_data') as mock_extract_star_data, \ - patch('wf_psf.data.centroids.compute_zernike_tip_tilt') as mock_compute_zernike_tip_tilt: - + with ( + patch("wf_psf.data.centroids.extract_star_data") as mock_extract_star_data, + patch( + "wf_psf.data.centroids.compute_zernike_tip_tilt" + ) as mock_compute_zernike_tip_tilt, + ): + # Mock extract_star_data to return synthetic star postage stamps mock_extract_star_data.side_effect = lambda data, train_key, test_key: ( - np.array([[1, 2], [3, 4]]) if train_key == 'noisy_stars' else np.array([[5, 6], [7, 8]]) + np.array([[1, 2], [3, 4]]) + if train_key == "noisy_stars" + else np.array([[5, 6], [7, 8]]) ) - + # Mock compute_zernike_tip_tilt assuming no masks mock_compute_zernike_tip_tilt.return_value = np.array([[0.1, 0.2], [0.3, 0.4]]) - + # Call function under test result = compute_centroid_correction(model_params, mock_data) @@ -161,12 +184,14 @@ def test_compute_centroid_correction_without_masks(mock_data): assert result.shape == (4, 3) # (n_stars, 3 Zernike components) # Validate expected values (adjust based on behavior) - expected_result = -1.0 * np.array([ - [0, 0.1, 0.2], # From training data - [0, 0.3, 0.4], - [0, 0.1, 0.2], # From test data (reused mocked return) - [0, 0.3, 0.4] - ]) + expected_result = -1.0 * np.array( + [ + [0, 0.1, 0.2], # From training data + [0, 0.3, 0.4], + [0, 0.1, 0.2], # From test data (reused mocked return) + [0, 0.3, 0.4], + ] + ) assert np.allclose(result, expected_result) @@ -178,6 +203,7 @@ def test_centroid_calculation_one_star(centroid_estimator): assert np.isclose(xc, 2.0) assert np.isclose(yc, 2.0) + # Test for centroid calculation with mask def test_centroid_calculation_with_one_star_and_mask(centroid_estimator_with_mask): """Test that the centroid is calculated correctly when a mask is applied.""" @@ -186,53 +212,61 @@ def test_centroid_calculation_with_one_star_and_mask(centroid_estimator_with_mas assert np.isclose(xc, 2.0) assert np.isclose(yc, 2.0) + def test_centroid_calculation_multiple_images(multiple_images): """Test the centroid estimation for a batch of images.""" estimator = CentroidEstimator(im=multiple_images) - + # Check that centroids are correctly estimated expected_centroids = [(2.0, 2.0), (1.0, 3.0), (3.0, 1.0)] - for i, (xc, yc) in enumerate(zip(estimator.xc,estimator.yc)): + for i, (xc, yc) in enumerate(zip(estimator.xc, estimator.yc)): assert np.isclose(xc, expected_centroids[i][0]) assert np.isclose(yc, expected_centroids[i][1]) + def test_centroid_no_mask(simple_star_and_mask): # Extract star star, _ = simple_star_and_mask # Expected centroid for the symmetric pattern - true_centroid = (2.5, 2.5) - + true_centroid = (2.5, 2.5) + # Create the CentroidEstimator instance (assuming auto_run=True by default) centroid_estimator = CentroidEstimator(im=star, n_iter=1) - + # Check if the centroid is calculated correctly computed_centroid = (centroid_estimator.xc, centroid_estimator.yc) assert np.isclose(computed_centroid[0], true_centroid[0]) assert np.isclose(computed_centroid[1], true_centroid[1]) + # Test for centroid calculation with a mask def test_centroid_with_mask(simple_star_and_mask): # Extract star and mask star, mask = simple_star_and_mask # Expected centroid after masking (estimated manually) - expected_masked_centroid = (2.0, 2.5) - + expected_masked_centroid = (2.0, 2.5) + # Create the CentroidEstimator instance (with mask) centroid_estimator = CentroidEstimator(im=star, mask=mask, n_iter=1) - + # Check if the centroid is calculated correctly with the mask applied computed_centroid = (centroid_estimator.xc, centroid_estimator.yc) assert np.isclose(computed_centroid[0], expected_masked_centroid[0]) assert np.isclose(computed_centroid[1], expected_masked_centroid[1]) + def test_centroid_estimator_initialization(simple_image): """Test the initialization of the CentroidEstimator.""" estimator = CentroidEstimator(simple_image) assert estimator.im.shape == (1, 5, 5) # Shape should match the input image - assert estimator.xc0 == 2.5 # Default xc should be the center of the image, i.e. float(self.stamp_size[0]) / 2 - assert estimator.yc0 == 2.5 # Default yc should be the center of the image, i.e. float(self.stamp_size[0]) / 2 + assert ( + estimator.xc0 == 2.5 + ) # Default xc should be the center of the image, i.e. float(self.stamp_size[0]) / 2 + assert ( + estimator.yc0 == 2.5 + ) # Default yc should be the center of the image, i.e. float(self.stamp_size[0]) / 2 assert estimator.sigma_init == 7.5 # Default sigma_init should be 7.5 assert estimator.n_iter == 5 # Default number of iterations should be 5 assert estimator.mask is None # By default, mask should be None @@ -240,7 +274,7 @@ def test_centroid_estimator_initialization(simple_image): def test_single_iteration(centroid_estimator): """Test that the internal methods are called exactly once for n_iter=1.""" - + # Mock the methods centroid_estimator.update_grid = MagicMock() centroid_estimator.elliptical_gaussian = MagicMock() @@ -248,7 +282,7 @@ def test_single_iteration(centroid_estimator): # Set n_iter to 1 centroid_estimator.n_iter = 1 - + # Run the estimate method centroid_estimator.estimate() @@ -257,14 +291,19 @@ def test_single_iteration(centroid_estimator): centroid_estimator.elliptical_gaussian.assert_called_once() centroid_estimator.compute_moments.assert_called_once() + def test_single_iteration_auto_run(simple_image): """Test that the internal methods are called exactly once for n_iter=1.""" # Patch the methods at the time the object is created - with patch.object(CentroidEstimator, 'update_grid') as update_grid_mock, \ - patch.object(CentroidEstimator, 'elliptical_gaussian') as elliptical_gaussian_mock, \ - patch.object(CentroidEstimator, 'compute_moments') as compute_moments_mock: - + with ( + patch.object(CentroidEstimator, "update_grid") as update_grid_mock, + patch.object( + CentroidEstimator, "elliptical_gaussian" + ) as elliptical_gaussian_mock, + patch.object(CentroidEstimator, "compute_moments") as compute_moments_mock, + ): + # Initialize the CentroidEstimator with auto_run=True centroid_estimator = CentroidEstimator(im=simple_image, n_iter=1, auto_run=True) @@ -273,54 +312,80 @@ def test_single_iteration_auto_run(simple_image): elliptical_gaussian_mock.assert_called_once() compute_moments_mock.assert_called_once() + def test_update_grid(simple_image): """Test that the grid is correctly updated.""" centroid_estimator = CentroidEstimator(im=simple_image, auto_run=True, n_iter=1) - + # Check the shapes of the grid coordinates assert centroid_estimator.xx.shape == (1, 5, 5) assert centroid_estimator.yy.shape == (1, 5, 5) - + # Check the values of the grid coordinates # xx should be the same for all rows and columns (broadcasted across the image) - assert np.allclose(centroid_estimator.xx, - np.array([[[[-2.5, -2.5, -2.5, -2.5, -2.5], - [-1.5, -1.5, -1.5, -1.5, -1.5], - [-0.5, -0.5, -0.5, -0.5, -0.5], - [ 0.5, 0.5, 0.5, 0.5, 0.5], - [ 1.5, 1.5, 1.5, 1.5, 1.5]]]])) - + assert np.allclose( + centroid_estimator.xx, + np.array( + [ + [ + [ + [-2.5, -2.5, -2.5, -2.5, -2.5], + [-1.5, -1.5, -1.5, -1.5, -1.5], + [-0.5, -0.5, -0.5, -0.5, -0.5], + [0.5, 0.5, 0.5, 0.5, 0.5], + [1.5, 1.5, 1.5, 1.5, 1.5], + ] + ] + ] + ), + ) + # yy should be the same for all columns and rows (broadcasted across the image) - assert np.allclose(centroid_estimator.yy, - np.array([[[[-2.5, -1.5, -0.5, 0.5, 1.5], - [-2.5, -1.5, -0.5, 0.5, 1.5], - [-2.5, -1.5, -0.5, 0.5, 1.5], - [-2.5, -1.5, -0.5, 0.5, 1.5], - [-2.5, -1.5, -0.5, 0.5, 1.5]]]])) + assert np.allclose( + centroid_estimator.yy, + np.array( + [ + [ + [ + [-2.5, -1.5, -0.5, 0.5, 1.5], + [-2.5, -1.5, -0.5, 0.5, 1.5], + [-2.5, -1.5, -0.5, 0.5, 1.5], + [-2.5, -1.5, -0.5, 0.5, 1.5], + [-2.5, -1.5, -0.5, 0.5, 1.5], + ] + ] + ] + ), + ) + def test_elliptical_gaussian(simple_image): """Test that the elliptical Gaussian is calculated correctly.""" centroid_estimator = CentroidEstimator(im=simple_image, n_iter=1) # Check if the output is a valid 2D array with the correct shape assert centroid_estimator.window.shape == (1, 5, 5) - + # Check if the Gaussian window values are reasonable (non-negative and decrease with distance) assert np.all(centroid_estimator.window >= 0) - assert np.isclose(np.sum(centroid_estimator.window), 25, atol=1.0) + assert np.isclose(np.sum(centroid_estimator.window), 25, atol=1.0) def test_intra_pixel_shifts(simple_image_with_centroid): """Test the return_intra_pixel_shifts method.""" - + centroid_estimator = simple_image_with_centroid - + # Calculate intra-pixel shifts shifts = centroid_estimator.get_intra_pixel_shifts() - + # Expected intra-pixel shifts expected_x_shift = 2.3 - 2.0 # xc - xc0 expected_y_shift = 2.7 - 2.0 # yc - yc0 - + # Check that the shifts are correct - assert np.isclose(shifts[0], expected_x_shift), f"Expected {expected_x_shift}, got {shifts[0]}" - assert np.isclose(shifts[1], expected_y_shift), f"Expected {expected_y_shift}, got {shifts[1]}" + assert np.isclose( + shifts[0], expected_x_shift + ), f"Expected {expected_x_shift}, got {shifts[0]}" + assert np.isclose( + shifts[1], expected_y_shift + ), f"Expected {expected_y_shift}, got {shifts[1]}" diff --git a/src/wf_psf/tests/test_data/data_handler_test.py b/src/wf_psf/tests/test_data/data_handler_test.py index 5a76c5af..c1310d21 100644 --- a/src/wf_psf/tests/test_data/data_handler_test.py +++ b/src/wf_psf/tests/test_data/data_handler_test.py @@ -10,10 +10,12 @@ import logging from unittest.mock import patch + def mock_sed(): # Create a fake SED with shape (n_wavelengths,) — match what your real SEDs look like return np.linspace(0.1, 1.0, 50) + def test_process_sed_data_auto_load(data_params, simPSF): # load_data=True → dataset is used and SEDs processed automatically data_handler = DataHandler( @@ -43,7 +45,9 @@ def test_load_train_dataset(tmp_path, simPSF): data_params = RecursiveNamespace(data_dir=str(data_dir), file="train_data.npy") n_bins_lambda = 10 - data_handler = DataHandler("training", data_params, simPSF, n_bins_lambda, load_data=False) + data_handler = DataHandler( + "training", data_params, simPSF, n_bins_lambda, load_data=False + ) # Call the load_dataset method data_handler.load_dataset() @@ -74,14 +78,15 @@ def test_load_test_dataset(tmp_path, simPSF): # Initialize DataHandler instance data_params = RecursiveNamespace(data_dir=str(data_dir), file="test_data.npy") - + n_bins_lambda = 10 data_handler = DataHandler( - dataset_type="test", - data_params=data_params, - simPSF=simPSF, - n_bins_lambda=n_bins_lambda, - load_data=False) + dataset_type="test", + data_params=data_params, + simPSF=simPSF, + n_bins_lambda=n_bins_lambda, + load_data=False, + ) # Call the load_dataset method data_handler.load_dataset() @@ -102,18 +107,23 @@ def test_validate_train_dataset_missing_noisy_stars_raises(tmp_path, simPSF): "positions": np.array([[1, 2], [3, 4]]), # No 'noisy_stars' key "SEDs": np.array([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]]), } - + np.save(temp_data_file, mock_dataset) data_params = RecursiveNamespace(data_dir=str(data_dir), file="train_data.npy") n_bins_lambda = 10 - data_handler = DataHandler("training", data_params, simPSF, n_bins_lambda, load_data=False) + data_handler = DataHandler( + "training", data_params, simPSF, n_bins_lambda, load_data=False + ) - with pytest.raises(ValueError, match="Missing required field 'noisy_stars' in training dataset."): + with pytest.raises( + ValueError, match="Missing required field 'noisy_stars' in training dataset." + ): data_handler.load_dataset() data_handler.validate_and_process_dataset() + def test_load_test_dataset_missing_stars(tmp_path, simPSF): """Test that a warning is raised if 'stars' is missing in test data.""" data_dir = tmp_path / "data" @@ -130,9 +140,13 @@ def test_load_test_dataset_missing_stars(tmp_path, simPSF): data_params = RecursiveNamespace(data_dir=str(data_dir), file="test_data.npy") n_bins_lambda = 10 - data_handler = DataHandler("test", data_params, simPSF, n_bins_lambda, load_data=False) + data_handler = DataHandler( + "test", data_params, simPSF, n_bins_lambda, load_data=False + ) - with pytest.raises(ValueError, match="Missing required field 'stars' in test dataset."): + with pytest.raises( + ValueError, match="Missing required field 'stars' in test dataset." + ): data_handler.load_dataset() data_handler.validate_and_process_dataset() @@ -146,17 +160,17 @@ def test_get_np_obs_positions(mock_data): def test_extract_star_data_valid_keys(mock_data): """Test extracting valid data from the dataset.""" result = extract_star_data(mock_data, train_key="noisy_stars", test_key="stars") - - expected = tf.concat([ - tf.constant([ - np.arange(25).reshape(5, 5), - np.arange(25, 50).reshape(5, 5) - ], dtype=tf.float32), - tf.constant([ - np.full((5, 5), 100), - np.full((5, 5), 200) - ], dtype=tf.float32), - ], axis=0) + + expected = tf.concat( + [ + tf.constant( + [np.arange(25).reshape(5, 5), np.arange(25, 50).reshape(5, 5)], + dtype=tf.float32, + ), + tf.constant([np.full((5, 5), 100), np.full((5, 5), 200)], dtype=tf.float32), + ], + axis=0, + ) np.testing.assert_array_equal(result, expected) @@ -180,29 +194,34 @@ def test_extract_star_data_missing_key(mock_data): with pytest.raises(KeyError, match="Missing keys in dataset: \\['invalid_key'\\]"): extract_star_data(mock_data, train_key="invalid_key", test_key="stars") + def test_extract_star_data_partially_missing_key(mock_data): """Test that the function raises a KeyError if only one key is missing.""" - with pytest.raises(KeyError, match="Missing keys in dataset: \\['missing_stars'\\]"): + with pytest.raises( + KeyError, match="Missing keys in dataset: \\['missing_stars'\\]" + ): extract_star_data(mock_data, train_key="noisy_stars", test_key="missing_stars") def test_extract_star_data_tensor_conversion(mock_data): """Test that the function properly converts TensorFlow tensors to NumPy arrays.""" result = extract_star_data(mock_data, train_key="noisy_stars", test_key="stars") - + assert isinstance(result, np.ndarray), "The result should be a NumPy array" assert result.dtype == np.float32, "The NumPy array should have dtype float32" def test_reference_shifts_broadcasting(): - reference_shifts = [-1/3, -1/3] # Example reference_shifts + reference_shifts = [-1 / 3, -1 / 3] # Example reference_shifts shifts = np.random.rand(2, 2400) # Example shifts array # Ensure reference_shifts is a NumPy array (if it's not already) reference_shifts = np.array(reference_shifts) # Broadcast reference_shifts to match the shape of shifts - reference_shifts = np.broadcast_to(reference_shifts[:, None], shifts.shape) # Shape will be (2, 2400) + reference_shifts = np.broadcast_to( + reference_shifts[:, None], shifts.shape + ) # Shape will be (2, 2400) # Ensure shapes are compatible for subtraction displacements = reference_shifts - shifts diff --git a/src/wf_psf/tests/test_data/test_data_utils.py b/src/wf_psf/tests/test_data/test_data_utils.py index 1ebc00cb..de111427 100644 --- a/src/wf_psf/tests/test_data/test_data_utils.py +++ b/src/wf_psf/tests/test_data/test_data_utils.py @@ -1,8 +1,13 @@ - class MockDataset: def __init__(self, positions, zernike_priors, star_type, stars, masks): - self.dataset = {"positions": positions, "zernike_prior": zernike_priors, star_type: stars, "masks": masks} - + self.dataset = { + "positions": positions, + "zernike_prior": zernike_priors, + star_type: stars, + "masks": masks, + } + + class MockData: def __init__( self, @@ -16,15 +21,16 @@ def __init__( masks=None, ): self.training_data = MockDataset( - positions=training_positions, + positions=training_positions, zernike_priors=training_zernike_priors, star_type="noisy_stars", stars=noisy_stars, - masks=noisy_masks) + masks=noisy_masks, + ) self.test_data = MockDataset( - positions=test_positions, + positions=test_positions, zernike_priors=test_zernike_priors, star_type="stars", stars=stars, - masks=masks) - + masks=masks, + ) diff --git a/src/wf_psf/tests/test_inference/psf_inference_test.py b/src/wf_psf/tests/test_inference/psf_inference_test.py index ec9f0495..05728de7 100644 --- a/src/wf_psf/tests/test_inference/psf_inference_test.py +++ b/src/wf_psf/tests/test_inference/psf_inference_test.py @@ -14,13 +14,14 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch, PropertyMock from wf_psf.inference.psf_inference import ( - InferenceConfigHandler, + InferenceConfigHandler, PSFInference, - PSFInferenceEngine + PSFInferenceEngine, ) from wf_psf.utils.read_config import RecursiveNamespace + def _patch_data_handler(): """Helper for patching data_handler to avoid full PSF logic.""" patcher = patch.object(PSFInference, "data_handler", new_callable=PropertyMock) @@ -34,6 +35,7 @@ def fake_process(x): mock_instance.process_sed_data.side_effect = fake_process return patcher, mock_instance + @pytest.fixture def mock_training_config(): training_config = RecursiveNamespace( @@ -60,13 +62,13 @@ def mock_training_config(): LP_filter_length=3, param_hparams=RecursiveNamespace( n_zernikes=10, - - ) - ) - ) + ), + ), + ) ) return training_config + @pytest.fixture def mock_inference_config(): inference_config = RecursiveNamespace( @@ -74,16 +76,12 @@ def mock_inference_config(): batch_size=16, cycle=2, configs=RecursiveNamespace( - trained_model_path='/path/to/trained/model', - model_subdir='psf_model', - trained_model_config_path='config/training_config.yaml', - data_config_path=None - ), - model_params=RecursiveNamespace( - n_bins_lda=8, - output_Q=1, - output_dim=64 + trained_model_path="/path/to/trained/model", + model_subdir="psf_model", + trained_model_config_path="config/training_config.yaml", + data_config_path=None, ), + model_params=RecursiveNamespace(n_bins_lda=8, output_Q=1, output_dim=64), ) ) return inference_config @@ -96,14 +94,18 @@ def psf_test_setup(mock_inference_config): output_dim = 32 mock_positions = tf.convert_to_tensor([[0.1, 0.1], [0.2, 0.2]], dtype=tf.float32) - mock_seds = tf.convert_to_tensor(np.random.rand(num_sources, num_bins, 2), dtype=tf.float32) - expected_psfs = np.random.rand(num_sources, output_dim, output_dim).astype(np.float32) + mock_seds = tf.convert_to_tensor( + np.random.rand(num_sources, num_bins, 2), dtype=tf.float32 + ) + expected_psfs = np.random.rand(num_sources, output_dim, output_dim).astype( + np.float32 + ) inference = PSFInference( "dummy_path.yaml", x_field=[0.1, 0.2], y_field=[0.1, 0.2], - seds=np.random.rand(num_sources, num_bins, 2) + seds=np.random.rand(num_sources, num_bins, 2), ) inference._config_handler = MagicMock() inference._config_handler.inference_config = mock_inference_config @@ -116,9 +118,10 @@ def psf_test_setup(mock_inference_config): "expected_psfs": expected_psfs, "num_sources": num_sources, "num_bins": num_bins, - "output_dim": output_dim + "output_dim": output_dim, } + @pytest.fixture def psf_single_star_setup(mock_inference_config): num_sources = 1 @@ -128,14 +131,18 @@ def psf_single_star_setup(mock_inference_config): # Single position mock_positions = tf.convert_to_tensor([[0.1, 0.1]], dtype=tf.float32) # Shape (1, 2, num_bins) - mock_seds = tf.convert_to_tensor(np.random.rand(num_sources, 2, num_bins), dtype=tf.float32) - expected_psfs = np.random.rand(num_sources, output_dim, output_dim).astype(np.float32) + mock_seds = tf.convert_to_tensor( + np.random.rand(num_sources, 2, num_bins), dtype=tf.float32 + ) + expected_psfs = np.random.rand(num_sources, output_dim, output_dim).astype( + np.float32 + ) inference = PSFInference( "dummy_path.yaml", - x_field=0.1, # scalar for single star + x_field=0.1, # scalar for single star y_field=0.1, - seds=np.random.rand(num_bins, 2) # shape (num_bins, 2) before batching + seds=np.random.rand(num_bins, 2), # shape (num_bins, 2) before batching ) inference._config_handler = MagicMock() inference._config_handler.inference_config = mock_inference_config @@ -148,7 +155,7 @@ def psf_single_star_setup(mock_inference_config): "expected_psfs": expected_psfs, "num_sources": num_sources, "num_bins": num_bins, - "output_dim": output_dim + "output_dim": output_dim, } @@ -165,7 +172,9 @@ def test_set_config_paths(mock_inference_config): # Assertions assert config_handler.trained_model_path == Path("/path/to/trained/model") assert config_handler.model_subdir == "psf_model" - assert config_handler.trained_model_config_path == Path("/path/to/trained/model/config/training_config.yaml") + assert config_handler.trained_model_config_path == Path( + "/path/to/trained/model/config/training_config.yaml" + ) assert config_handler.data_config_path == None @@ -175,15 +184,19 @@ def test_overwrite_model_params(mock_training_config, mock_inference_config): training_config = mock_training_config inference_config = mock_inference_config - InferenceConfigHandler.overwrite_model_params( - training_config, inference_config - ) + InferenceConfigHandler.overwrite_model_params(training_config, inference_config) # Assert that the model_params were overwritten correctly - assert training_config.training.model_params.output_Q == 1, "output_Q should be overwritten" - assert training_config.training.model_params.output_dim == 64, "output_dim should be overwritten" - - assert training_config.training.id_name == "mock_id", "id_name should not be overwritten" + assert ( + training_config.training.model_params.output_Q == 1 + ), "output_Q should be overwritten" + assert ( + training_config.training.model_params.output_dim == 64 + ), "output_dim should be overwritten" + + assert ( + training_config.training.id_name == "mock_id" + ), "id_name should not be overwritten" def test_prepare_configs(mock_training_config, mock_inference_config): @@ -196,7 +209,7 @@ def test_prepare_configs(mock_training_config, mock_inference_config): original_model_params = mock_training_config.training.model_params # Instantiate PSFInference - psf_inf = PSFInference('/dummy/path.yaml') + psf_inf = PSFInference("/dummy/path.yaml") # Mock the config handler attribute with a mock InferenceConfigHandler mock_config_handler = MagicMock(spec=InferenceConfigHandler) @@ -204,7 +217,9 @@ def test_prepare_configs(mock_training_config, mock_inference_config): mock_config_handler.inference_config = inference_config # Patch the overwrite_model_params to use the real static method - mock_config_handler.overwrite_model_params.side_effect = InferenceConfigHandler.overwrite_model_params + mock_config_handler.overwrite_model_params.side_effect = ( + InferenceConfigHandler.overwrite_model_params + ) psf_inf._config_handler = mock_config_handler @@ -223,30 +238,43 @@ def test_config_handler_lazy_load(monkeypatch): class DummyHandler: def load_configs(self): - called['load'] = True + called["load"] = True self.inference_config = {} self.training_config = {} self.data_config = {} - def overwrite_model_params(self, *args): pass - monkeypatch.setattr("wf_psf.inference.psf_inference.InferenceConfigHandler", lambda path: DummyHandler()) + def overwrite_model_params(self, *args): + pass + + monkeypatch.setattr( + "wf_psf.inference.psf_inference.InferenceConfigHandler", + lambda path: DummyHandler(), + ) inference.prepare_configs() - assert 'load' in called # Confirm lazy load happened + assert "load" in called # Confirm lazy load happened + def test_batch_size_positive(): inference = PSFInference("dummy_path.yaml") inference._config_handler = MagicMock() inference._config_handler.inference_config = SimpleNamespace( - inference=SimpleNamespace(batch_size=4, model_params=SimpleNamespace(output_dim=32)) + inference=SimpleNamespace( + batch_size=4, model_params=SimpleNamespace(output_dim=32) + ) ) assert inference.batch_size == 4 -@patch('wf_psf.inference.psf_inference.DataHandler') -@patch('wf_psf.inference.psf_inference.load_trained_psf_model') -def test_load_inference_model(mock_load_trained_psf_model, mock_data_handler, mock_training_config, mock_inference_config): +@patch("wf_psf.inference.psf_inference.DataHandler") +@patch("wf_psf.inference.psf_inference.load_trained_psf_model") +def test_load_inference_model( + mock_load_trained_psf_model, + mock_data_handler, + mock_training_config, + mock_inference_config, +): mock_data_config = MagicMock() mock_data_handler.return_value = mock_data_config mock_config_handler = MagicMock(spec=InferenceConfigHandler) @@ -255,29 +283,33 @@ def test_load_inference_model(mock_load_trained_psf_model, mock_data_handler, mo mock_config_handler.inference_config = mock_inference_config mock_config_handler.model_subdir = "psf_model" mock_config_handler.data_config = MagicMock() - + psf_inf = PSFInference("dummy_path.yaml") psf_inf._config_handler = mock_config_handler psf_inf.load_inference_model() weights_path_pattern = os.path.join( - mock_config_handler.trained_model_path, - mock_config_handler.model_subdir, - f"{mock_config_handler.model_subdir}*_{mock_config_handler.training_config.training.model_params.model_name}*{mock_config_handler.training_config.training.id_name}_cycle{mock_config_handler.inference_config.inference.cycle}*" - ) + mock_config_handler.trained_model_path, + mock_config_handler.model_subdir, + f"{mock_config_handler.model_subdir}*_{mock_config_handler.training_config.training.model_params.model_name}*{mock_config_handler.training_config.training.id_name}_cycle{mock_config_handler.inference_config.inference.cycle}*", + ) # Assert calls to the mocked methods mock_load_trained_psf_model.assert_called_once_with( - mock_training_config, - mock_data_config, - weights_path_pattern + mock_training_config, mock_data_config, weights_path_pattern ) -@patch.object(PSFInference, 'prepare_configs') -@patch.object(PSFInference, '_prepare_positions_and_seds') -@patch.object(PSFInferenceEngine, 'compute_psfs') -def test_run_inference(mock_compute_psfs, mock_prepare_positions_and_seds, mock_prepare_configs, psf_test_setup): + +@patch.object(PSFInference, "prepare_configs") +@patch.object(PSFInference, "_prepare_positions_and_seds") +@patch.object(PSFInferenceEngine, "compute_psfs") +def test_run_inference( + mock_compute_psfs, + mock_prepare_positions_and_seds, + mock_prepare_configs, + psf_test_setup, +): inference = psf_test_setup["inference"] mock_positions = psf_test_setup["mock_positions"] mock_seds = psf_test_setup["mock_seds"] @@ -294,8 +326,11 @@ def test_run_inference(mock_compute_psfs, mock_prepare_positions_and_seds, mock_ mock_compute_psfs.assert_called_once_with(mock_positions, mock_seds) mock_prepare_configs.assert_called_once() + @patch("wf_psf.inference.psf_inference.psf_models.simPSF") -def test_simpsf_uses_updated_model_params(mock_simpsf, mock_training_config, mock_inference_config): +def test_simpsf_uses_updated_model_params( + mock_simpsf, mock_training_config, mock_inference_config +): """Test that simPSF uses the updated model parameters.""" training_config = mock_training_config inference_config = mock_inference_config @@ -315,7 +350,7 @@ def test_simpsf_uses_updated_model_params(mock_simpsf, mock_training_config, moc mock_config_handler.inference_config = inference_config mock_config_handler.model_subdir = "psf_model" mock_config_handler.data_config = MagicMock() - + modeller = PSFInference("dummy_path.yaml") modeller._config_handler = mock_config_handler @@ -330,9 +365,11 @@ def test_simpsf_uses_updated_model_params(mock_simpsf, mock_training_config, moc assert result.output_Q == expected_output_Q -@patch.object(PSFInference, '_prepare_positions_and_seds') -@patch.object(PSFInferenceEngine, 'compute_psfs') -def test_get_psfs_runs_inference(mock_compute_psfs, mock_prepare_positions_and_seds, psf_test_setup): +@patch.object(PSFInference, "_prepare_positions_and_seds") +@patch.object(PSFInferenceEngine, "compute_psfs") +def test_get_psfs_runs_inference( + mock_compute_psfs, mock_prepare_positions_and_seds, psf_test_setup +): inference = psf_test_setup["inference"] mock_positions = psf_test_setup["mock_positions"] mock_seds = psf_test_setup["mock_seds"] @@ -355,7 +392,6 @@ def fake_compute_psfs(positions, seds): assert mock_compute_psfs.call_count == 1 - def test_single_star_inference_shape(psf_single_star_setup): setup = psf_single_star_setup @@ -374,9 +410,11 @@ def test_single_star_inference_shape(psf_single_star_setup): input_array = args[0] # Check input SED had the right shape before being tensorized - assert input_array.shape == (1, setup["num_bins"], 2), \ - "process_sed_data should have been called with shape (1, num_bins, 2)" - + assert input_array.shape == ( + 1, + setup["num_bins"], + 2, + ), "process_sed_data should have been called with shape (1, num_bins, 2)" def test_multiple_star_inference_shape(psf_test_setup): @@ -398,9 +436,12 @@ def test_multiple_star_inference_shape(psf_test_setup): input_array = args[0] # Check input SED had the right shape before being tensorized - assert input_array.shape == (2, setup["num_bins"], 2), \ - "process_sed_data should have been called with shape (2, num_bins, 2)" - + assert input_array.shape == ( + 2, + setup["num_bins"], + 2, + ), "process_sed_data should have been called with shape (2, num_bins, 2)" + def test_valueerror_on_mismatched_batches(psf_single_star_setup): """Raise if sed_data batch size != positions batch size and sed_data != 1.""" @@ -416,7 +457,9 @@ def test_valueerror_on_mismatched_batches(psf_single_star_setup): inference.seds = bad_sed inference.positions = np.ones((1, 2), dtype=np.float32) - with pytest.raises(ValueError, match="SEDs batch size 2 does not match number of positions 1"): + with pytest.raises( + ValueError, match="SEDs batch size 2 does not match number of positions 1" + ): inference._prepare_positions_and_seds() finally: patcher.stop() @@ -435,7 +478,9 @@ def test_valueerror_on_mismatched_positions(psf_single_star_setup): inference.x_field = np.ones((3, 1), dtype=np.float32) inference.y_field = np.ones((3, 1), dtype=np.float32) - with pytest.raises(ValueError, match="SEDs batch size 2 does not match number of positions 3"): + with pytest.raises( + ValueError, match="SEDs batch size 2 does not match number of positions 3" + ): inference._prepare_positions_and_seds() finally: - patcher.stop() \ No newline at end of file + patcher.stop() diff --git a/src/wf_psf/tests/test_metrics/conftest.py b/src/wf_psf/tests/test_metrics/conftest.py index d015024e..f1fe85de 100644 --- a/src/wf_psf/tests/test_metrics/conftest.py +++ b/src/wf_psf/tests/test_metrics/conftest.py @@ -7,6 +7,7 @@ """ + import pytest from unittest.mock import patch, MagicMock import numpy as np @@ -27,6 +28,7 @@ def load_weights(self, *args, **kwargs): # Simulate the weight loading pass + class TFGroundTruthSemiParametricField(TFSemiParametricField): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -35,6 +37,7 @@ def __init__(self, *args, **kwargs): def call(self, inputs, **kwargs): return inputs + @pytest.fixture def mock_psf_model(): # Return a mock instance of TFSemiParametricField @@ -42,8 +45,10 @@ def mock_psf_model(): psf_model.load_weights = MagicMock() # Mock load_weights method return psf_model + @pytest.fixture def mock_get_psf_model(mock_psf_model): - with patch('wf_psf.psf_models.psf_models.get_psf_model', return_value=mock_psf_model) as mock_method: + with patch( + "wf_psf.psf_models.psf_models.get_psf_model", return_value=mock_psf_model + ) as mock_method: yield mock_method - diff --git a/src/wf_psf/tests/test_metrics/metrics_interface_test.py b/src/wf_psf/tests/test_metrics/metrics_interface_test.py index 225ba894..c9084503 100644 --- a/src/wf_psf/tests/test_metrics/metrics_interface_test.py +++ b/src/wf_psf/tests/test_metrics/metrics_interface_test.py @@ -1,31 +1,30 @@ - from unittest.mock import patch, MagicMock import pytest from wf_psf.metrics.metrics_interface import evaluate_model, MetricsParamsHandler from wf_psf.data.data_handler import DataHandler + @pytest.fixture def mock_metrics_params(): return MagicMock( eval_mono_metric=True, eval_opd_metric=False, eval_test_shape_results_dict=True, - eval_train_shape_results_dict=False + eval_train_shape_results_dict=False, ) + @pytest.fixture def mock_trained_model_params(): - return MagicMock( - model_params=MagicMock(model_name="mock_model"), - id_name="mock_id" - ) + return MagicMock(model_params=MagicMock(model_name="mock_model"), id_name="mock_id") + @pytest.fixture def mock_data(): # Create mock instances of the required attributes mock_data_params = MagicMock() mock_simPSF = MagicMock() - + # Mock the `data_params` dictionary for "train" and "test" data mock_data_params.train = MagicMock() mock_data_params.test = MagicMock() @@ -41,34 +40,50 @@ def mock_data(): mock_data_handler.test_data = MagicMock() mock_data_handler.training_data.dataset = { - 'positions': 'train_positions', - 'noisy_stars': 'train_noisy_stars', + "positions": "train_positions", + "noisy_stars": "train_noisy_stars", "SEDs": "train_SEDs", - "C_poly": "train_C_poly" + "C_poly": "train_C_poly", } mock_data_handler.test_data.dataset = { - 'positions': 'test_positions', - 'noisy_stars': 'test_noisy_stars', + "positions": "test_positions", + "noisy_stars": "test_noisy_stars", "SEDs": "test_SEDs", - "C_poly": "test_C_poly" + "C_poly": "test_C_poly", } - mock_data_handler.sed_data = 'mock_sed_data' - + mock_data_handler.sed_data = "mock_sed_data" + # Return the mocked DataHandler instance return mock_data_handler -def test_evaluate_model(mock_metrics_params, mock_trained_model_params, mock_data, mock_psf_model, mocker): +def test_evaluate_model( + mock_metrics_params, mock_trained_model_params, mock_data, mock_psf_model, mocker +): # Mock the metric functions - with patch('wf_psf.metrics.metrics_interface.MetricsParamsHandler.evaluate_metrics_polychromatic_lowres', new_callable=MagicMock) as mock_evaluate_poly_metric, \ - patch('wf_psf.metrics.metrics_interface.MetricsParamsHandler.evaluate_metrics_mono_rmse', new_callable=MagicMock) as mock_evaluate_mono_metric, \ - patch('wf_psf.metrics.metrics_interface.MetricsParamsHandler.evaluate_metrics_opd', new_callable=MagicMock) as mock_evaluate_opd_metric, \ - patch('wf_psf.metrics.metrics_interface.MetricsParamsHandler.evaluate_metrics_shape', new_callable=MagicMock) as mock_evaluate_shape_results_dict, \ - patch('numpy.save', new_callable=MagicMock) as mock_np_save: + with ( + patch( + "wf_psf.metrics.metrics_interface.MetricsParamsHandler.evaluate_metrics_polychromatic_lowres", + new_callable=MagicMock, + ) as mock_evaluate_poly_metric, + patch( + "wf_psf.metrics.metrics_interface.MetricsParamsHandler.evaluate_metrics_mono_rmse", + new_callable=MagicMock, + ) as mock_evaluate_mono_metric, + patch( + "wf_psf.metrics.metrics_interface.MetricsParamsHandler.evaluate_metrics_opd", + new_callable=MagicMock, + ) as mock_evaluate_opd_metric, + patch( + "wf_psf.metrics.metrics_interface.MetricsParamsHandler.evaluate_metrics_shape", + new_callable=MagicMock, + ) as mock_evaluate_shape_results_dict, + patch("numpy.save", new_callable=MagicMock) as mock_np_save, + ): # Mock the logger - logger = mocker.patch('wf_psf.metrics.metrics_interface.logger') + logger = mocker.patch("wf_psf.metrics.metrics_interface.logger") # Call evaluate_model evaluate_model( @@ -76,12 +91,16 @@ def test_evaluate_model(mock_metrics_params, mock_trained_model_params, mock_dat trained_model_params=mock_trained_model_params, data=mock_data, psf_model=mock_psf_model, - metrics_output="/mock/metrics/output" + metrics_output="/mock/metrics/output", ) - # Assertions for metric functions - assert mock_evaluate_poly_metric.call_count == 2 # Called twice, once for each dataset - assert mock_evaluate_mono_metric.call_count == 2 # Called twice, once for each dataset + # Assertions for metric functions + assert ( + mock_evaluate_poly_metric.call_count == 2 + ) # Called twice, once for each dataset + assert ( + mock_evaluate_mono_metric.call_count == 2 + ) # Called twice, once for each dataset mock_evaluate_opd_metric.assert_not_called() # Should not be called because the flag is False mock_evaluate_shape_results_dict.assert_called_once() # Should be called only for the test dataset @@ -90,5 +109,7 @@ def test_evaluate_model(mock_metrics_params, mock_trained_model_params, mock_dat # Validate the np.save call arguments output_path, saved_data = mock_np_save.call_args[0] # Extract arguments - assert "/mock/metrics/output/metrics-mock_modelmock_id" in output_path # Ensure correct path format - assert isinstance(saved_data, dict) # Ensure data being saved is a dictionary \ No newline at end of file + assert ( + "/mock/metrics/output/metrics-mock_modelmock_id" in output_path + ) # Ensure correct path format + assert isinstance(saved_data, dict) # Ensure data being saved is a dictionary diff --git a/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py b/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py index 95ad6287..e900a6d3 100644 --- a/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py +++ b/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py @@ -32,16 +32,16 @@ def zks_prior(): def mock_data(mocker, zks_prior): mock_instance = mocker.Mock(spec=DataConfigHandler) mock_instance.run_type = "training" - + training_dataset = { "positions": np.array([[1, 2], [3, 4]]), "zernike_prior": zks_prior, - "noisy_stars": np.zeros((2, 1, 1, 1)), + "noisy_stars": np.zeros((2, 1, 1, 1)), } test_dataset = { "positions": np.array([[5, 6], [7, 8]]), "zernike_prior": zks_prior, - "stars": np.zeros((2, 1, 1, 1)), + "stars": np.zeros((2, 1, 1, 1)), } mock_instance.training_data = mocker.Mock() @@ -60,44 +60,53 @@ def mock_model_params(mocker): model_params_mock.pupil_diameter = 256 return model_params_mock + @pytest.fixture def physical_layer_instance(mocker, mock_model_params, mock_data): # Patch expensive methods during construction to avoid errors - with patch("wf_psf.psf_models.models.psf_model_physical_polychromatic.TFPhysicalPolychromaticField._assemble_zernike_contributions", return_value=tf.constant([[[[1.0]]], [[[2.0]]]])): - from wf_psf.psf_models.models.psf_model_physical_polychromatic import TFPhysicalPolychromaticField - instance = TFPhysicalPolychromaticField(mock_model_params, mocker.Mock(), mock_data) + with patch( + "wf_psf.psf_models.models.psf_model_physical_polychromatic.TFPhysicalPolychromaticField._assemble_zernike_contributions", + return_value=tf.constant([[[[1.0]]], [[[2.0]]]]), + ): + from wf_psf.psf_models.models.psf_model_physical_polychromatic import ( + TFPhysicalPolychromaticField, + ) + + instance = TFPhysicalPolychromaticField( + mock_model_params, mocker.Mock(), mock_data + ) return instance + def test_compute_zernikes(mocker, physical_layer_instance): # Expected output of mock components - padded_zernike_param = tf.constant([[[[10]], [[20]], [[30]], [[40]]]], dtype=tf.float32) + padded_zernike_param = tf.constant( + [[[[10]], [[20]], [[30]], [[40]]]], dtype=tf.float32 + ) padded_zernike_prior = tf.constant([[[[1]], [[2]], [[0]], [[0]]]], dtype=tf.float32) n_zks_total = physical_layer_instance.n_zks_total expected_values_list = [11, 22, 30, 40] + [0] * (n_zks_total - 4) expected_values = tf.constant( - [[[[v]] for v in expected_values_list]], - dtype=tf.float32 -) + [[[[v]] for v in expected_values_list]], dtype=tf.float32 + ) # Patch tf_poly_Z_field method mocker.patch.object( TFPhysicalPolychromaticField, "tf_poly_Z_field", - return_value=padded_zernike_param + return_value=padded_zernike_param, ) # Patch tf_physical_layer.call method mock_tf_physical_layer = mocker.Mock() mock_tf_physical_layer.call.return_value = padded_zernike_prior mocker.patch.object( - TFPhysicalPolychromaticField, - "tf_physical_layer", - mock_tf_physical_layer + TFPhysicalPolychromaticField, "tf_physical_layer", mock_tf_physical_layer ) # Patch pad_tf_zernikes function mocker.patch( "wf_psf.data.data_zernike_utils.pad_tf_zernikes", - return_value=(padded_zernike_param, padded_zernike_prior) + return_value=(padded_zernike_param, padded_zernike_prior), ) # Run the test diff --git a/src/wf_psf/tests/test_psf_models/psf_models_test.py b/src/wf_psf/tests/test_psf_models/psf_models_test.py index b7c906f6..2b907eff 100644 --- a/src/wf_psf/tests/test_psf_models/psf_models_test.py +++ b/src/wf_psf/tests/test_psf_models/psf_models_test.py @@ -10,7 +10,7 @@ from wf_psf.psf_models import psf_models from wf_psf.psf_models.models import ( psf_model_semiparametric, - psf_model_physical_polychromatic + psf_model_physical_polychromatic, ) import tensorflow as tf import numpy as np diff --git a/src/wf_psf/tests/test_utils/configs_handler_test.py b/src/wf_psf/tests/test_utils/configs_handler_test.py index 2f2b6c9c..8898df91 100644 --- a/src/wf_psf/tests/test_utils/configs_handler_test.py +++ b/src/wf_psf/tests/test_utils/configs_handler_test.py @@ -12,9 +12,9 @@ from wf_psf.utils.read_config import RecursiveNamespace from wf_psf.utils.io import FileIOHandler from wf_psf.utils.configs_handler import ( - TrainingConfigHandler, - MetricsConfigHandler, - DataConfigHandler + TrainingConfigHandler, + MetricsConfigHandler, + DataConfigHandler, ) import os @@ -110,9 +110,8 @@ def test_get_run_config(path_to_repo_dir, path_to_tmp_output_dir, path_to_config assert type(config_class) is RegisterConfigClass -def test_data_config_handler_init( - mock_training_conf, mock_data_read_conf, mocker -): + +def test_data_config_handler_init(mock_training_conf, mock_data_read_conf, mocker): # Mock read_conf function mock_data_read_conf() @@ -130,13 +129,16 @@ def test_data_config_handler_init( # Patch load_dataset to assign dataset def mock_load_dataset(self): - self.dataset = {"SEDs": ["dummy_sed_data"], "positions": ["dummy_positions_data"]} + self.dataset = { + "SEDs": ["dummy_sed_data"], + "positions": ["dummy_positions_data"], + } mocker.patch.object(DataHandler, "load_dataset", new=mock_load_dataset) # Create DataConfigHandler instance data_config_handler = DataConfigHandler( - "/path/to/data_config.yaml", + "/path/to/data_config.yaml", mock_training_conf.training.model_params, mock_training_conf.training.training_hparams.batch_size, ) @@ -153,12 +155,11 @@ def mock_load_dataset(self): == mock_training_conf.training.model_params.n_bins_lda ) assert ( - data_config_handler.batch_size + data_config_handler.batch_size == mock_training_conf.training.training_hparams.batch_size ) - def test_training_config_handler_init(mocker, mock_training_conf, mock_file_handler): # Mock read_conf function mocker.patch( diff --git a/src/wf_psf/tests/test_utils/utils_test.py b/src/wf_psf/tests/test_utils/utils_test.py index 7d125014..adabeaa0 100644 --- a/src/wf_psf/tests/test_utils/utils_test.py +++ b/src/wf_psf/tests/test_utils/utils_test.py @@ -16,6 +16,7 @@ ) from wf_psf.sims.psf_simulator import PSFSimulator + def test_initialization(): """Test if NoiseEstimator initializes correctly.""" img_dim = (50, 50) @@ -27,6 +28,7 @@ def test_initialization(): assert isinstance(estimator.window, np.ndarray) assert estimator.window.shape == img_dim + def test_init_window(): """Test that the exclusion window is correctly initialized.""" img_dim = (50, 50) @@ -41,13 +43,17 @@ def test_init_window(): inside_radius = np.sqrt((x - mid_x) ** 2 + (y - mid_y) ** 2) <= win_rad assert estimator.window[x, y] == (not inside_radius) + def test_sigma_mad(): """Test the MAD-based standard deviation estimation.""" - data = np.array([1, 1, 2, 2, 3, 3, 4, 4, 100]) # Outlier should not heavily influence MAD + data = np.array( + [1, 1, 2, 2, 3, 3, 4, 4, 100] + ) # Outlier should not heavily influence MAD expected_sigma = 1.4826 * np.median(np.abs(data - np.median(data))) assert np.isclose(NoiseEstimator.sigma_mad(data), expected_sigma, atol=1e-4) + def test_estimate_noise_without_default_window(): """Test noise estimation with the default exclusion window (no custom mask).""" img_dim = (50, 50) @@ -59,10 +65,11 @@ def test_estimate_noise_without_default_window(): image = np.random.normal(0, 10, img_dim) noise_estimation = estimator.estimate_noise(image) - + # The estimated noise should be close to 10 (the true std) assert np.isclose(noise_estimation, 10, atol=2) + def test_estimate_noise_with_custom_mask(): """Test noise estimation with a custom mask applied outside the exclusion radius.""" img_dim = (50, 50) @@ -80,6 +87,7 @@ def test_estimate_noise_with_custom_mask(): assert np.isclose(noise_estimation, 5, atol=1) + def test_apply_mask_with_none_mask(): """Test apply_mask when mask is None.""" img_dim = (10, 10) @@ -88,13 +96,16 @@ def test_apply_mask_with_none_mask(): result = estimator.apply_mask(None) # Pass None as the mask # It should return the window itself when no mask is provided - assert np.array_equal(result, estimator.window), "apply_mask should return the window when mask is None." + assert np.array_equal( + result, estimator.window + ), "apply_mask should return the window when mask is None." + def test_apply_mask_with_valid_mask(): """Test apply_mask when a valid mask is provided.""" img_dim = (10, 10) estimator = NoiseEstimator(img_dim, win_rad=3) - + # Create a custom mask custom_mask = np.ones(img_dim, dtype=bool) custom_mask[5, 5] = False # Set a pixel to False to exclude it from the window @@ -103,7 +114,10 @@ def test_apply_mask_with_valid_mask(): # Check that the mask was applied correctly: pixel (5, 5) should be False, others True expected_result = estimator.window & custom_mask - assert np.array_equal(result, expected_result), "apply_mask did not apply the mask correctly." + assert np.array_equal( + result, expected_result + ), "apply_mask did not apply the mask correctly." + def test_apply_mask_with_zeroed_mask(): """Test apply_mask when a zeroed mask is provided.""" @@ -117,7 +131,9 @@ def test_apply_mask_with_zeroed_mask(): # The result should be an array of False values, as the mask excludes all pixels expected_result = np.zeros(img_dim, dtype=bool) - assert np.array_equal(result, expected_result), "apply_mask did not handle the zeroed mask correctly." + assert np.array_equal( + result, expected_result + ), "apply_mask did not handle the zeroed mask correctly." def test_unobscured_zernike_projection(): @@ -180,7 +196,9 @@ def test_tf_decompose_obscured_opd_basis(): tf_zernike_cube = tf.convert_to_tensor(np_zernike_cube, dtype=tf.float32) # Create obscurations - obscurations = PSFSimulator.generate_euclid_pupil_obscurations(N_pix=wfe_dim, N_filter=2) + obscurations = PSFSimulator.generate_euclid_pupil_obscurations( + N_pix=wfe_dim, N_filter=2 + ) tf_obscurations = tf.convert_to_tensor(obscurations, dtype=tf.float32) # Create random zernike coefficient array diff --git a/src/wf_psf/training/train.py b/src/wf_psf/training/train.py index 2da47900..38161edc 100644 --- a/src/wf_psf/training/train.py +++ b/src/wf_psf/training/train.py @@ -34,6 +34,7 @@ def get_gpu_info(): device_name = tf.test.gpu_device_name() return device_name + def setup_training(): """Set up Training. @@ -274,7 +275,7 @@ def _prepare_callbacks( def get_loss_metrics_monitor_and_outputs(training_handler, data_conf): """Factory to return fresh loss, metrics (param & non-param), monitor, and outputs for the current cycle. - + Parameters ---------- training_handler: TrainingParamsHandler @@ -296,7 +297,7 @@ def get_loss_metrics_monitor_and_outputs(training_handler, data_conf): Tensor containing the outputs for training output_val: tf.Tensor Tensor containing the outputs for validation - + """ if training_handler.training_hparams.loss == "mask_mse": @@ -305,10 +306,18 @@ def get_loss_metrics_monitor_and_outputs(training_handler, data_conf): param_metrics = [train_utils.MaskedMeanSquaredErrorMetric()] non_param_metrics = [train_utils.MaskedMeanSquaredErrorMetric()] outputs = tf.stack( - [data_conf.training_data.dataset["noisy_stars"], data_conf.training_data.dataset["masks"]], axis=-1 + [ + data_conf.training_data.dataset["noisy_stars"], + data_conf.training_data.dataset["masks"], + ], + axis=-1, ) output_val = tf.stack( - [data_conf.test_data.dataset["stars"], data_conf.test_data.dataset["masks"]], axis=-1 + [ + data_conf.test_data.dataset["stars"], + data_conf.test_data.dataset["masks"], + ], + axis=-1, ) else: loss = tf.keras.losses.MeanSquaredError() @@ -333,7 +342,7 @@ def train( This function manages multi-cycle training of a parametric + non-parametric PSF model, including initialization, loss/metric configuration, optimizer setup, model checkpointing, - and optional projection or resetting of non-parametric features. Each cycle can include + and optional projection or resetting of non-parametric features. Each cycle can include both parametric and non-parametric training stages, and training history is saved for each. Parameters @@ -364,7 +373,7 @@ def train( None Side Effects - ------------ + ------------ - Saves model weights to `psf_model_dir` per training cycle (or final one if not all saved) - Saves optimizer histories to `optimizer_dir` - Logs cycle information and time durations @@ -392,10 +401,10 @@ def train( current_cycle += 1 # Instantiate fresh loss, monitor, and independent metric objects per training phase (param / non-param) - loss, param_metrics, non_param_metrics, monitor, outputs, output_val = get_loss_metrics_monitor_and_outputs( - training_handler, data_conf + loss, param_metrics, non_param_metrics, monitor, outputs, output_val = ( + get_loss_metrics_monitor_and_outputs(training_handler, data_conf) ) - + # If projected learning is enabled project DD_features. if hasattr(psf_model, "project_dd_features") and psf_model.project_dd_features: if current_cycle > 1: diff --git a/src/wf_psf/utils/configs_handler.py b/src/wf_psf/utils/configs_handler.py index 8de55714..c8c4ea3a 100644 --- a/src/wf_psf/utils/configs_handler.py +++ b/src/wf_psf/utils/configs_handler.py @@ -122,7 +122,7 @@ class DataConfigHandler: training_model_params : Recursive Namespace object Recursive Namespace object containing the training model parameters batch_size : int - Training hyperparameter used for batched pre-processing of data. + Training hyperparameter used for batched pre-processing of data. """ @@ -134,7 +134,7 @@ def __init__(self, data_conf, training_model_params, batch_size=16, load_data=Tr exit() self.simPSF = psf_models.simPSF(training_model_params) - + # Extract sub-configs early train_params = self.data_conf.data.training test_params = self.data_conf.data.test @@ -153,7 +153,7 @@ def __init__(self, data_conf, training_model_params, batch_size=16, load_data=Tr n_bins_lambda=training_model_params.n_bins_lda, load_data=load_data, ) - + self.batch_size = batch_size @@ -264,9 +264,11 @@ def __init__(self, metrics_conf, file_handler, training_conf=None): self.training_conf = training_conf self.data_conf = self._load_data_conf() self.data_conf.run_type = "metrics" - self.metrics_dir = self._file_handler.get_metrics_dir(self._file_handler._run_output_dir) + self.metrics_dir = self._file_handler.get_metrics_dir( + self._file_handler._run_output_dir + ) self.trained_psf_model = self._load_trained_psf_model() - + @property def metrics_conf(self): return self._metrics_conf @@ -285,7 +287,9 @@ def training_conf(self, training_conf): if training_conf is None: try: training_conf_path = self._get_training_conf_path_from_metrics() - logger.info(f"Loading training config from inferred path: {training_conf_path}") + logger.info( + f"Loading training config from inferred path: {training_conf_path}" + ) self._training_conf = read_conf(training_conf_path) except Exception as e: logger.error(f"Failed to load training config: {e}") @@ -307,14 +311,11 @@ def _load_trained_psf_model(self): model_name = self.training_conf.training.model_params.model_name id_name = self.training_conf.training.id_name - + weights_path_pattern = os.path.join( - trained_model_path, - model_subdir, - ( - f"{model_subdir}*_{model_name}" - f"*{id_name}_cycle{cycle}*" - ), + trained_model_path, + model_subdir, + (f"{model_subdir}*_{model_name}" f"*{id_name}_cycle{cycle}*"), ) return load_trained_psf_model( self.training_conf, @@ -322,7 +323,6 @@ def _load_trained_psf_model(self): weights_path_pattern, ) - def _get_training_conf_path_from_metrics(self): """ Retrieves the full path to the training config based on the metrics configuration. @@ -344,22 +344,27 @@ def _get_training_conf_path_from_metrics(self): try: training_conf_filename = self._metrics_conf.metrics.trained_model_config except AttributeError as e: - raise KeyError("Missing 'trained_model_config' key in metrics configuration.") from e + raise KeyError( + "Missing 'trained_model_config' key in metrics configuration." + ) from e training_conf_path = os.path.join( - self._file_handler.get_config_dir(trained_model_path), training_conf_filename) + self._file_handler.get_config_dir(trained_model_path), + training_conf_filename, + ) if not os.path.exists(training_conf_path): - raise FileNotFoundError(f"Training config file not found: {training_conf_path}") + raise FileNotFoundError( + f"Training config file not found: {training_conf_path}" + ) return training_conf_path - def _get_trained_model_path(self): """ Determine the trained model path from either: - - 1. The metrics configuration file (i.e., for metrics-only runs after training), or + + 1. The metrics configuration file (i.e., for metrics-only runs after training), or 2. The runtime-generated file handler paths (i.e., for single runs that perform both training and evaluation). Returns @@ -372,14 +377,18 @@ def _get_trained_model_path(self): ConfigParameterError If the path specified in the metrics config is invalid or missing. """ - trained_model_path = getattr(self._metrics_conf.metrics, "trained_model_path", None) + trained_model_path = getattr( + self._metrics_conf.metrics, "trained_model_path", None + ) if trained_model_path: if not os.path.isdir(trained_model_path): raise ConfigParameterError( f"The trained model path provided in the metrics config is not a valid directory: {trained_model_path}" ) - logger.info(f"Using trained model path from metrics config: {trained_model_path}") + logger.info( + f"Using trained model path from metrics config: {trained_model_path}" + ) return trained_model_path # Fallback for single-run training + metrics evaluation mode @@ -388,7 +397,9 @@ def _get_trained_model_path(self): self._file_handler.parent_output_dir, self._file_handler.workdir, ) - logger.info(f"Using fallback trained model path from runtime file handler: {fallback_path}") + logger.info( + f"Using fallback trained model path from runtime file handler: {fallback_path}" + ) return fallback_path def _load_data_conf(self): @@ -413,7 +424,6 @@ def _load_data_conf(self): logger.exception(e) raise ConfigParameterError("Data configuration loading error.") - def call_plot_config_handler_run(self, model_metrics): """Make Metrics Plots. diff --git a/src/wf_psf/utils/read_config.py b/src/wf_psf/utils/read_config.py index f6ca3bf8..922c34b6 100644 --- a/src/wf_psf/utils/read_config.py +++ b/src/wf_psf/utils/read_config.py @@ -33,9 +33,9 @@ class RecursiveNamespace(SimpleNamespace): def __init__(self, **kwargs): super().__init__(**kwargs) for key, val in kwargs.items(): - if isinstance(val,dict): + if isinstance(val, dict): setattr(self, key, RecursiveNamespace(**val)) - elif isinstance(val,list): + elif isinstance(val, list): setattr(self, key, list(map(self.map_entry, val))) @staticmethod diff --git a/src/wf_psf/utils/utils.py b/src/wf_psf/utils/utils.py index f13d2c59..7c8d5e41 100644 --- a/src/wf_psf/utils/utils.py +++ b/src/wf_psf/utils/utils.py @@ -19,6 +19,7 @@ def scale_to_range(input_array, old_range, new_range): input_array = input_array * (new_range[1] - new_range[0]) + new_range[0] return input_array + def ensure_batch(arr): """ Ensure array/tensor has a batch dimension. Converts shape (M, N) → (1, M, N). From 4e9add16dd39d9aa8d30f7589fc63b45df041f9a Mon Sep 17 00:00:00 2001 From: jeipollack Date: Fri, 5 Sep 2025 19:31:46 +0100 Subject: [PATCH 105/146] Correct type hint errors --- src/wf_psf/data/data_handler.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index c0ddf7b0..bdcf9a6b 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -241,7 +241,7 @@ def process_sed_data(self, sed_data): def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: - """ + """ Extract and concatenate star-related data from training and test datasets. This function retrieves arrays (e.g., postage stamps, masks, positions) from @@ -310,8 +310,8 @@ def get_data_array( key: str = None, train_key: str = None, test_key: str = None, - allow_missing: bool = False, -) -> np.ndarray | None: + allow_missing: bool = True, +) -> Optional[np.ndarray]: """ Retrieve data from dataset depending on run type. @@ -337,7 +337,7 @@ def get_data_array( test_key : str, optional Key for test dataset access. If None, defaults to the resolved train_key value. Default is None. - allow_missing : bool, default False + allow_missing : bool, default True Control behavior when data is missing or keys are not found: - True: Return None instead of raising exceptions - False: Raise appropriate exceptions (KeyError, ValueError) @@ -401,7 +401,7 @@ def get_data_array( raise -def _get_direct_data(data, key: str, allow_missing: bool) -> np.ndarray | None: +def _get_direct_data(data, key: str, allow_missing: bool) -> Optional[np.ndarray]: """ Extract data directly with proper error handling and type conversion. From 7c07a7e65d99f2e3edb08accc451232ab4fa3895 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Fri, 5 Sep 2025 19:32:16 +0100 Subject: [PATCH 106/146] Remove unused import --- src/wf_psf/instrument/ccd_misalignments.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/wf_psf/instrument/ccd_misalignments.py b/src/wf_psf/instrument/ccd_misalignments.py index 6b745ad4..9f2fb221 100644 --- a/src/wf_psf/instrument/ccd_misalignments.py +++ b/src/wf_psf/instrument/ccd_misalignments.py @@ -10,7 +10,6 @@ import numpy as np import matplotlib.path as mpltPath from scipy.spatial import KDTree -from wf_psf.data.data_handler import get_np_obs_positions def compute_ccd_misalignment(model_params, positions: np.ndarray) -> np.ndarray: From 71a50eaf1b9bfd3987327c505c349b049f5c5bca Mon Sep 17 00:00:00 2001 From: jeipollack Date: Fri, 5 Sep 2025 19:32:56 +0100 Subject: [PATCH 107/146] Replace call to deprecated get_np_obs_positions with get_data_array --- src/wf_psf/psf_models/tf_modules/tf_psf_field.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/wf_psf/psf_models/tf_modules/tf_psf_field.py b/src/wf_psf/psf_models/tf_modules/tf_psf_field.py index df18bfd3..c94ca2e5 100644 --- a/src/wf_psf/psf_models/tf_modules/tf_psf_field.py +++ b/src/wf_psf/psf_models/tf_modules/tf_psf_field.py @@ -16,7 +16,7 @@ TFPhysicalLayer, ) from wf_psf.psf_models.models.psf_model_semiparametric import TFSemiParametricField -from wf_psf.data.data_handler import get_np_obs_positions +from wf_psf.data.data_handler import get_data_array from wf_psf.psf_models import psf_models as psfm from wf_psf.psf_models.tf_modules.tf_utils import ensure_tensor import logging @@ -222,7 +222,7 @@ def __init__(self, model_params, training_params, data, coeff_mat): self.output_Q = model_params.output_Q # Inputs: TF_physical_layer - self.obs_pos = ensure_tensor(get_np_obs_positions(data), dtype=tf.float32) + self.obs_pos = ensure_tensor(get_data_array(data, data.run_type, key="positions"), dtype=tf.float32) self.zks_prior = get_ground_truth_zernike(data) self.n_zks_prior = tf.shape(self.zks_prior)[1].numpy() From 0fd8b9ee73a775145880bde71a451e55475e07f0 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 31 Oct 2025 12:33:40 +0100 Subject: [PATCH 108/146] Remove unused imports and reformat --- src/wf_psf/data/centroids.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/wf_psf/data/centroids.py b/src/wf_psf/data/centroids.py index 20391793..01135428 100644 --- a/src/wf_psf/data/centroids.py +++ b/src/wf_psf/data/centroids.py @@ -8,9 +8,7 @@ import numpy as np import scipy.signal as scisig -from wf_psf.data.data_handler import extract_star_data from fractions import Fraction -import tensorflow as tf from typing import Optional @@ -252,7 +250,6 @@ def __init__( self, im, mask=None, sigma_init=7.5, n_iter=5, auto_run=True, xc=None, yc=None ): """Initialize class attributes.""" - # Convert to np.ndarray if not already im = np.asarray(im) if mask is not None: From 660cee23ad88dad3e7553319351f49ceeea87542 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 31 Oct 2025 12:35:13 +0100 Subject: [PATCH 109/146] Update fixtures and unit tests --- src/wf_psf/tests/test_data/centroids_test.py | 41 ++--- src/wf_psf/tests/test_data/conftest.py | 16 +- .../tests/test_data/data_handler_test.py | 148 ++++++++++++++++-- .../test_data/data_zernike_utils_test.py | 30 +++- .../test_inference/psf_inference_test.py | 1 - 5 files changed, 192 insertions(+), 44 deletions(-) diff --git a/src/wf_psf/tests/test_data/centroids_test.py b/src/wf_psf/tests/test_data/centroids_test.py index 1cb0db16..e22ab042 100644 --- a/src/wf_psf/tests/test_data/centroids_test.py +++ b/src/wf_psf/tests/test_data/centroids_test.py @@ -9,8 +9,6 @@ import numpy as np import pytest from wf_psf.data.centroids import compute_centroid_correction, CentroidEstimator -from wf_psf.data.data_handler import extract_star_data -from wf_psf.data.data_zernike_utils import compute_zernike_tip_tilt from wf_psf.utils.read_config import RecursiveNamespace from unittest.mock import MagicMock, patch @@ -116,24 +114,23 @@ def test_compute_centroid_correction_with_masks(mock_data): reference_shifts=["-1/3", "-1/3"], ) + # Wrap mock_data into a dict to match the function signature + centroid_dataset = { + "stamps": mock_data.training_data.dataset["noisy_stars"], + "masks": mock_data.training_data.dataset["masks"], + } + # Mock the internal function calls: with ( - patch("wf_psf.data.centroids.extract_star_data") as mock_extract_star_data, patch( "wf_psf.data.centroids.compute_zernike_tip_tilt" ) as mock_compute_zernike_tip_tilt, ): - - # Mock the return values of extract_star_data and compute_zernike_tip_tilt - mock_extract_star_data.side_effect = lambda data, train_key, test_key: ( - np.array([[1, 2], [3, 4]]) - if train_key == "noisy_stars" - else np.array([[5, 6], [7, 8]]) - ) + # Mock compute_zernike_tip_tilt to return synthetic Zernike coefficients mock_compute_zernike_tip_tilt.return_value = np.array([[0.1, 0.2], [0.3, 0.4]]) # Call the function under test - result = compute_centroid_correction(model_params, mock_data) + result = compute_centroid_correction(model_params, centroid_dataset) # Ensure the result has the correct shape assert result.shape == (4, 3) # Should be (n_stars, 3 Zernike components) @@ -148,10 +145,6 @@ def test_compute_centroid_correction_with_masks(mock_data): def test_compute_centroid_correction_without_masks(mock_data): """Test compute_centroid_correction function when no masks are provided.""" - # Remove masks from mock_data - mock_data.test_data.dataset["masks"] = None - mock_data.training_data.dataset["masks"] = None - # Define model parameters model_params = RecursiveNamespace( pix_sampling=12e-6, # Example pixel sampling in meters @@ -159,26 +152,23 @@ def test_compute_centroid_correction_without_masks(mock_data): reference_shifts=["-1/3", "-1/3"], ) + # Wrap mock_data into a dict to match the function signature + centroid_dataset = { + "stamps": mock_data.training_data.dataset["noisy_stars"], + } + # Mock internal function calls with ( - patch("wf_psf.data.centroids.extract_star_data") as mock_extract_star_data, patch( "wf_psf.data.centroids.compute_zernike_tip_tilt" ) as mock_compute_zernike_tip_tilt, ): - # Mock extract_star_data to return synthetic star postage stamps - mock_extract_star_data.side_effect = lambda data, train_key, test_key: ( - np.array([[1, 2], [3, 4]]) - if train_key == "noisy_stars" - else np.array([[5, 6], [7, 8]]) - ) - # Mock compute_zernike_tip_tilt assuming no masks mock_compute_zernike_tip_tilt.return_value = np.array([[0.1, 0.2], [0.3, 0.4]]) # Call function under test - result = compute_centroid_correction(model_params, mock_data) + result = compute_centroid_correction(model_params, centroid_dataset) # Validate result shape assert result.shape == (4, 3) # (n_stars, 3 Zernike components) @@ -274,7 +264,6 @@ def test_centroid_estimator_initialization(simple_image): def test_single_iteration(centroid_estimator): """Test that the internal methods are called exactly once for n_iter=1.""" - # Mock the methods centroid_estimator.update_grid = MagicMock() centroid_estimator.elliptical_gaussian = MagicMock() @@ -294,7 +283,6 @@ def test_single_iteration(centroid_estimator): def test_single_iteration_auto_run(simple_image): """Test that the internal methods are called exactly once for n_iter=1.""" - # Patch the methods at the time the object is created with ( patch.object(CentroidEstimator, "update_grid") as update_grid_mock, @@ -372,7 +360,6 @@ def test_elliptical_gaussian(simple_image): def test_intra_pixel_shifts(simple_image_with_centroid): """Test the return_intra_pixel_shifts method.""" - centroid_estimator = simple_image_with_centroid # Calculate intra-pixel shifts diff --git a/src/wf_psf/tests/test_data/conftest.py b/src/wf_psf/tests/test_data/conftest.py index 47eed929..131922e5 100644 --- a/src/wf_psf/tests/test_data/conftest.py +++ b/src/wf_psf/tests/test_data/conftest.py @@ -11,9 +11,11 @@ import pytest import numpy as np import tensorflow as tf +from types import SimpleNamespace + from wf_psf.utils.read_config import RecursiveNamespace from wf_psf.psf_models import psf_models -from wf_psf.tests.test_data.test_data_utils import MockData, MockDataset +from wf_psf.tests.test_data.test_data_utils import MockData training_config = RecursiveNamespace( id_name="-coherent_euclid_200stars", @@ -129,6 +131,18 @@ def mock_data(scope="module"): ) +@pytest.fixture +def mock_data_inference(): + """Flat dataset for inference path only.""" + return SimpleNamespace( + dataset={ + "positions": np.array([[9, 9], [10, 10]]), + "zernike_prior": np.array([[0.9, 0.9]]), + # no "missing_key" → used to trigger allow_missing behavior + } + ) + + @pytest.fixture def simple_image(scope="module"): """Fixture for a simple star image.""" diff --git a/src/wf_psf/tests/test_data/data_handler_test.py b/src/wf_psf/tests/test_data/data_handler_test.py index c1310d21..d29771a1 100644 --- a/src/wf_psf/tests/test_data/data_handler_test.py +++ b/src/wf_psf/tests/test_data/data_handler_test.py @@ -3,12 +3,10 @@ import tensorflow as tf from wf_psf.data.data_handler import ( DataHandler, - get_np_obs_positions, + get_data_array, extract_star_data, ) from wf_psf.utils.read_config import RecursiveNamespace -import logging -from unittest.mock import patch def mock_sed(): @@ -151,12 +149,6 @@ def test_load_test_dataset_missing_stars(tmp_path, simPSF): data_handler.validate_and_process_dataset() -def test_get_np_obs_positions(mock_data): - observed_positions = get_np_obs_positions(mock_data) - expected_positions = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) - assert tf.reduce_all(tf.equal(observed_positions, expected_positions)) - - def test_extract_star_data_valid_keys(mock_data): """Test extracting valid data from the dataset.""" result = extract_star_data(mock_data, train_key="noisy_stars", test_key="stars") @@ -229,3 +221,141 @@ def test_reference_shifts_broadcasting(): # Test the result assert displacements.shape == shifts.shape, "Shapes do not match" assert np.all(displacements.shape == (2, 2400)), "Broadcasting failed" + + +@pytest.mark.parametrize( + "run_type,data_fixture,key,train_key,test_key,allow_missing,expect", + [ + # =================================================== + # training/simulation/metrics → extract_star_data path + # =================================================== + ( + "training", + "mock_data", + None, + "positions", + None, + False, + np.array([[1, 2], [3, 4], [5, 6], [7, 8]]), + ), + ( + "simulation", + "mock_data", + "none", + "noisy_stars", + "stars", + True, + # will concatenate noisy_stars from train and stars from test + # expected shape: (4, 5, 5) + # validate shape only, not full content (too large) + "shape:(4, 5, 5)", + ), + ( + "metrics", + "mock_data", + "zernike_prior", + None, + None, + True, + np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]]), + ), + # ================= + # inference (success) + # ================= + ( + "inference", + "mock_data_inference", + "positions", + None, + None, + False, + np.array([[9, 9], [10, 10]]), + ), + ( + "inference", + "mock_data_inference", + "zernike_prior", + None, + None, + False, + np.array([[0.9, 0.9]]), + ), + # ============================== + # inference → allow_missing=True + # ============================== + ( + "inference", + "mock_data_inference", + None, + None, + None, + True, + None, + ), + ( + "inference", + "mock_data_inference", + "missing_key", + None, + None, + True, + None, + ), + # ================================= + # inference → allow_missing=False → errors + # ================================= + ( + "inference", + "mock_data_inference", + None, + None, + None, + False, + pytest.raises(ValueError), + ), + ( + "inference", + "mock_data_inference", + "missing_key", + None, + None, + False, + pytest.raises(KeyError), + ), + ], +) +def test_get_data_array_v2( + request, run_type, data_fixture, key, train_key, test_key, allow_missing, expect +): + data = request.getfixturevalue(data_fixture) + + if hasattr(expect, "__enter__") and hasattr(expect, "__exit__"): + with expect: + get_data_array( + data, + run_type, + key=key, + train_key=train_key, + test_key=test_key, + allow_missing=allow_missing, + ) + return + + result = get_data_array( + data, + run_type, + key=key, + train_key=train_key, + test_key=test_key, + allow_missing=allow_missing, + ) + + if expect is None: + assert result is None + elif isinstance(expect, str) and expect.startswith("shape:"): + expected_shape = tuple(eval(expect.replace("shape:", ""))) + assert isinstance(result, np.ndarray) + assert result.shape == expected_shape + else: + assert isinstance(result, np.ndarray) + assert np.allclose(result, expect, rtol=1e-6, atol=1e-8) diff --git a/src/wf_psf/tests/test_data/data_zernike_utils_test.py b/src/wf_psf/tests/test_data/data_zernike_utils_test.py index 390d10f9..66d23309 100644 --- a/src/wf_psf/tests/test_data/data_zernike_utils_test.py +++ b/src/wf_psf/tests/test_data/data_zernike_utils_test.py @@ -11,7 +11,6 @@ compute_zernike_tip_tilt, pad_tf_zernikes, ) -from wf_psf.tests.test_data.test_data_utils import MockData from types import SimpleNamespace as RecursiveNamespace @@ -46,7 +45,27 @@ def test_training_without_prior(mock_model_params, mock_data): data=mock_data, run_type="training", model_params=mock_model_params ) - assert zinputs.centroid_dataset is mock_data + mock_data_stamps = np.concatenate( + [ + mock_data.training_data.dataset["noisy_stars"], + mock_data.test_data.dataset["stars"], + ] + ) + mock_data_masks = np.concatenate( + [ + mock_data.training_data.dataset["masks"], + mock_data.test_data.dataset["masks"], + ] + ) + + assert np.allclose( + zinputs.centroid_dataset["stamps"], mock_data_stamps, rtol=1e-6, atol=1e-8 + ) + + assert np.allclose( + zinputs.centroid_dataset["masks"], mock_data_masks, rtol=1e-6, atol=1e-8 + ) + assert zinputs.zernike_prior is None expected_positions = np.concatenate( @@ -88,7 +107,7 @@ def test_training_with_explicit_prior(mock_model_params, caplog): data, "training", mock_model_params, prior=explicit_prior ) - assert "Zernike prior explicitly provided" in caplog.text + assert "Explicit prior provided; ignoring dataset-based prior." in caplog.text assert (zinputs.zernike_prior == explicit_prior).all() @@ -103,7 +122,8 @@ def test_inference_with_dict_and_prior(mock_model_params): zinputs = ZernikeInputsFactory.build(data, "inference", mock_model_params) - assert zinputs.centroid_dataset is None + for key in ["stamps", "masks"]: + assert zinputs.centroid_dataset[key] is None # NumPy array comparison np.testing.assert_array_equal( @@ -397,7 +417,6 @@ def test_pad_zernikes_param_greater_than_prior(): def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_mask): """Test compute_zernike_tip_tilt handling with single batch input and mocks.""" - # Mock the CentroidEstimator class mock_centroid_calc = mocker.patch( "wf_psf.data.centroids.CentroidEstimator", autospec=True @@ -456,7 +475,6 @@ def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_ma def test_compute_zernike_tip_tilt_batch(mocker, multiple_images): """Test compute_zernike_tip_tilt batch handling of multiple inputs.""" - # Mock the CentroidEstimator class mock_centroid_calc = mocker.patch( "wf_psf.data.centroids.CentroidEstimator", autospec=True diff --git a/src/wf_psf/tests/test_inference/psf_inference_test.py b/src/wf_psf/tests/test_inference/psf_inference_test.py index 05728de7..28a2c1af 100644 --- a/src/wf_psf/tests/test_inference/psf_inference_test.py +++ b/src/wf_psf/tests/test_inference/psf_inference_test.py @@ -161,7 +161,6 @@ def psf_single_star_setup(mock_inference_config): def test_set_config_paths(mock_inference_config): """Test setting configuration paths.""" - # Initialize handler and inject mock config config_handler = InferenceConfigHandler("fake/path") config_handler.inference_config = mock_inference_config From 76c17f323bc53bc9df9162b106113300fbad417a Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 31 Oct 2025 12:36:32 +0100 Subject: [PATCH 110/146] Reformat with black --- src/wf_psf/data/data_handler.py | 41 ++- src/wf_psf/data/data_zernike_utils.py | 2 - src/wf_psf/data/old_zernike_prior.py | 335 ++++++++++++++++++ src/wf_psf/instrument/ccd_misalignments.py | 6 - src/wf_psf/metrics/metrics.py | 4 +- src/wf_psf/plotting/plots_interface.py | 2 + .../psf_models/models/psf_model_parametric.py | 1 - .../psf_model_physical_polychromatic.py | 3 - src/wf_psf/psf_models/psf_model_loader.py | 1 - src/wf_psf/psf_models/tf_modules/tf_layers.py | 10 +- .../psf_models/tf_modules/tf_psf_field.py | 4 +- src/wf_psf/psf_models/tf_modules/tf_utils.py | 1 - src/wf_psf/run.py | 4 +- src/wf_psf/sims/psf_simulator.py | 6 +- src/wf_psf/sims/spatial_varying_psf.py | 5 +- src/wf_psf/tests/test_metrics/conftest.py | 1 - .../test_metrics/metrics_interface_test.py | 2 +- .../tests/test_training/train_utils_test.py | 2 - .../tests/test_utils/configs_handler_test.py | 1 - src/wf_psf/tests/test_utils/conftest.py | 1 - src/wf_psf/training/train.py | 1 - src/wf_psf/utils/configs_handler.py | 10 +- src/wf_psf/utils/graph_utils.py | 8 +- src/wf_psf/utils/io.py | 2 - src/wf_psf/utils/read_config.py | 15 +- src/wf_psf/utils/utils.py | 8 +- 26 files changed, 383 insertions(+), 93 deletions(-) create mode 100644 src/wf_psf/data/old_zernike_prior.py diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index bdcf9a6b..052fe730 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -118,7 +118,6 @@ def __init__( `load_data=True` is used. - TensorFlow conversion is performed at the end of initialization via `convert_dataset_to_tensorflow()`. """ - self.dataset_type = dataset_type self.data_params = data_params self.simPSF = simPSF @@ -182,7 +181,6 @@ def _validate_dataset_structure(self): def _convert_dataset_to_tensorflow(self): """Convert dataset to TensorFlow tensors.""" - self.dataset["positions"] = ensure_tensor( self.dataset["positions"], dtype=tf.float32 ) @@ -244,8 +242,8 @@ def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: """ Extract and concatenate star-related data from training and test datasets. - This function retrieves arrays (e.g., postage stamps, masks, positions) from - both the training and test datasets using the specified keys, converts them + This function retrieves arrays (e.g., postage stamps, masks, positions) from + both the training and test datasets using the specified keys, converts them to NumPy if necessary, and concatenates them along the first axis. Parameters @@ -253,27 +251,27 @@ def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: data : DataConfigHandler Object containing training and test datasets. train_key : str - Key to retrieve data from the training dataset + Key to retrieve data from the training dataset (e.g., 'noisy_stars', 'masks'). test_key : str - Key to retrieve data from the test dataset + Key to retrieve data from the test dataset (e.g., 'stars', 'masks'). Returns ------- np.ndarray - Concatenated NumPy array containing the selected data from both + Concatenated NumPy array containing the selected data from both training and test sets. Raises ------ KeyError - If either the training or test dataset does not contain the + If either the training or test dataset does not contain the requested key. Notes ----- - - Designed for datasets with separate train/test splits, such as when + - Designed for datasets with separate train/test splits, such as when evaluating metrics on held-out data. - TensorFlow tensors are automatically converted to NumPy arrays. - Requires eager execution if TensorFlow tensors are present. @@ -304,6 +302,7 @@ def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: # Concatenate and return return np.concatenate((train_data, test_data), axis=0) + def get_data_array( data, run_type: str, @@ -334,7 +333,7 @@ def get_data_array( train_key : str, optional Key for training dataset access. If None and key is provided, defaults to key value. Default is None. - test_key : str, optional + test_key : str, optional Key for test dataset access. If None, defaults to the resolved train_key value. Default is None. allow_missing : bool, default True @@ -355,12 +354,12 @@ def get_data_array( resolved for the operation and allow_missing=False. KeyError If the specified key is not found in the dataset and allow_missing=False. - + Notes ----- Key resolution follows this priority order: 1. train_key = train_key or key - 2. test_key = test_key or resolved_train_key + 2. test_key = test_key or resolved_train_key 3. key = key or resolved_train_key (for inference fallback) For TensorFlow tensors, the .numpy() method is called to convert to NumPy. @@ -370,13 +369,13 @@ def get_data_array( -------- >>> # Training data retrieval >>> train_data = get_data_array(data, "training", train_key="noisy_stars") - + >>> # Inference with fallback handling - >>> inference_data = get_data_array(data, "inference", key="positions", + >>> inference_data = get_data_array(data, "inference", key="positions", ... allow_missing=True) >>> if inference_data is None: ... print("No inference data available") - + >>> # Using key parameter for both train and inference >>> result = get_data_array(data, "inference", key="positions") """ @@ -384,18 +383,18 @@ def get_data_array( valid_run_types = {"training", "simulation", "metrics", "inference"} if run_type not in valid_run_types: raise ValueError(f"run_type must be one of {valid_run_types}, got '{run_type}'") - + # Simplify key resolution with clear precedence effective_train_key = train_key or key effective_test_key = test_key or effective_train_key effective_key = key or effective_train_key - + try: if run_type in {"simulation", "training", "metrics"}: return extract_star_data(data, effective_train_key, effective_test_key) else: # inference return _get_direct_data(data, effective_key, allow_missing) - except Exception as e: + except Exception: if allow_missing: return None raise @@ -417,7 +416,7 @@ def _get_direct_data(data, key: str, allow_missing: bool) -> Optional[np.ndarray Returns ------- np.ndarray or None - Data converted to NumPy array, or None if allow_missing=True and + Data converted to NumPy array, or None if allow_missing=True and data is unavailable. Raises @@ -437,11 +436,11 @@ def _get_direct_data(data, key: str, allow_missing: bool) -> Optional[np.ndarray if allow_missing: return None raise ValueError("No key provided for inference data") - + value = data.dataset.get(key, None) if value is None: if allow_missing: return None raise KeyError(f"Key '{key}' not found in inference dataset") - + return value.numpy() if tf.is_tensor(value) else np.asarray(value) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index 399ee9ef..0fad9c8e 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -181,7 +181,6 @@ def pad_tf_zernikes(zk_param: tf.Tensor, zk_prior: tf.Tensor, n_zks_total: int): padded_zk_prior : tf.Tensor Padded Zernike coefficients for the prior part. Shape [batch, n_zks_total, 1, 1]. """ - pad_num_param = n_zks_total - tf.shape(zk_param)[1] pad_num_prior = n_zks_total - tf.shape(zk_prior)[1] @@ -230,7 +229,6 @@ def assemble_zernike_contributions( tf.Tensor A tensor representing the full Zernike contribution map. """ - zernike_contribution_list = [] # Prior diff --git a/src/wf_psf/data/old_zernike_prior.py b/src/wf_psf/data/old_zernike_prior.py new file mode 100644 index 00000000..0feb3e70 --- /dev/null +++ b/src/wf_psf/data/old_zernike_prior.py @@ -0,0 +1,335 @@ +import numpy as np +import tensorflow as tf +from wf_psf.instrument.ccd_misalignments import CCDMisalignmentCalculator +from wf_psf.data.centroids import compute_zernike_tip_tilt +from fractions import Fraction +import logging + +logger = logging.getLogger(__name__) + + +def get_np_obs_positions(data): + """Get observed positions in numpy from the provided dataset. + + This method concatenates the positions of the stars from both the training + and test datasets to obtain the observed positions. + + Parameters + ---------- + data : DataConfigHandler + Object containing training and test datasets. + + Returns + ------- + np.ndarray + Numpy array containing the observed positions of the stars. + + Notes + ----- + The observed positions are obtained by concatenating the positions of stars + from both the training and test datasets along the 0th axis. + """ + obs_positions = np.concatenate( + ( + data.training_data.dataset["positions"], + data.test_data.dataset["positions"], + ), + axis=0, + ) + + return obs_positions + + +def get_obs_positions(data): + """Get observed positions from the provided dataset. + + Parameters + ---------- + data : DataConfigHandler + Object containing training and test datasets. + + Returns + ------- + tf.Tensor + Tensor containing the observed positions of the stars. + """ + obs_positions = get_np_obs_positions(data) + + return tf.convert_to_tensor(obs_positions, dtype=tf.float32) + + +def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: + """Extract specific star-related data from training and test datasets. + + This function retrieves and concatenates specific star-related data (e.g., stars, masks) from the + star training and test datasets such as star stamps or masks, based on the provided keys. + + Parameters + ---------- + data : DataConfigHandler + Object containing training and test datasets. + train_key : str + The key to retrieve data from the training dataset (e.g., 'noisy_stars', 'masks'). + test_key : str + The key to retrieve data from the test dataset (e.g., 'stars', 'masks'). + + Returns + ------- + np.ndarray + A NumPy array containing the concatenated data for the given keys. + + Raises + ------ + KeyError + If the specified keys do not exist in the training or test datasets. + + Notes + ----- + - If the dataset contains TensorFlow tensors, they will be converted to NumPy arrays. + - Ensure that eager execution is enabled when calling this function. + """ + # Ensure the requested keys exist in both training and test datasets + missing_keys = [ + key + for key, dataset in [ + (train_key, data.training_data.dataset), + (test_key, data.test_data.dataset), + ] + if key not in dataset + ] + + if missing_keys: + raise KeyError(f"Missing keys in dataset: {missing_keys}") + + # Retrieve data from training and test sets + train_data = data.training_data.dataset[train_key] + test_data = data.test_data.dataset[test_key] + + # Convert to NumPy if necessary + if tf.is_tensor(train_data): + train_data = train_data.numpy() + if tf.is_tensor(test_data): + test_data = test_data.numpy() + + # Concatenate and return + return np.concatenate((train_data, test_data), axis=0) + + +def get_np_zernike_prior(data): + """Get the zernike prior from the provided dataset. + + This method concatenates the stars from both the training + and test datasets to obtain the full prior. + + Parameters + ---------- + data : DataConfigHandler + Object containing training and test datasets. + + Returns + ------- + zernike_prior : np.ndarray + Numpy array containing the full prior. + """ + zernike_prior = np.concatenate( + ( + data.training_data.dataset["zernike_prior"], + data.test_data.dataset["zernike_prior"], + ), + axis=0, + ) + + return zernike_prior + + +def compute_centroid_correction(model_params, data, batch_size: int = 1) -> np.ndarray: + """Compute centroid corrections using Zernike polynomials. + + This function calculates the Zernike contributions required to match the centroid + of the WaveDiff PSF model to the observed star centroids, processing in batches. + + Parameters + ---------- + model_params : RecursiveNamespace + An object containing parameters for the PSF model, including pixel sampling and initial centroid window parameters. + + data : DataConfigHandler + An object containing training and test datasets, including observed PSFs + and optional star masks. + + batch_size : int, optional + The batch size to use when processing the stars. Default is 16. + + + Returns + ------- + zernike_centroid_array : np.ndarray + A 2D NumPy array of shape `(n_stars, 3)`, where `n_stars` is the number of + observed stars. The array contains the computed Zernike contributions, + with zero padding applied to the first column to ensure a consistent shape. + """ + star_postage_stamps = extract_star_data( + data=data, train_key="noisy_stars", test_key="stars" + ) + + # Get star mask catalogue only if "masks" exist in both training and test datasets + star_masks = ( + extract_star_data(data=data, train_key="masks", test_key="masks") + if ( + data.training_data.dataset.get("masks") is not None + and data.test_data.dataset.get("masks") is not None + and tf.size(data.training_data.dataset["masks"]) > 0 + and tf.size(data.test_data.dataset["masks"]) > 0 + ) + else None + ) + + pix_sampling = model_params.pix_sampling * 1e-6 # Change units from [um] to [m] + + # Ensure star_masks is properly handled + star_masks = star_masks if star_masks is not None else None + + reference_shifts = [ + float(Fraction(value)) for value in model_params.reference_shifts + ] + + n_stars = len(star_postage_stamps) + zernike_centroid_array = [] + + # Batch process the stars + for i in range(0, n_stars, batch_size): + batch_postage_stamps = star_postage_stamps[i : i + batch_size] + batch_masks = star_masks[i : i + batch_size] if star_masks is not None else None + + # Compute Zernike 1 and Zernike 2 for the batch + zk1_2_batch = -1.0 * compute_zernike_tip_tilt( + batch_postage_stamps, batch_masks, pix_sampling, reference_shifts + ) + + # Zero pad array for each batch and append + zernike_centroid_array.append( + np.pad( + zk1_2_batch, + pad_width=[(0, 0), (1, 0)], + mode="constant", + constant_values=0, + ) + ) + + # Combine all batches into a single array + return np.concatenate(zernike_centroid_array, axis=0) + + +def compute_ccd_misalignment(model_params, data): + """Compute CCD misalignment. + + Parameters + ---------- + model_params : RecursiveNamespace + Object containing parameters for this PSF model class. + data : DataConfigHandler + Object containing training and test datasets. + + Returns + ------- + zernike_ccd_misalignment_array : np.ndarray + Numpy array containing the Zernike contributions to model the CCD chip misalignments. + """ + obs_positions = get_np_obs_positions(data) + + ccd_misalignment_calculator = CCDMisalignmentCalculator( + tiles_path=model_params.ccd_misalignments_input_path, + x_lims=model_params.x_lims, + y_lims=model_params.y_lims, + tel_focal_length=model_params.tel_focal_length, + tel_diameter=model_params.tel_diameter, + ) + # Compute required zernike 4 for each position + zk4_values = np.array( + [ + ccd_misalignment_calculator.get_zk4_from_position(single_pos) + for single_pos in obs_positions + ] + ).reshape(-1, 1) + + # Zero pad array to get shape (n_stars, n_zernike=4) + zernike_ccd_misalignment_array = np.pad( + zk4_values, pad_width=[(0, 0), (3, 0)], mode="constant", constant_values=0 + ) + + return zernike_ccd_misalignment_array + + +def get_zernike_prior(model_params, data, batch_size: int = 16): + """Get Zernike priors from the provided dataset. + + This method concatenates the Zernike priors from both the training + and test datasets. + + Parameters + ---------- + model_params : RecursiveNamespace + Object containing parameters for this PSF model class. + data : DataConfigHandler + Object containing training and test datasets. + batch_size : int, optional + The batch size to use when processing the stars. Default is 16. + + Returns + ------- + tf.Tensor + Tensor containing the observed positions of the stars. + + Notes + ----- + The Zernike prior are obtained by concatenating the Zernike priors + from both the training and test datasets along the 0th axis. + + """ + # List of zernike contribution + zernike_contribution_list = [] + + if model_params.use_prior: + logger.info("Reading in Zernike prior into Zernike contribution list...") + zernike_contribution_list.append(get_np_zernike_prior(data)) + + if model_params.correct_centroids: + logger.info("Adding centroid correction to Zernike contribution list...") + zernike_contribution_list.append( + compute_centroid_correction(model_params, data, batch_size) + ) + + if model_params.add_ccd_misalignments: + logger.info("Adding CCD mis-alignments to Zernike contribution list...") + zernike_contribution_list.append(compute_ccd_misalignment(model_params, data)) + + if len(zernike_contribution_list) == 1: + zernike_contribution = zernike_contribution_list[0] + else: + # Get max zk order + max_zk_order = np.max( + np.array( + [ + zk_contribution.shape[1] + for zk_contribution in zernike_contribution_list + ] + ) + ) + + zernike_contribution = np.zeros( + (zernike_contribution_list[0].shape[0], max_zk_order) + ) + + # Pad arrays to get the same length and add the final contribution + for it in range(len(zernike_contribution_list)): + current_zk_order = zernike_contribution_list[it].shape[1] + current_zernike_contribution = np.pad( + zernike_contribution_list[it], + pad_width=[(0, 0), (0, int(max_zk_order - current_zk_order))], + mode="constant", + constant_values=0, + ) + + zernike_contribution += current_zernike_contribution + + return tf.convert_to_tensor(zernike_contribution, dtype=tf.float32) diff --git a/src/wf_psf/instrument/ccd_misalignments.py b/src/wf_psf/instrument/ccd_misalignments.py index 9f2fb221..da21f630 100644 --- a/src/wf_psf/instrument/ccd_misalignments.py +++ b/src/wf_psf/instrument/ccd_misalignments.py @@ -162,7 +162,6 @@ def _preprocess_tile_data(self) -> None: def _initialize_polygons(self): """Initialize polygons to look for CCD IDs""" - # Build polygon list corresponding to each CCD self.ccd_polygons = [] @@ -238,7 +237,6 @@ def scale_position_to_tile_reference(self, pos): Focal plane position in wavediff coordinate system respecting `self.x_lims` and `self.y_lims`. Shape: (2,) """ - self.check_position_wavediff_limits(pos) pos_x = pos[0] @@ -266,7 +264,6 @@ def scale_position_to_wavediff_reference(self, pos): pos : np.ndarray Tile position in input tile coordinate system. Shape: (2,) """ - self.check_position_tile_limits(pos) pos_x = pos[0] @@ -286,7 +283,6 @@ def scale_position_to_wavediff_reference(self, pos): def check_position_wavediff_limits(self, pos): """Check if position is within wavediff limits.""" - if (pos[0] < self.x_lims[0] or pos[0] > self.x_lims[1]) or ( pos[1] < self.y_lims[0] or pos[1] > self.y_lims[1] ): @@ -296,7 +292,6 @@ def check_position_wavediff_limits(self, pos): def check_position_tile_limits(self, pos): """Check if position is within tile limits.""" - if (pos[0] < self.tiles_x_lims[0] or pos[0] > self.tiles_x_lims[1]) or ( pos[1] < self.tiles_y_lims[0] or pos[1] > self.tiles_y_lims[1] ): @@ -417,7 +412,6 @@ def compute_z_from_plane_data(pos, normal, d): d : np.ndarray `d` value from the plane ecuation. Shape (3,) """ - z = (-normal[0] * pos[0] - normal[1] * pos[1] - d) * 1.0 / normal[2] return z diff --git a/src/wf_psf/metrics/metrics.py b/src/wf_psf/metrics/metrics.py index 3bf671a4..4324d6e6 100644 --- a/src/wf_psf/metrics/metrics.py +++ b/src/wf_psf/metrics/metrics.py @@ -216,8 +216,8 @@ def compute_mono_metric( lambda_obs = lambda_list[it] phase_N = simPSF_np.feasible_N(lambda_obs) - residuals = np.zeros((total_samples)) - gt_star_mean = np.zeros((total_samples)) + residuals = np.zeros(total_samples) + gt_star_mean = np.zeros(total_samples) # Total number of epochs n_epochs = int(np.ceil(total_samples / batch_size)) diff --git a/src/wf_psf/plotting/plots_interface.py b/src/wf_psf/plotting/plots_interface.py index 3ab6ab78..613d488e 100644 --- a/src/wf_psf/plotting/plots_interface.py +++ b/src/wf_psf/plotting/plots_interface.py @@ -369,6 +369,7 @@ class ShapeMetricsPlotHandler: """ShapeMetricsPlotHandler class. A class to handle plot parameters shape metrics results. + Parameters ---------- id: str @@ -526,6 +527,7 @@ def get_number_of_stars(metrics): ---------- metrics: dict A dictionary containig the metrics results per run + Returns ------- list_of_stars: list diff --git a/src/wf_psf/psf_models/models/psf_model_parametric.py b/src/wf_psf/psf_models/models/psf_model_parametric.py index 5c643a3a..95a11615 100644 --- a/src/wf_psf/psf_models/models/psf_model_parametric.py +++ b/src/wf_psf/psf_models/models/psf_model_parametric.py @@ -166,7 +166,6 @@ def predict_mono_psfs(self, input_positions, lambda_obs, phase_N): ``simPSF_np = wf_psf.sims.psf_simulator.PSFSimulator(...)`` ``phase_N = simPSF_np.feasible_N(lambda_obs)`` """ - # Initialise the monochromatic PSF batch calculator tf_batch_mono_psf = TFBatchMonochromaticPSF( obscurations=self.obscurations, diff --git a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index e2c84868..fb8bc902 100644 --- a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -26,8 +26,6 @@ TFPhysicalLayer, ) from wf_psf.psf_models.tf_modules.tf_utils import ensure_tensor -from wf_psf.utils.read_config import RecursiveNamespace -from wf_psf.utils.configs_handler import DataConfigHandler import logging @@ -282,7 +280,6 @@ def tf_zernike_OPD(self): def _build_tf_batch_poly_PSF(self): """Eagerly build the TFBatchPolychromaticPSF layer with numpy-based obscurations.""" - return TFBatchPolychromaticPSF( obscurations=self.obscurations, output_Q=self.output_Q, diff --git a/src/wf_psf/psf_models/psf_model_loader.py b/src/wf_psf/psf_models/psf_model_loader.py index e445f7af..e41e3536 100644 --- a/src/wf_psf/psf_models/psf_model_loader.py +++ b/src/wf_psf/psf_models/psf_model_loader.py @@ -10,7 +10,6 @@ import logging from wf_psf.psf_models.psf_models import get_psf_model, get_psf_model_weights_filepath -import tensorflow as tf logger = logging.getLogger(__name__) diff --git a/src/wf_psf/psf_models/tf_modules/tf_layers.py b/src/wf_psf/psf_models/tf_modules/tf_layers.py index 98d450f2..30de4804 100644 --- a/src/wf_psf/psf_models/tf_modules/tf_layers.py +++ b/src/wf_psf/psf_models/tf_modules/tf_layers.py @@ -189,7 +189,6 @@ def calculate_monochromatic_PSF(self, packed_elems): def calculate_polychromatic_PSF(self, packed_elems): """Calculate a polychromatic PSF.""" - self.current_opd = packed_elems[0][tf.newaxis, :, :] SED_pack_data = packed_elems[1] @@ -214,7 +213,6 @@ def _calculate_polychromatic_PSF(elems_to_unpack): def call(self, inputs): """Calculate the batch polychromatic PSFs.""" - # Unpack Inputs opd_batch = inputs[0] packed_SED_data = inputs[1] @@ -299,7 +297,6 @@ def set_output_params(self, output_Q, output_dim): def call(self, opd_batch): """Calculate the batch poly PSFs.""" - if self.phase_N is None: self.set_lambda_phaseN() @@ -312,7 +309,7 @@ def _calculate_PSF_batch(elems_to_unpack): swap_memory=True, ) - mono_psf_batch = _calculate_PSF_batch((opd_batch)) + mono_psf_batch = _calculate_PSF_batch(opd_batch) return mono_psf_batch @@ -320,7 +317,6 @@ def _calculate_PSF_batch(elems_to_unpack): class TFNonParametricPolynomialVariationsOPD(tf.keras.layers.Layer): """Non-parametric OPD generation with polynomial variations. - Parameters ---------- x_lims: [int, int] @@ -424,7 +420,6 @@ def call(self, positions): class TFNonParametricMCCDOPDv2(tf.keras.layers.Layer): """Non-parametric OPD generation with hybrid-MCCD variations. - Parameters ---------- obs_pos: tensor(n_stars, 2) @@ -642,7 +637,6 @@ def calc_index(idx_pos): class TFNonParametricGraphOPD(tf.keras.layers.Layer): """Non-parametric OPD generation with only graph-cosntraint variations. - Parameters ---------- obs_pos: tensor(n_stars, 2) @@ -750,7 +744,6 @@ def set_alpha_identity(self): def predict(self, positions): """Prediction step.""" - ## Graph part A_graph_train = tf.linalg.matmul(self.graph_dic, self.alpha_graph) # RBF interpolation @@ -960,7 +953,6 @@ def call(self, positions): If the shape of the input `positions` tensor is not compatible. """ - # Find indices for all positions in one batch operation idx = find_position_indices(self.obs_pos, positions) diff --git a/src/wf_psf/psf_models/tf_modules/tf_psf_field.py b/src/wf_psf/psf_models/tf_modules/tf_psf_field.py index c94ca2e5..86575298 100644 --- a/src/wf_psf/psf_models/tf_modules/tf_psf_field.py +++ b/src/wf_psf/psf_models/tf_modules/tf_psf_field.py @@ -222,7 +222,9 @@ def __init__(self, model_params, training_params, data, coeff_mat): self.output_Q = model_params.output_Q # Inputs: TF_physical_layer - self.obs_pos = ensure_tensor(get_data_array(data, data.run_type, key="positions"), dtype=tf.float32) + self.obs_pos = ensure_tensor( + get_data_array(data, data.run_type, key="positions"), dtype=tf.float32 + ) self.zks_prior = get_ground_truth_zernike(data) self.n_zks_prior = tf.shape(self.zks_prior)[1].numpy() diff --git a/src/wf_psf/psf_models/tf_modules/tf_utils.py b/src/wf_psf/psf_models/tf_modules/tf_utils.py index 09540e60..4bd1246a 100644 --- a/src/wf_psf/psf_models/tf_modules/tf_utils.py +++ b/src/wf_psf/psf_models/tf_modules/tf_utils.py @@ -16,7 +16,6 @@ """ import tensorflow as tf -import numpy as np @tf.function diff --git a/src/wf_psf/run.py b/src/wf_psf/run.py index 828a6e17..171519bb 100644 --- a/src/wf_psf/run.py +++ b/src/wf_psf/run.py @@ -93,9 +93,7 @@ def mainMethod(): except Exception as e: logger.error( - "Check your config file {} for errors. Error Msg: {}.".format( - args.conffile, e - ), + f"Check your config file {args.conffile} for errors. Error Msg: {e}.", exc_info=True, ) diff --git a/src/wf_psf/sims/psf_simulator.py b/src/wf_psf/sims/psf_simulator.py index 9b7b58af..dec06ff3 100644 --- a/src/wf_psf/sims/psf_simulator.py +++ b/src/wf_psf/sims/psf_simulator.py @@ -19,7 +19,7 @@ print("Problem importing skimage..") -class PSFSimulator(object): +class PSFSimulator: """Simulate PSFs. In the future the zernike maps could be created with galsim or some other @@ -198,7 +198,7 @@ def fft_diffract(wf, output_Q, output_dim=64): start = int(psf.shape[0] // 2 - (output_dim * output_Q) // 2) stop = int(psf.shape[0] // 2 + (output_dim * output_Q) // 2) else: - start = int(0) + start = 0 stop = psf.shape[0] # Crop psf @@ -366,7 +366,6 @@ def decimate_im(input_im, decim_f): Based on the PIL library using the default interpolator. """ - pil_im = PIL.Image.fromarray(input_im) (width, height) = (pil_im.width // decim_f, pil_im.height // decim_f) im_resized = pil_im.resize((width, height)) @@ -593,7 +592,6 @@ def calculate_wfe_rms(self, z_coeffs=None): def check_wfe_rms(self, z_coeffs=None, max_wfe_rms=None): """Check if Zernike coefficients are within the maximum admitted error.""" - if max_wfe_rms is None: max_wfe_rms = self.max_wfe_rms diff --git a/src/wf_psf/sims/spatial_varying_psf.py b/src/wf_psf/sims/spatial_varying_psf.py index 33548d7f..f343b025 100644 --- a/src/wf_psf/sims/spatial_varying_psf.py +++ b/src/wf_psf/sims/spatial_varying_psf.py @@ -240,7 +240,6 @@ def check_position_coordinate_limits(xv, yv, x_lims, y_lims, verbose): None """ - x_check = np.sum(xv >= x_lims[1] * 1.1) + np.sum(xv <= x_lims[0] * 1.1) y_check = np.sum(yv >= y_lims[1] * 1.1) + np.sum(yv <= y_lims[0] * 1.1) @@ -480,7 +479,7 @@ def calculate_zernike( ) -class SpatialVaryingPSF(object): +class SpatialVaryingPSF: """Spatial Varying PSF. Generate PSF field with polynomial variations of Zernike coefficients. @@ -621,7 +620,6 @@ def calculate_wfe_rms(self, xv, yv, polynomial_coeffs): numpy.ndarray An array containing the WFE RMS values for the provided positions. """ - Z = ZernikeHelper.generate_zernike_polynomials( xv, yv, self.x_lims, self.y_lims, self.d_max, polynomial_coeffs ) @@ -645,7 +643,6 @@ def build_polynomial_coeffs(self): ------- None """ - # Build mesh xv_grid, yv_grid = MeshHelper.build_mesh( self.x_lims, self.y_lims, self.grid_points diff --git a/src/wf_psf/tests/test_metrics/conftest.py b/src/wf_psf/tests/test_metrics/conftest.py index f1fe85de..b32dfbf1 100644 --- a/src/wf_psf/tests/test_metrics/conftest.py +++ b/src/wf_psf/tests/test_metrics/conftest.py @@ -10,7 +10,6 @@ import pytest from unittest.mock import patch, MagicMock -import numpy as np import tensorflow as tf diff --git a/src/wf_psf/tests/test_metrics/metrics_interface_test.py b/src/wf_psf/tests/test_metrics/metrics_interface_test.py index c9084503..b994196e 100644 --- a/src/wf_psf/tests/test_metrics/metrics_interface_test.py +++ b/src/wf_psf/tests/test_metrics/metrics_interface_test.py @@ -1,6 +1,6 @@ from unittest.mock import patch, MagicMock import pytest -from wf_psf.metrics.metrics_interface import evaluate_model, MetricsParamsHandler +from wf_psf.metrics.metrics_interface import evaluate_model from wf_psf.data.data_handler import DataHandler diff --git a/src/wf_psf/tests/test_training/train_utils_test.py b/src/wf_psf/tests/test_training/train_utils_test.py index 8cefa5f6..703531c8 100644 --- a/src/wf_psf/tests/test_training/train_utils_test.py +++ b/src/wf_psf/tests/test_training/train_utils_test.py @@ -150,7 +150,6 @@ def test_calculate_sample_weights_integration( use_sample_weights, loss, expected_output_type ): """Test different cases for sample weight computation.""" - # Generate dummy image data batch_size, height, width = 5, 32, 32 @@ -416,7 +415,6 @@ def test_general_train_cycle_with_callbacks( mock_test_setup, cycle_def, param_callback, non_param_callback, general_callback ): """Test general_train_cycle with different cycle_def and callback configurations.""" - # Unpack test setup mock_model = mock_test_setup["mock_model"] diff --git a/src/wf_psf/tests/test_utils/configs_handler_test.py b/src/wf_psf/tests/test_utils/configs_handler_test.py index 8898df91..071181c3 100644 --- a/src/wf_psf/tests/test_utils/configs_handler_test.py +++ b/src/wf_psf/tests/test_utils/configs_handler_test.py @@ -13,7 +13,6 @@ from wf_psf.utils.io import FileIOHandler from wf_psf.utils.configs_handler import ( TrainingConfigHandler, - MetricsConfigHandler, DataConfigHandler, ) import os diff --git a/src/wf_psf/tests/test_utils/conftest.py b/src/wf_psf/tests/test_utils/conftest.py index abcfa93f..caee7878 100644 --- a/src/wf_psf/tests/test_utils/conftest.py +++ b/src/wf_psf/tests/test_utils/conftest.py @@ -10,7 +10,6 @@ import pytest import os -from wf_psf.utils.read_config import RecursiveNamespace from wf_psf.utils.io import FileIOHandler cwd = os.getcwd() diff --git a/src/wf_psf/training/train.py b/src/wf_psf/training/train.py index 38161edc..f3c11d6e 100644 --- a/src/wf_psf/training/train.py +++ b/src/wf_psf/training/train.py @@ -299,7 +299,6 @@ def get_loss_metrics_monitor_and_outputs(training_handler, data_conf): Tensor containing the outputs for validation """ - if training_handler.training_hparams.loss == "mask_mse": loss = train_utils.MaskedMeanSquaredError() monitor = "loss" diff --git a/src/wf_psf/utils/configs_handler.py b/src/wf_psf/utils/configs_handler.py index c8c4ea3a..6971a6e9 100644 --- a/src/wf_psf/utils/configs_handler.py +++ b/src/wf_psf/utils/configs_handler.py @@ -209,7 +209,6 @@ def run(self): input configuration. """ - train.train( self.training_conf.training, self.data_conf, @@ -581,7 +580,6 @@ def _metrics_run_id_name(self, wf_outdir, metrics_params): metrics_run_id_name: list List containing the model name and id for each training run """ - try: training_conf = read_conf( os.path.join( @@ -596,9 +594,7 @@ def _metrics_run_id_name(self, wf_outdir, metrics_params): except (TypeError, FileNotFoundError): logger.info("Trained model path not provided...") logger.info( - "Trying to retrieve training config file from workdir: {}".format( - wf_outdir - ) + f"Trying to retrieve training config file from workdir: {wf_outdir}" ) training_confs = [ @@ -649,9 +645,7 @@ def load_metrics_into_dict(self): "metrics-" + run_id_name + ".npy", ) logger.info( - "Attempting to read in trained model config file...{}".format( - output_path - ) + f"Attempting to read in trained model config file...{output_path}" ) try: metrics_dict[k].append( diff --git a/src/wf_psf/utils/graph_utils.py b/src/wf_psf/utils/graph_utils.py index 1f6c5463..96f9b270 100644 --- a/src/wf_psf/utils/graph_utils.py +++ b/src/wf_psf/utils/graph_utils.py @@ -1,7 +1,7 @@ import numpy as np -class GraphBuilder(object): +class GraphBuilder: r"""GraphBuilder class. This class computes the necessary quantities for RCA's graph constraint. @@ -112,10 +112,10 @@ def _build_graphs(self): R -= vect.T.dot(vect.dot(R)) if self.verbose: print( - " > selected e: {}\tselected a:".format(e) - + "{}\t chosen index: {}/{}".format(a, j, self.n_eigenvects) + f" > selected e: {e}\tselected a:" + + f"{a}\t chosen index: {j}/{self.n_eigenvects}" ) - self.VT = np.vstack((eigenvect for eigenvect in list_eigenvects)) + self.VT = np.vstack(eigenvect for eigenvect in list_eigenvects) self.alpha = np.zeros((self.n_comp, self.VT.shape[0])) for i in range(self.n_comp): self.alpha[i, i * self.n_eigenvects + idx[i]] = 1 diff --git a/src/wf_psf/utils/io.py b/src/wf_psf/utils/io.py index b49c9a47..b7444c7b 100644 --- a/src/wf_psf/utils/io.py +++ b/src/wf_psf/utils/io.py @@ -118,7 +118,6 @@ def get_timestamp(self): timestamp: str A string representation of the date and time. """ - timestamp = datetime.now().strftime("%Y%m%d%H%M") return timestamp @@ -190,7 +189,6 @@ def copy_conffile_to_output_dir(self, source_file): source_file: str Name of source file """ - source = os.path.join(self.config_path, source_file) destination = os.path.join( self.get_config_dir(self._run_output_dir), source_file diff --git a/src/wf_psf/utils/read_config.py b/src/wf_psf/utils/read_config.py index 922c34b6..7282d4c4 100644 --- a/src/wf_psf/utils/read_config.py +++ b/src/wf_psf/utils/read_config.py @@ -56,7 +56,6 @@ def map_entry(entry): entry: type Original type of entry if type is not a dictionary """ - if isinstance(entry, dict): return RecursiveNamespace(**entry) @@ -100,21 +99,19 @@ def read_conf(conf_file): Recursive Namespace object """ - logger.info("Loading...{}".format(conf_file)) - with open(conf_file, "r") as f: + logger.info(f"Loading...{conf_file}") + with open(conf_file) as f: try: my_conf = yaml.safe_load(f) except (ParserError, ScannerError, TypeError): logger.exception( - "There is a syntax problem with your config file. Please check {}.".format( - conf_file - ) + f"There is a syntax problem with your config file. Please check {conf_file}." ) exit() if my_conf is None: raise TypeError( - "Config file {} is empty...Stopping Program.".format(conf_file) + f"Config file {conf_file} is empty...Stopping Program." ) exit() @@ -122,7 +119,7 @@ def read_conf(conf_file): return RecursiveNamespace(**my_conf) except TypeError as e: logger.exception( - "Check your config file for errors. Error Msg: {}.".format(e) + f"Check your config file for errors. Error Msg: {e}." ) exit() @@ -143,7 +140,7 @@ def read_stream(conf_file): A dictionary containing all config files. """ - stream = open(conf_file, "r") + stream = open(conf_file) docs = yaml.load_all(stream, yaml.FullLoader) for doc in docs: diff --git a/src/wf_psf/utils/utils.py b/src/wf_psf/utils/utils.py index 7c8d5e41..eeb67357 100644 --- a/src/wf_psf/utils/utils.py +++ b/src/wf_psf/utils/utils.py @@ -1,5 +1,5 @@ import numpy as np -from typing import Optional, Tuple +from typing import Tuple import tensorflow as tf import tensorflow_addons as tfa import PIL @@ -114,7 +114,6 @@ def generate_SED_elems(SED, sim_psf_toolkit, n_bins=20): n_bins: int Number of wavelength bins """ - feasible_wv, SED_norm = sim_psf_toolkit.calc_SED_wave_values(SED, n_bins) feasible_N = np.array([sim_psf_toolkit.feasible_N(_wv) for _wv in feasible_wv]) @@ -140,7 +139,6 @@ def generate_SED_elems_in_tensorflow( tf_dtype: tf. Tensor Flow data type """ - feasible_wv, SED_norm = sim_psf_toolkit.calc_SED_wave_values(SED, n_bins) feasible_N = np.array([sim_psf_toolkit.feasible_N(_wv) for _wv in feasible_wv]) @@ -405,7 +403,7 @@ def estimate_noise(self, image: np.ndarray, mask: np.ndarray = None) -> float: return self.sigma_mad(image[self.window]) -class ZernikeInterpolation(object): +class ZernikeInterpolation: """Interpolate zernikes This class helps to interpolate zernikes using only the closest K elements @@ -478,7 +476,7 @@ def interpolate_zks(self, interp_positions): return tf.squeeze(interp_zks, axis=1) -class IndependentZernikeInterpolation(object): +class IndependentZernikeInterpolation: """Interpolate each Zernike polynomial independently The interpolation is done independently for each Zernike polynomial. From d11f0bb5ac76e5286bbd3d8e61ce760b6a39fc2e Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Fri, 27 Jun 2025 17:41:46 +0200 Subject: [PATCH 111/146] refactor the noise std dev calculation --- src/wf_psf/training/train_utils.py | 68 ++++++++++++++++++++++-------- 1 file changed, 51 insertions(+), 17 deletions(-) diff --git a/src/wf_psf/training/train_utils.py b/src/wf_psf/training/train_utils.py index a3d6f95a..5b54be12 100644 --- a/src/wf_psf/training/train_utils.py +++ b/src/wf_psf/training/train_utils.py @@ -382,6 +382,55 @@ def configure_optimizer_and_loss( return optimizer, loss, metrics +def compute_noise_std_from_stars( + outputs: np.ndarray, loss: Union[str, Callable, None] +) -> Optional[np.ndarray]: + """ + Compute the noise standard deviation from star images. + + Parameters + ---------- + outputs: np.ndarray + A 3D array of shape (batch_size, height, width) representing star images. + The first dimension is the batch size, and the next two dimensions are the image height and width. + loss: str, callable, optional + The loss function used for training. If the loss name is "masked_mean_squared_error", the function will calculate the noise standard deviation for masked images. + + Returns + ------- + np.ndarray + An array of standard deviations for each image in the batch, or None if no images are provided. + + """ + if outputs is not None and len(outputs.shape) >= 3: + img_dim = (outputs.shape[1], outputs.shape[2]) + win_rad = np.ceil(outputs.shape[1] / 3.33) + std_est = NoiseEstimator(img_dim=img_dim, win_rad=win_rad) + + if loss is not None and ( + (isinstance(loss, str) and loss == "masked_mean_squared_error") + or (hasattr(loss, "name") and loss.name == "masked_mean_squared_error") + ): + logger.info("Estimating noise standard deviation for masked images..") + images = outputs[..., 0] + masks = np.array(1 - outputs[..., 1], dtype=bool) + imgs_std = np.array( + [std_est.estimate_noise(_im, _win) for _im, _win in zip(images, masks)] + ) + else: + logger.info("Estimating noise standard deviation for images..") + # Estimate noise standard deviation + imgs_std = np.array([std_est.estimate_noise(_im) for _im in outputs]) + + else: + logger.warning( + "No images provided for noise standard deviation estimation or there was a problem with the input images." + ) + imgs_std = None + + return imgs_std + + def calculate_sample_weights( outputs: np.ndarray, use_sample_weights: bool, @@ -420,24 +469,9 @@ def calculate_sample_weights( An array of sample weights, or None if `use_sample_weights` is False. """ if use_sample_weights: - img_dim = (outputs.shape[1], outputs.shape[2]) - win_rad = np.ceil(outputs.shape[1] / 3.33) - std_est = NoiseEstimator(img_dim=img_dim, win_rad=win_rad) - if loss is not None and ( - (isinstance(loss, str) and loss == "masked_mean_squared_error") - or (hasattr(loss, "name") and loss.name == "masked_mean_squared_error") - ): - logger.info("Estimating noise standard deviation for masked images..") - images = outputs[..., 0] - masks = np.array(1 - outputs[..., 1], dtype=bool) - imgs_std = np.array( - [std_est.estimate_noise(_im, _win) for _im, _win in zip(images, masks)] - ) - else: - logger.info("Estimating noise standard deviation for images..") - # Estimate noise standard deviation - imgs_std = np.array([std_est.estimate_noise(_im) for _im in outputs]) + # Compute noise standard deviation from images + imgs_std = compute_noise_std_from_stars(outputs, loss) # Calculate variances variances = imgs_std**2 From 45e204412cf49df2e74e29e170c53c35f4cc7d79 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Fri, 27 Jun 2025 17:54:16 +0200 Subject: [PATCH 112/146] improve refactoring --- src/wf_psf/training/train_utils.py | 43 +++++++++++++++++------------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/src/wf_psf/training/train_utils.py b/src/wf_psf/training/train_utils.py index 5b54be12..e27aeb19 100644 --- a/src/wf_psf/training/train_utils.py +++ b/src/wf_psf/training/train_utils.py @@ -383,18 +383,19 @@ def configure_optimizer_and_loss( def compute_noise_std_from_stars( - outputs: np.ndarray, loss: Union[str, Callable, None] + images: np.ndarray, + masks: Optional[np.ndarray] = None, ) -> Optional[np.ndarray]: """ Compute the noise standard deviation from star images. Parameters ---------- - outputs: np.ndarray + images: np.ndarray A 3D array of shape (batch_size, height, width) representing star images. The first dimension is the batch size, and the next two dimensions are the image height and width. - loss: str, callable, optional - The loss function used for training. If the loss name is "masked_mean_squared_error", the function will calculate the noise standard deviation for masked images. + masks: np.ndarray, optional + A 3D array of shape (batch_size, height, width) representing masks for the images. Returns ------- @@ -402,25 +403,17 @@ def compute_noise_std_from_stars( An array of standard deviations for each image in the batch, or None if no images are provided. """ - if outputs is not None and len(outputs.shape) >= 3: - img_dim = (outputs.shape[1], outputs.shape[2]) - win_rad = np.ceil(outputs.shape[1] / 3.33) + if images is not None and len(images.shape) >= 3: + img_dim = (images.shape[1], images.shape[2]) + win_rad = np.ceil(images.shape[1] / 3.33) std_est = NoiseEstimator(img_dim=img_dim, win_rad=win_rad) - if loss is not None and ( - (isinstance(loss, str) and loss == "masked_mean_squared_error") - or (hasattr(loss, "name") and loss.name == "masked_mean_squared_error") - ): - logger.info("Estimating noise standard deviation for masked images..") - images = outputs[..., 0] - masks = np.array(1 - outputs[..., 1], dtype=bool) + if masks is not None: imgs_std = np.array( [std_est.estimate_noise(_im, _win) for _im, _win in zip(images, masks)] ) else: - logger.info("Estimating noise standard deviation for images..") - # Estimate noise standard deviation - imgs_std = np.array([std_est.estimate_noise(_im) for _im in outputs]) + imgs_std = np.array([std_est.estimate_noise(_im) for _im in images]) else: logger.warning( @@ -449,7 +442,7 @@ def calculate_sample_weights( ---------- outputs: np.ndarray A 3D array of shape (batch_size, height, width) representing images, where the first dimension is the batch size - and the next two dimensions are the image height and width. + and the next two dimensions are the image height and width. It can contain the masks in an extra dimension, e.g., (batch_size, height, width, 2), use_sample_weights: bool Flag indicating whether to compute sample weights. If True, sample weights will be computed based on the image noise. loss: str, callable, optional @@ -471,7 +464,19 @@ def calculate_sample_weights( if use_sample_weights: # Compute noise standard deviation from images - imgs_std = compute_noise_std_from_stars(outputs, loss) + if loss is not None and ( + (isinstance(loss, str) and loss == "masked_mean_squared_error") + or (hasattr(loss, "name") and loss.name == "masked_mean_squared_error") + ): + logger.info("Estimating noise standard deviation for masked images..") + images = outputs[..., 0] + masks = np.array(1 - outputs[..., 1], dtype=bool) + imgs_std = compute_noise_std_from_stars(images, masks) + + else: + logger.info("Estimating noise standard deviation for images..") + # Estimate noise standard deviation + imgs_std = compute_noise_std_from_stars(outputs) # Calculate variances variances = imgs_std**2 From 1ac42411189a85a00db7a8ff2dd605660bb42e3c Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Fri, 27 Jun 2025 19:24:29 +0200 Subject: [PATCH 113/146] add chi2 metric function --- src/wf_psf/metrics/metrics.py | 155 ++++++++++++++++++++++++++++++++++ 1 file changed, 155 insertions(+) diff --git a/src/wf_psf/metrics/metrics.py b/src/wf_psf/metrics/metrics.py index 4324d6e6..64141f2a 100644 --- a/src/wf_psf/metrics/metrics.py +++ b/src/wf_psf/metrics/metrics.py @@ -14,6 +14,7 @@ import wf_psf.utils.utils as utils from wf_psf.psf_models.psf_models import build_PSF_model from wf_psf.sims import psf_simulator as psf_simulator +from wf_psf.training.train_utils import compute_noise_std_from_stars import logging logger = logging.getLogger(__name__) @@ -157,6 +158,160 @@ def compute_poly_metric( return rmse, rel_rmse, std_rmse, std_rel_rmse +def compute_chi2_metric( + tf_trained_psf_model, + gt_tf_psf_model, + simPSF_np, + tf_pos, + tf_SEDs, + n_bins_lda=20, + n_bins_gt=20, + batch_size=16, + dataset_dict=None, + mask=False, +): + """Calculate the chi2 metric for polychromatic reconstructions at observation resolution. + + The ``tf_trained_psf_model`` should be the model to evaluate, and the + ``gt_tf_psf_model`` should be loaded with the ground truth PSF field. + + Parameters + ---------- + tf_trained_psf_model: PSF field object + Trained model to evaluate. + gt_tf_psf_model: PSF field object + Ground truth model to produce gt observations at any position + and wavelength. + simPSF_np: PSF simulator object + Simulation object to be used by ``generate_packed_elems`` function. + tf_pos: Tensor or numpy.ndarray [batch x 2] floats + Positions to evaluate the model. + tf_SEDs: numpy.ndarray [batch x SED_samples x 2] + SED samples for the corresponding positions. + n_bins_lda: int + Number of wavelength bins to use for the polychromatic PSF. + n_bins_gt: int + Number of wavelength bins to use for the ground truth polychromatic PSF. + batch_size: int + Batch size for the PSF calcualtions. + dataset_dict: dict + Dictionary containing the dataset information. If provided, and if the `'stars'` key + is present, the noiseless stars from the dataset are used to compute the metrics. + Otherwise, the stars are generated from the gt model. + Default is `None`. + mask: bool + If `True`, predictions are masked using the same mask as the target data, ensuring + that metric calculations consider only unmasked regions. + Default is `False`. + + Returns + ------- + reduced_chi2_stat: float + Reduced chi squared value. + avg_noise_std_dev: float + Average noise standard deviation used for the chi squared calculation. + + """ + # Create flag + noiseless_stars = False + + # Generate SED data list for the model + packed_SED_data = [ + utils.generate_packed_elems(_sed, simPSF_np, n_bins=n_bins_lda) + for _sed in tf_SEDs + ] + tf_packed_SED_data = tf.convert_to_tensor(packed_SED_data, dtype=tf.float32) + tf_packed_SED_data = tf.transpose(tf_packed_SED_data, perm=[0, 2, 1]) + pred_inputs = [tf_pos, tf_packed_SED_data] + + # Model prediction + preds = tf_trained_psf_model.predict(x=pred_inputs, batch_size=batch_size) + + # Ground truth data preparation + if dataset_dict is None or "stars" not in dataset_dict: + logger.info( + "No precomputed ground truth stars found. Regenerating from the ground truth model using configured interpolation settings." + ) + # The stars will be noiseless as we are recreating them from the ground truth model + noiseless_stars = True + + # Change interpolation parameters for the ground truth simPSF + simPSF_np.SED_interp_pts_per_bin = 0 + simPSF_np.SED_sigma = 0 + # Generate SED data list for gt model + packed_SED_data = [ + utils.generate_packed_elems(_sed, simPSF_np, n_bins=n_bins_gt) + for _sed in tf_SEDs + ] + tf_packed_SED_data = tf.convert_to_tensor(packed_SED_data, dtype=tf.float32) + tf_packed_SED_data = tf.transpose(tf_packed_SED_data, perm=[0, 2, 1]) + pred_inputs = [tf_pos, tf_packed_SED_data] + + # Ground Truth model prediction + reference_stars = gt_tf_psf_model.predict(x=pred_inputs, batch_size=batch_size) + + else: + logger.info("Using precomputed ground truth stars from dataset_dict['stars'].") + reference_stars = dataset_dict["stars"] + + # If the data is masked, mask the predictions + if mask: + logger.info( + "Applying masks to predictions. Only unmasked regions will be considered for metric calculations." + ) + # Change convention + masks = 1 - dataset_dict["masks"] + # Ensure masks as float dtype + masks = masks.astype(preds.dtype) + + else: + # We create a dummy mask of ones + masks = np.ones_like(reference_stars, dtype=preds.dtype) + + # Compute noise standard deviation from the reference stars + if not noiseless_stars: + estimated_std_dev = compute_noise_std_from_stars(reference_stars, masks) + # Check if there is a zero value + if np.any(estimated_std_dev == 0): + logger.info( + "Chi2 metric calculation: Some estimated standard deviations are zero. Setting them to 1 to avoid division by zero." + ) + estimated_std_dev[estimated_std_dev == 0] = 1.0 + else: + # If the stars are noiseless, we set the std dev to 1 + estimated_std_dev = np.ones(reference_stars.shape[0], dtype=preds.dtype) + logger.info( + "Using noiseless stars for chi2 calculation. Setting all std dev to 1." + ) + + # Compute residuals + residuals = (reference_stars - preds) * masks + + # Standardize residuals + standardized_residuals = np.array( + [ + (residual - np.sum(residual) / np.sum(mask)) / std_est + for residual, mask, std_est in zip(residuals, masks, estimated_std_dev) + ] + ) + # Compute the degrees of freedom and the mean + degrees_of_freedom = np.sum(masks) + mean_standardized_residuals = np.sum(standardized_residuals) / degrees_of_freedom + # The degrees of freedom is reduced by 1 because we're removing the mean (see Cochran's theorem) + reduced_chi2_stat = np.sum( + ((standardized_residuals - mean_standardized_residuals) * masks) ** 2 + ) / (degrees_of_freedom - 1) + + # Average std deviation + mean_std_dev = np.mean(estimated_std_dev) + + # Print chi2 values + logger.info("Reduced chi2:\t %.5e" % (reduced_chi2_stat)) + logger.info("Average noise std dev:\t %.5e" % (mean_std_dev)) + + return reduced_chi2_stat, mean_std_dev + + def compute_mono_metric( tf_semiparam_field, gt_tf_semiparam_field, From 43213b0e2ae9a50e830203548bbad0599588b470 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Fri, 27 Jun 2025 19:25:09 +0200 Subject: [PATCH 114/146] added chi2 metric to list and automatic black formating --- src/wf_psf/metrics/metrics_interface.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/wf_psf/metrics/metrics_interface.py b/src/wf_psf/metrics/metrics_interface.py index 3fa9d0bd..cf1f595a 100644 --- a/src/wf_psf/metrics/metrics_interface.py +++ b/src/wf_psf/metrics/metrics_interface.py @@ -379,6 +379,7 @@ def evaluate_model( # Define the metric evaluation functions metric_functions = { "poly_metric": metrics_handler.evaluate_metrics_polychromatic_lowres, + "chi2_metric": metrics_handler.evaluate_metrics_chi2, "mono_metric": metrics_handler.evaluate_metrics_mono_rmse, "opd_metric": metrics_handler.evaluate_metrics_opd, "shape_results_dict": metrics_handler.evaluate_metrics_shape, From c54928e2f96a592f1df88fc2c5271754f7fd36d9 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Fri, 27 Jun 2025 19:27:28 +0200 Subject: [PATCH 115/146] improve naming --- src/wf_psf/metrics/metrics.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/wf_psf/metrics/metrics.py b/src/wf_psf/metrics/metrics.py index 64141f2a..cc860310 100644 --- a/src/wf_psf/metrics/metrics.py +++ b/src/wf_psf/metrics/metrics.py @@ -270,16 +270,16 @@ def compute_chi2_metric( # Compute noise standard deviation from the reference stars if not noiseless_stars: - estimated_std_dev = compute_noise_std_from_stars(reference_stars, masks) + estimated_noise_std_dev = compute_noise_std_from_stars(reference_stars, masks) # Check if there is a zero value - if np.any(estimated_std_dev == 0): + if np.any(estimated_noise_std_dev == 0): logger.info( "Chi2 metric calculation: Some estimated standard deviations are zero. Setting them to 1 to avoid division by zero." ) - estimated_std_dev[estimated_std_dev == 0] = 1.0 + estimated_noise_std_dev[estimated_noise_std_dev == 0] = 1.0 else: # If the stars are noiseless, we set the std dev to 1 - estimated_std_dev = np.ones(reference_stars.shape[0], dtype=preds.dtype) + estimated_noise_std_dev = np.ones(reference_stars.shape[0], dtype=preds.dtype) logger.info( "Using noiseless stars for chi2 calculation. Setting all std dev to 1." ) @@ -291,7 +291,9 @@ def compute_chi2_metric( standardized_residuals = np.array( [ (residual - np.sum(residual) / np.sum(mask)) / std_est - for residual, mask, std_est in zip(residuals, masks, estimated_std_dev) + for residual, mask, std_est in zip( + residuals, masks, estimated_noise_std_dev + ) ] ) # Compute the degrees of freedom and the mean @@ -303,13 +305,13 @@ def compute_chi2_metric( ) / (degrees_of_freedom - 1) # Average std deviation - mean_std_dev = np.mean(estimated_std_dev) + mean_noise_std_dev = np.mean(estimated_noise_std_dev) # Print chi2 values logger.info("Reduced chi2:\t %.5e" % (reduced_chi2_stat)) - logger.info("Average noise std dev:\t %.5e" % (mean_std_dev)) + logger.info("Average noise std dev:\t %.5e" % (mean_noise_std_dev)) - return reduced_chi2_stat, mean_std_dev + return reduced_chi2_stat, mean_noise_std_dev def compute_mono_metric( From 56b5e7238a76c369275cec4233790d0e4d73069b Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Fri, 27 Jun 2025 19:29:14 +0200 Subject: [PATCH 116/146] improve docsting --- src/wf_psf/metrics/metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/wf_psf/metrics/metrics.py b/src/wf_psf/metrics/metrics.py index cc860310..5a15a899 100644 --- a/src/wf_psf/metrics/metrics.py +++ b/src/wf_psf/metrics/metrics.py @@ -209,7 +209,7 @@ def compute_chi2_metric( reduced_chi2_stat: float Reduced chi squared value. avg_noise_std_dev: float - Average noise standard deviation used for the chi squared calculation. + Average estimated noise standard deviation used for the chi squared calculation. """ # Create flag From 9fd138d013d32f6db72fa092f414f3d57c920029 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Fri, 27 Jun 2025 19:29:26 +0200 Subject: [PATCH 117/146] add chi2 evaluation function --- src/wf_psf/metrics/metrics_interface.py | 68 +++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/src/wf_psf/metrics/metrics_interface.py b/src/wf_psf/metrics/metrics_interface.py index cf1f595a..b6b15acb 100644 --- a/src/wf_psf/metrics/metrics_interface.py +++ b/src/wf_psf/metrics/metrics_interface.py @@ -110,6 +110,74 @@ def evaluate_metrics_polychromatic_lowres( "std_rel_rmse": std_rel_rmse, } + def evaluate_metrics_chi2( + self, psf_model: Any, simPSF: Any, data: Any, dataset: Dict[str, Any] + ) -> Dict[str, float]: + """Evaluate reduced chi2 metric for low-resolution polychromatic PSF. + + This method computes reduced chi2 metric for a + low-resolution polychromatic Point Spread Function (PSF) model. + + Parameters + ---------- + psf_model : object + An instance of the PSF model selected for metrics evaluation. + simPSF : object + An instance of the PSFSimulator. + data : object + A DataConfigHandler object containing training and test datasets. + dataset : dict + Dictionary containing dataset details, including: + - ``SEDs`` (Spectral Energy Distributions) + - ``positions`` (Star positions) + - ``C_poly`` Tensor or None, optional + The Zernike coefficient matrix used in generating simulations of the PSF model. This + matrix defines the Zernike polynomials up to a given order used to simulate the PSF + field. It may be present in some datasets or only required for some classes. + If not present or required, the model will proceed without it. + + + Returns + ------- + dict + A dictionary containing the reduced chi2 statistic and the Average estimated + noise standard deviation used for the chi squared calculation. + + - ``reduced_chi2`` : float + Reduced chi squared value. + - ``mean_noise_std_dev`` : float + Average estimated noise standard deviation used for the chi squared calculation. + + """ + logger.info("Computing polychromatic metrics at low resolution.") + + # Check if testing predictions should be masked + mask = self.trained_model.training_hparams.loss == "mask_mse" + + # Compute metrics + reduced_chi2_stat, mean_noise_std_dev = wf_metrics.compute_chi2_metric( + tf_semiparam_field=psf_model, + gt_tf_semiparam_field=psf_models.get_psf_model( + self.metrics_params.ground_truth_model.model_params, + self.metrics_params.metrics_hparams, + data, + dataset.get("C_poly", None), # Extract C_poly or default to None + ), + simPSF_np=simPSF, + tf_pos=dataset["positions"], + tf_SEDs=dataset["SEDs"], + n_bins_lda=self.trained_model.model_params.n_bins_lda, + n_bins_gt=self.metrics_params.ground_truth_model.model_params.n_bins_lda, + batch_size=self.metrics_params.metrics_hparams.batch_size, + dataset_dict=dataset, + mask=mask, + ) + + return { + "reduced_chi2": reduced_chi2_stat, + "mean_noise_std_dev": mean_noise_std_dev, + } + def evaluate_metrics_mono_rmse( self, psf_model: Any, simPSF: Any, data: Any, dataset: Dict[str, Any] ) -> Dict[str, float]: From 6303290705773e959a3c2da990d86400b4ed2cc9 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Fri, 27 Jun 2025 19:35:49 +0200 Subject: [PATCH 118/146] add evalutation flag --- src/wf_psf/metrics/metrics_interface.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/wf_psf/metrics/metrics_interface.py b/src/wf_psf/metrics/metrics_interface.py index b6b15acb..4f20eca9 100644 --- a/src/wf_psf/metrics/metrics_interface.py +++ b/src/wf_psf/metrics/metrics_interface.py @@ -430,6 +430,10 @@ def evaluate_model( "test": True, "train": True, }, + "chi2_metric": { + "test": metrics_params.eval_chi2_metric, + "train": metrics_params.eval_chi2_metric, + }, "mono_metric": { "test": metrics_params.eval_mono_metric, "train": metrics_params.eval_mono_metric, From 22bf695adfd4db32ac3c5402d9f9f03c74917539 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Fri, 27 Jun 2025 19:43:53 +0200 Subject: [PATCH 119/146] fix parameter name bug --- src/wf_psf/metrics/metrics_interface.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/wf_psf/metrics/metrics_interface.py b/src/wf_psf/metrics/metrics_interface.py index 4f20eca9..20a6d4f4 100644 --- a/src/wf_psf/metrics/metrics_interface.py +++ b/src/wf_psf/metrics/metrics_interface.py @@ -156,8 +156,8 @@ def evaluate_metrics_chi2( # Compute metrics reduced_chi2_stat, mean_noise_std_dev = wf_metrics.compute_chi2_metric( - tf_semiparam_field=psf_model, - gt_tf_semiparam_field=psf_models.get_psf_model( + tf_trained_psf_model=psf_model, + gt_tf_psf_model=psf_models.get_psf_model( self.metrics_params.ground_truth_model.model_params, self.metrics_params.metrics_hparams, data, From 4a0580d403d3f884bf625f62fe1490a928394b7f Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Fri, 27 Jun 2025 19:49:55 +0200 Subject: [PATCH 120/146] fix data type problem --- src/wf_psf/metrics/metrics.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/wf_psf/metrics/metrics.py b/src/wf_psf/metrics/metrics.py index 5a15a899..002662c0 100644 --- a/src/wf_psf/metrics/metrics.py +++ b/src/wf_psf/metrics/metrics.py @@ -270,7 +270,9 @@ def compute_chi2_metric( # Compute noise standard deviation from the reference stars if not noiseless_stars: - estimated_noise_std_dev = compute_noise_std_from_stars(reference_stars, masks) + estimated_noise_std_dev = compute_noise_std_from_stars( + reference_stars, masks.astype(bool) + ) # Check if there is a zero value if np.any(estimated_noise_std_dev == 0): logger.info( From 5dbdecb9cb36e99e5e26190e4fe48b5a24d5dcdd Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 16 May 2025 14:41:17 +0200 Subject: [PATCH 121/146] Refactor: reorganise modules, relocate utility functions, rename modules, update import statements and unit tests --- src/wf_psf/data/data_handler.py | 3 + src/wf_psf/data/data_zernike_utils.py | 1 - src/wf_psf/data/old_zernike_prior.py | 335 ------------------ src/wf_psf/instrument/ccd_misalignments.py | 40 +++ .../tests/test_data/data_handler_test.py | 1 - 5 files changed, 43 insertions(+), 337 deletions(-) delete mode 100644 src/wf_psf/data/old_zernike_prior.py diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index 052fe730..f62af516 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -18,6 +18,9 @@ from wf_psf.psf_models.tf_modules.tf_utils import ensure_tensor import tensorflow as tf from typing import Optional, Union +import numpy as np +import tensorflow as tf +from fractions import Fraction import logging logger = logging.getLogger(__name__) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index 0fad9c8e..54582e7d 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -359,7 +359,6 @@ def compute_zernike_tip_tilt( An array of shape `(num_images, 2)`, where: - Column 0 contains `Zk1` (tip) values. - Column 1 contains `Zk2` (tilt) values. - Notes ----- - This function processes all images at once using vectorized operations. diff --git a/src/wf_psf/data/old_zernike_prior.py b/src/wf_psf/data/old_zernike_prior.py deleted file mode 100644 index 0feb3e70..00000000 --- a/src/wf_psf/data/old_zernike_prior.py +++ /dev/null @@ -1,335 +0,0 @@ -import numpy as np -import tensorflow as tf -from wf_psf.instrument.ccd_misalignments import CCDMisalignmentCalculator -from wf_psf.data.centroids import compute_zernike_tip_tilt -from fractions import Fraction -import logging - -logger = logging.getLogger(__name__) - - -def get_np_obs_positions(data): - """Get observed positions in numpy from the provided dataset. - - This method concatenates the positions of the stars from both the training - and test datasets to obtain the observed positions. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - np.ndarray - Numpy array containing the observed positions of the stars. - - Notes - ----- - The observed positions are obtained by concatenating the positions of stars - from both the training and test datasets along the 0th axis. - """ - obs_positions = np.concatenate( - ( - data.training_data.dataset["positions"], - data.test_data.dataset["positions"], - ), - axis=0, - ) - - return obs_positions - - -def get_obs_positions(data): - """Get observed positions from the provided dataset. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - tf.Tensor - Tensor containing the observed positions of the stars. - """ - obs_positions = get_np_obs_positions(data) - - return tf.convert_to_tensor(obs_positions, dtype=tf.float32) - - -def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: - """Extract specific star-related data from training and test datasets. - - This function retrieves and concatenates specific star-related data (e.g., stars, masks) from the - star training and test datasets such as star stamps or masks, based on the provided keys. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - train_key : str - The key to retrieve data from the training dataset (e.g., 'noisy_stars', 'masks'). - test_key : str - The key to retrieve data from the test dataset (e.g., 'stars', 'masks'). - - Returns - ------- - np.ndarray - A NumPy array containing the concatenated data for the given keys. - - Raises - ------ - KeyError - If the specified keys do not exist in the training or test datasets. - - Notes - ----- - - If the dataset contains TensorFlow tensors, they will be converted to NumPy arrays. - - Ensure that eager execution is enabled when calling this function. - """ - # Ensure the requested keys exist in both training and test datasets - missing_keys = [ - key - for key, dataset in [ - (train_key, data.training_data.dataset), - (test_key, data.test_data.dataset), - ] - if key not in dataset - ] - - if missing_keys: - raise KeyError(f"Missing keys in dataset: {missing_keys}") - - # Retrieve data from training and test sets - train_data = data.training_data.dataset[train_key] - test_data = data.test_data.dataset[test_key] - - # Convert to NumPy if necessary - if tf.is_tensor(train_data): - train_data = train_data.numpy() - if tf.is_tensor(test_data): - test_data = test_data.numpy() - - # Concatenate and return - return np.concatenate((train_data, test_data), axis=0) - - -def get_np_zernike_prior(data): - """Get the zernike prior from the provided dataset. - - This method concatenates the stars from both the training - and test datasets to obtain the full prior. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - zernike_prior : np.ndarray - Numpy array containing the full prior. - """ - zernike_prior = np.concatenate( - ( - data.training_data.dataset["zernike_prior"], - data.test_data.dataset["zernike_prior"], - ), - axis=0, - ) - - return zernike_prior - - -def compute_centroid_correction(model_params, data, batch_size: int = 1) -> np.ndarray: - """Compute centroid corrections using Zernike polynomials. - - This function calculates the Zernike contributions required to match the centroid - of the WaveDiff PSF model to the observed star centroids, processing in batches. - - Parameters - ---------- - model_params : RecursiveNamespace - An object containing parameters for the PSF model, including pixel sampling and initial centroid window parameters. - - data : DataConfigHandler - An object containing training and test datasets, including observed PSFs - and optional star masks. - - batch_size : int, optional - The batch size to use when processing the stars. Default is 16. - - - Returns - ------- - zernike_centroid_array : np.ndarray - A 2D NumPy array of shape `(n_stars, 3)`, where `n_stars` is the number of - observed stars. The array contains the computed Zernike contributions, - with zero padding applied to the first column to ensure a consistent shape. - """ - star_postage_stamps = extract_star_data( - data=data, train_key="noisy_stars", test_key="stars" - ) - - # Get star mask catalogue only if "masks" exist in both training and test datasets - star_masks = ( - extract_star_data(data=data, train_key="masks", test_key="masks") - if ( - data.training_data.dataset.get("masks") is not None - and data.test_data.dataset.get("masks") is not None - and tf.size(data.training_data.dataset["masks"]) > 0 - and tf.size(data.test_data.dataset["masks"]) > 0 - ) - else None - ) - - pix_sampling = model_params.pix_sampling * 1e-6 # Change units from [um] to [m] - - # Ensure star_masks is properly handled - star_masks = star_masks if star_masks is not None else None - - reference_shifts = [ - float(Fraction(value)) for value in model_params.reference_shifts - ] - - n_stars = len(star_postage_stamps) - zernike_centroid_array = [] - - # Batch process the stars - for i in range(0, n_stars, batch_size): - batch_postage_stamps = star_postage_stamps[i : i + batch_size] - batch_masks = star_masks[i : i + batch_size] if star_masks is not None else None - - # Compute Zernike 1 and Zernike 2 for the batch - zk1_2_batch = -1.0 * compute_zernike_tip_tilt( - batch_postage_stamps, batch_masks, pix_sampling, reference_shifts - ) - - # Zero pad array for each batch and append - zernike_centroid_array.append( - np.pad( - zk1_2_batch, - pad_width=[(0, 0), (1, 0)], - mode="constant", - constant_values=0, - ) - ) - - # Combine all batches into a single array - return np.concatenate(zernike_centroid_array, axis=0) - - -def compute_ccd_misalignment(model_params, data): - """Compute CCD misalignment. - - Parameters - ---------- - model_params : RecursiveNamespace - Object containing parameters for this PSF model class. - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - zernike_ccd_misalignment_array : np.ndarray - Numpy array containing the Zernike contributions to model the CCD chip misalignments. - """ - obs_positions = get_np_obs_positions(data) - - ccd_misalignment_calculator = CCDMisalignmentCalculator( - tiles_path=model_params.ccd_misalignments_input_path, - x_lims=model_params.x_lims, - y_lims=model_params.y_lims, - tel_focal_length=model_params.tel_focal_length, - tel_diameter=model_params.tel_diameter, - ) - # Compute required zernike 4 for each position - zk4_values = np.array( - [ - ccd_misalignment_calculator.get_zk4_from_position(single_pos) - for single_pos in obs_positions - ] - ).reshape(-1, 1) - - # Zero pad array to get shape (n_stars, n_zernike=4) - zernike_ccd_misalignment_array = np.pad( - zk4_values, pad_width=[(0, 0), (3, 0)], mode="constant", constant_values=0 - ) - - return zernike_ccd_misalignment_array - - -def get_zernike_prior(model_params, data, batch_size: int = 16): - """Get Zernike priors from the provided dataset. - - This method concatenates the Zernike priors from both the training - and test datasets. - - Parameters - ---------- - model_params : RecursiveNamespace - Object containing parameters for this PSF model class. - data : DataConfigHandler - Object containing training and test datasets. - batch_size : int, optional - The batch size to use when processing the stars. Default is 16. - - Returns - ------- - tf.Tensor - Tensor containing the observed positions of the stars. - - Notes - ----- - The Zernike prior are obtained by concatenating the Zernike priors - from both the training and test datasets along the 0th axis. - - """ - # List of zernike contribution - zernike_contribution_list = [] - - if model_params.use_prior: - logger.info("Reading in Zernike prior into Zernike contribution list...") - zernike_contribution_list.append(get_np_zernike_prior(data)) - - if model_params.correct_centroids: - logger.info("Adding centroid correction to Zernike contribution list...") - zernike_contribution_list.append( - compute_centroid_correction(model_params, data, batch_size) - ) - - if model_params.add_ccd_misalignments: - logger.info("Adding CCD mis-alignments to Zernike contribution list...") - zernike_contribution_list.append(compute_ccd_misalignment(model_params, data)) - - if len(zernike_contribution_list) == 1: - zernike_contribution = zernike_contribution_list[0] - else: - # Get max zk order - max_zk_order = np.max( - np.array( - [ - zk_contribution.shape[1] - for zk_contribution in zernike_contribution_list - ] - ) - ) - - zernike_contribution = np.zeros( - (zernike_contribution_list[0].shape[0], max_zk_order) - ) - - # Pad arrays to get the same length and add the final contribution - for it in range(len(zernike_contribution_list)): - current_zk_order = zernike_contribution_list[it].shape[1] - current_zernike_contribution = np.pad( - zernike_contribution_list[it], - pad_width=[(0, 0), (0, int(max_zk_order - current_zk_order))], - mode="constant", - constant_values=0, - ) - - zernike_contribution += current_zernike_contribution - - return tf.convert_to_tensor(zernike_contribution, dtype=tf.float32) diff --git a/src/wf_psf/instrument/ccd_misalignments.py b/src/wf_psf/instrument/ccd_misalignments.py index da21f630..e6cbab74 100644 --- a/src/wf_psf/instrument/ccd_misalignments.py +++ b/src/wf_psf/instrument/ccd_misalignments.py @@ -53,6 +53,46 @@ def compute_ccd_misalignment(model_params, positions: np.ndarray) -> np.ndarray: return zernike_ccd_misalignment_array +def compute_ccd_misalignment(model_params, data): + """Compute CCD misalignment. + + Parameters + ---------- + model_params : RecursiveNamespace + Object containing parameters for this PSF model class. + data : DataConfigHandler + Object containing training and test datasets. + + Returns + ------- + zernike_ccd_misalignment_array : np.ndarray + Numpy array containing the Zernike contributions to model the CCD chip misalignments. + """ + obs_positions = get_np_obs_positions(data) + + ccd_misalignment_calculator = CCDMisalignmentCalculator( + tiles_path=model_params.ccd_misalignments_input_path, + x_lims=model_params.x_lims, + y_lims=model_params.y_lims, + tel_focal_length=model_params.tel_focal_length, + tel_diameter=model_params.tel_diameter, + ) + # Compute required zernike 4 for each position + zk4_values = np.array( + [ + ccd_misalignment_calculator.get_zk4_from_position(single_pos) + for single_pos in obs_positions + ] + ).reshape(-1, 1) + + # Zero pad array to get shape (n_stars, n_zernike=4) + zernike_ccd_misalignment_array = np.pad( + zk4_values, pad_width=[(0, 0), (3, 0)], mode="constant", constant_values=0 + ) + + return zernike_ccd_misalignment_array + + class CCDMisalignmentCalculator: """CCD Misalignment Calculator. diff --git a/src/wf_psf/tests/test_data/data_handler_test.py b/src/wf_psf/tests/test_data/data_handler_test.py index d29771a1..4a424d00 100644 --- a/src/wf_psf/tests/test_data/data_handler_test.py +++ b/src/wf_psf/tests/test_data/data_handler_test.py @@ -198,7 +198,6 @@ def test_extract_star_data_partially_missing_key(mock_data): def test_extract_star_data_tensor_conversion(mock_data): """Test that the function properly converts TensorFlow tensors to NumPy arrays.""" result = extract_star_data(mock_data, train_key="noisy_stars", test_key="stars") - assert isinstance(result, np.ndarray), "The result should be a NumPy array" assert result.dtype == np.float32, "The NumPy array should have dtype float32" From 545b96c212780942e0a9ba9a5d63a90bf2aa735b Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sun, 18 May 2025 12:22:27 +0200 Subject: [PATCH 122/146] Refactor data_handler with new utility functions to validate and process datasets and update docstrings --- src/wf_psf/data/data_handler.py | 86 ++++++++++++++++++--------------- 1 file changed, 46 insertions(+), 40 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index f62af516..e1bbbe89 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -21,6 +21,7 @@ import numpy as np import tensorflow as tf from fractions import Fraction +from typing import Optional, Union import logging logger = logging.getLogger(__name__) @@ -80,47 +81,52 @@ def __init__( sed_data: Optional[Union[dict, list]] = None, ): """ - Initialize the DataHandler for PSF dataset preparation. - - This constructor sets up the dataset handler used for PSF simulation tasks, - such as training, testing, or inference. It supports three modes of use: - - 1. **Manual mode** (`load_data=False`, no `dataset`): data loading and SED processing - must be triggered manually via `load_dataset()` and `process_sed_data()`. - 2. **Pre-loaded dataset mode** (`dataset` is provided): the given dataset is used directly, - and `process_sed_data()` is called with either the given `sed_data` or `dataset["SEDs"]`. - 3. **Automatic loading mode** (`load_data=True` and no `dataset`): the dataset is loaded - from disk using `data_params`, and SEDs are extracted and processed automatically. - - Parameters - ---------- - dataset_type : str - One of {"train", "test", "inference"} indicating dataset usage. - data_params : RecursiveNamespace - Configuration object with paths, preprocessing options, and metadata. - simPSF : PSFSimulator - Used to convert SEDs to TensorFlow format. - n_bins_lambda : int - Number of wavelength bins for the SEDs. - load_data : bool, optional - Whether to automatically load and process the dataset (default: True). - dataset : dict or list, optional - A pre-loaded dataset to use directly (overrides `load_data`). - sed_data : array-like, optional - Pre-loaded SED data to use directly. If not provided but `dataset` is, - SEDs are taken from `dataset["SEDs"]`. - - Raises - ------ - ValueError - If SEDs cannot be found in either `dataset` or as `sed_data`. - - Notes - ----- - - `self.dataset` and `self.sed_data` are both `None` if neither `dataset` nor - `load_data=True` is used. - - TensorFlow conversion is performed at the end of initialization via `convert_dataset_to_tensorflow()`. + Initialize the DataHandler for PSF dataset preparation. + + This constructor sets up the dataset handler used for PSF simulation tasks, + such as training, testing, or inference. It supports three modes of use: + + 1. **Manual mode** (`load_data=False`, no `dataset`): data loading and SED processing + must be triggered manually via `load_dataset()` and `process_sed_data()`. + 2. **Pre-loaded dataset mode** (`dataset` is provided): the given dataset is used directly, + and `process_sed_data()` is called with either the given `sed_data` or `dataset["SEDs"]`. + 3. **Automatic loading mode** (`load_data=True` and no `dataset`): the dataset is loaded + from disk using `data_params`, and SEDs are extracted and processed automatically. + + Parameters + ---------- + dataset_type : str + One of {"train", "test", "inference"} indicating dataset usage. + data_params : RecursiveNamespace + Configuration object with paths, preprocessing options, and metadata. + simPSF : PSFSimulator + Used to convert SEDs to TensorFlow format. + n_bins_lambda : int + Number of wavelength bins for the SEDs. + load_data : bool, optional + Whether to automatically load and process the dataset (default: True). + dataset : dict or list, optional + A pre-loaded dataset to use directly (overrides `load_data`). + sed_data : array-like, optional + <<<<<<< HEAD + Pre-loaded SED data to use directly. If not provided but `dataset` is, + ======= + Pre-loaded SED data to use directly. If not provided but `dataset` is, + >>>>>>> 4b896e3 (Refactor data_handler with new utility functions to validate and process datasets and update docstrings) + SEDs are taken from `dataset["SEDs"]`. + + Raises + ------ + ValueError + If SEDs cannot be found in either `dataset` or as `sed_data`. + + Notes + ----- + - `self.dataset` and `self.sed_data` are both `None` if neither `dataset` nor + `load_data=True` is used. + - TensorFlow conversion is performed at the end of initialization via `convert_dataset_to_tensorflow()`. """ + self.dataset_type = dataset_type self.data_params = data_params self.simPSF = simPSF From 09da1666d1cfbd73a249828c0678d3264dfcd597 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sun, 18 May 2025 12:23:27 +0200 Subject: [PATCH 123/146] Update unit tests associated to changes in data_handler.py --- src/wf_psf/tests/test_data/data_handler_test.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/wf_psf/tests/test_data/data_handler_test.py b/src/wf_psf/tests/test_data/data_handler_test.py index 4a424d00..85200b8a 100644 --- a/src/wf_psf/tests/test_data/data_handler_test.py +++ b/src/wf_psf/tests/test_data/data_handler_test.py @@ -9,6 +9,11 @@ from wf_psf.utils.read_config import RecursiveNamespace +def mock_sed(): + # Create a fake SED with shape (n_wavelengths,) — match what your real SEDs look like + return np.linspace(0.1, 1.0, 50) + + def mock_sed(): # Create a fake SED with shape (n_wavelengths,) — match what your real SEDs look like return np.linspace(0.1, 1.0, 50) From a1f215cccc473fcae91e21f5318f57a323b2697e Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sat, 21 Jun 2025 18:52:26 +0200 Subject: [PATCH 124/146] Refactor TFPhysicalPolychromaticField to lazy load property objects and attributes dynamically at run-time according to the run_type: training or inference --- .../models/psf_model_physical_polychromatic.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index fb8bc902..263f0430 100644 --- a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -142,6 +142,23 @@ def __init__(self, model_params, training_params, data, coeff_mat=None): zks_total_contribution_np.shape[1], ) + @property + def tf_zernike_OPD(self): + """Lazy loading of the Zernike Optical Path Difference (OPD) layer.""" + if not hasattr(self, "_tf_zernike_OPD"): + self._tf_zernike_OPD = TFZernikeOPD(zernike_maps=self.zernike_maps) + return self._tf_zernike_OPD + + @property + def tf_batch_poly_PSF(self): + """Lazily initialize the batch polychromatic PSF layer.""" + if not hasattr(self, "_tf_batch_poly_PSF"): + obscurations = psfm.tf_obscurations( + pupil_diam=self.model_params.pupil_diameter, + N_filter=self.model_params.LP_filter_length, + rotation_angle=self.model_params.obscuration_rotation_angle, + ) + # Precompute zernike maps as tf.float32 self._zernike_maps = psfm.generate_zernike_maps_3d( n_zernikes=self._n_zks_total, pupil_diam=self.model_params.pupil_diameter From 3aa325ce0e80c02b84e9ef3cbbec5f1f908a3bef Mon Sep 17 00:00:00 2001 From: jeipollack Date: Sun, 22 Jun 2025 18:03:50 +0200 Subject: [PATCH 125/146] Use ensure_tensor method from tf_utils.py to check/convert to tensorflow type; Remove get_obs_positions and replace with ensure_tensor method; add property tf_positions --- src/wf_psf/data/data_handler.py | 87 +++++++++++++++------------------ 1 file changed, 40 insertions(+), 47 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index e1bbbe89..ae690c69 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -17,9 +17,6 @@ import wf_psf.utils.utils as utils from wf_psf.psf_models.tf_modules.tf_utils import ensure_tensor import tensorflow as tf -from typing import Optional, Union -import numpy as np -import tensorflow as tf from fractions import Fraction from typing import Optional, Union import logging @@ -81,50 +78,46 @@ def __init__( sed_data: Optional[Union[dict, list]] = None, ): """ - Initialize the DataHandler for PSF dataset preparation. - - This constructor sets up the dataset handler used for PSF simulation tasks, - such as training, testing, or inference. It supports three modes of use: - - 1. **Manual mode** (`load_data=False`, no `dataset`): data loading and SED processing - must be triggered manually via `load_dataset()` and `process_sed_data()`. - 2. **Pre-loaded dataset mode** (`dataset` is provided): the given dataset is used directly, - and `process_sed_data()` is called with either the given `sed_data` or `dataset["SEDs"]`. - 3. **Automatic loading mode** (`load_data=True` and no `dataset`): the dataset is loaded - from disk using `data_params`, and SEDs are extracted and processed automatically. - - Parameters - ---------- - dataset_type : str - One of {"train", "test", "inference"} indicating dataset usage. - data_params : RecursiveNamespace - Configuration object with paths, preprocessing options, and metadata. - simPSF : PSFSimulator - Used to convert SEDs to TensorFlow format. - n_bins_lambda : int - Number of wavelength bins for the SEDs. - load_data : bool, optional - Whether to automatically load and process the dataset (default: True). - dataset : dict or list, optional - A pre-loaded dataset to use directly (overrides `load_data`). - sed_data : array-like, optional - <<<<<<< HEAD - Pre-loaded SED data to use directly. If not provided but `dataset` is, - ======= - Pre-loaded SED data to use directly. If not provided but `dataset` is, - >>>>>>> 4b896e3 (Refactor data_handler with new utility functions to validate and process datasets and update docstrings) - SEDs are taken from `dataset["SEDs"]`. - - Raises - ------ - ValueError - If SEDs cannot be found in either `dataset` or as `sed_data`. - - Notes - ----- - - `self.dataset` and `self.sed_data` are both `None` if neither `dataset` nor - `load_data=True` is used. - - TensorFlow conversion is performed at the end of initialization via `convert_dataset_to_tensorflow()`. + Initialize the DataHandler for PSF dataset preparation. + + This constructor sets up the dataset handler used for PSF simulation tasks, + such as training, testing, or inference. It supports three modes of use: + + 1. **Manual mode** (`load_data=False`, no `dataset`): data loading and SED processing + must be triggered manually via `load_dataset()` and `process_sed_data()`. + 2. **Pre-loaded dataset mode** (`dataset` is provided): the given dataset is used directly, + and `process_sed_data()` is called with either the given `sed_data` or `dataset["SEDs"]`. + 3. **Automatic loading mode** (`load_data=True` and no `dataset`): the dataset is loaded + from disk using `data_params`, and SEDs are extracted and processed automatically. + + Parameters + ---------- + dataset_type : str + One of {"train", "test", "inference"} indicating dataset usage. + data_params : RecursiveNamespace + Configuration object with paths, preprocessing options, and metadata. + simPSF : PSFSimulator + Used to convert SEDs to TensorFlow format. + n_bins_lambda : int + Number of wavelength bins for the SEDs. + load_data : bool, optional + Whether to automatically load and process the dataset (default: True). + dataset : dict or list, optional + A pre-loaded dataset to use directly (overrides `load_data`). + sed_data : array-like, optional + Pre-loaded SED data to use directly. If not provided but `dataset` is, + SEDs are taken from `dataset["SEDs"]`. + + Raises + ------ + ValueError + If SEDs cannot be found in either `dataset` or as `sed_data`. + + Notes + ----- + - `self.dataset` and `self.sed_data` are both `None` if neither `dataset` nor + `load_data=True` is used. + - TensorFlow conversion is performed at the end of initialization via `convert_dataset_to_tensorflow()`. """ self.dataset_type = dataset_type From 8383deb9de56927112cd6555f48edea7e9c788bf Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 7 Jul 2025 11:39:08 +0200 Subject: [PATCH 126/146] add median noise calculation and todo comment --- src/wf_psf/metrics/metrics.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/wf_psf/metrics/metrics.py b/src/wf_psf/metrics/metrics.py index 002662c0..f5f39afd 100644 --- a/src/wf_psf/metrics/metrics.py +++ b/src/wf_psf/metrics/metrics.py @@ -308,10 +308,13 @@ def compute_chi2_metric( # Average std deviation mean_noise_std_dev = np.mean(estimated_noise_std_dev) + median_noise_std_dev = np.median(estimated_noise_std_dev) + # TODO: Compute the reduced chi2 for each star. Show median and mean values. # Print chi2 values - logger.info("Reduced chi2:\t %.5e" % (reduced_chi2_stat)) + logger.info("Reduced chi2:\t\t %.5e" % (reduced_chi2_stat)) logger.info("Average noise std dev:\t %.5e" % (mean_noise_std_dev)) + logger.info("Median noise std dev:\t %.5e" % (median_noise_std_dev)) return reduced_chi2_stat, mean_noise_std_dev From eaacdb545b29bfcb5470e7f0177fb987a458c70e Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 7 Jul 2025 11:42:08 +0200 Subject: [PATCH 127/146] update docstring --- src/wf_psf/metrics/metrics_interface.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/wf_psf/metrics/metrics_interface.py b/src/wf_psf/metrics/metrics_interface.py index 20a6d4f4..ef72f9a7 100644 --- a/src/wf_psf/metrics/metrics_interface.py +++ b/src/wf_psf/metrics/metrics_interface.py @@ -396,8 +396,6 @@ def evaluate_model( DataHandler object containing training and test data psf_model: object PSF model object - weights_path: str - Directory location of model weights metrics_output: str Directory location of metrics output From f17add0bbac00b2d27a8fd24b0251cde43590a7a Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 7 Jul 2025 14:47:55 +0200 Subject: [PATCH 128/146] update tabs --- src/wf_psf/metrics/metrics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/wf_psf/metrics/metrics.py b/src/wf_psf/metrics/metrics.py index f5f39afd..f992d187 100644 --- a/src/wf_psf/metrics/metrics.py +++ b/src/wf_psf/metrics/metrics.py @@ -312,9 +312,9 @@ def compute_chi2_metric( # TODO: Compute the reduced chi2 for each star. Show median and mean values. # Print chi2 values - logger.info("Reduced chi2:\t\t %.5e" % (reduced_chi2_stat)) + logger.info("Reduced chi2:\t\t\t %.5e" % (reduced_chi2_stat)) logger.info("Average noise std dev:\t %.5e" % (mean_noise_std_dev)) - logger.info("Median noise std dev:\t %.5e" % (median_noise_std_dev)) + logger.info("Median noise std dev:\t\t %.5e" % (median_noise_std_dev)) return reduced_chi2_stat, mean_noise_std_dev From ba4cfa4b68db2bce26d61bf330a9d430781fbf0c Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 7 Jul 2025 14:38:33 +0200 Subject: [PATCH 129/146] add info about the reference simulated datasets --- data/generation/README.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/data/generation/README.md b/data/generation/README.md index aa90bbbf..2a363cbb 100644 --- a/data/generation/README.md +++ b/data/generation/README.md @@ -8,6 +8,16 @@ To run the script: python data_generation_script.py -c data_generation_params_v.0.1.0.yml ``` +> ⚠️ Warning +> +> There are some differences with the original dataset used for the WaveDiff paper (Liaudat et al. 2023) even if we use it as a reference: +> - The assingment of SEDs for each star will not match that one of the original dataset. Although the same templates are used. +> - The assigned noise level (SNR) for each star will not match the original dataset. Although, the same distribution will be used. +> +> Nevertheless, the `C_poly` will match and the `positions` will match. Therefore, results from the new datasets will not be exactly the same as in the original datasets. + + +### Dataset description **Dataset 0.x.x:** @@ -56,3 +66,5 @@ python data_generation_script.py -c data_generation_params_v.0.1.0.yml - v3.2.1/2/3/4 with dummy (unitary) masks - v3.3.1/2/3/4 with realistic masks + + From 15c22fed6fd79fc0e7b058bce7e3cac604138a30 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 7 Jul 2025 15:14:27 +0200 Subject: [PATCH 130/146] add noisy stars and masked noisy stars for the test dataset --- data/generation/data_generation_script.py | 25 +++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/data/generation/data_generation_script.py b/data/generation/data_generation_script.py index f9746ad9..dd8f93fb 100644 --- a/data/generation/data_generation_script.py +++ b/data/generation/data_generation_script.py @@ -599,6 +599,7 @@ def main(args): # Add the centroid shifts to the Zernike coefficients train_zks += train_delta_centroid_shifts + # TO_DEFINE: For now we do add the centroid shifts to the test dataset. Should we? test_zks += test_delta_centroid_shifts # ------------ # @@ -776,6 +777,21 @@ def main(args): axis=0, ) + # Also add noise to the test stars + noisy_test_poly_psf_np = np.copy(test_poly_psf_np) + # Generate a dataset with a SNR varying randomly within the desired range + rand_SNR = ( + np.random.rand(noisy_test_poly_psf_np.shape[0]) * (SNR_range[1] - SNR_range[0]) + ) + SNR_range[0] + # Add Gaussian noise to the observations + noisy_test_poly_psf_np = np.stack( + [ + add_noise(_im, desired_SNR=_SNR) + for _im, _SNR in zip(noisy_test_poly_psf_np, rand_SNR) + ], + axis=0, + ) + # ------------ # # Generate masks @@ -797,6 +813,13 @@ def main(args): noisy_train_poly_psf_np.dtype ) + masked_noisy_test_poly_psf_np = np.copy(noisy_test_poly_psf_np) + # Apply the random masks to the test stars + masked_noisy_test_poly_psf_np = ( + masked_noisy_test_poly_psf_np + * test_masks.astype(noisy_test_poly_psf_np.dtype) + ) + # Turn masks to SHE convention. 1 (True) means to mask and 0 (False) means to keep train_masks = ~train_masks test_masks = ~test_masks @@ -1131,6 +1154,7 @@ def main(args): test_psf_dataset = { "stars": test_poly_psf_np, "SR_stars": SR_test_poly_psf_np, + "noisy_stars": noisy_test_poly_psf_np, "positions": test_positions, "SEDs": test_SED_np, "zernike_GT": test_zks, @@ -1138,6 +1162,7 @@ def main(args): if add_masks: test_psf_dataset["masks"] = test_masks + test_psf_dataset["masked_noisy_stars"] = masked_noisy_test_poly_psf_np if add_ccd_misalignments: test_psf_dataset["zernike_ccd_misalignments"] = test_delta_Z3_arr From 9c2da9683c81e0f8343745284dbcf12c09f3393d Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 7 Jul 2025 15:29:07 +0200 Subject: [PATCH 131/146] add noshift generation for the test dataset --- data/generation/data_generation_script.py | 32 ++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/data/generation/data_generation_script.py b/data/generation/data_generation_script.py index dd8f93fb..2321fb41 100644 --- a/data/generation/data_generation_script.py +++ b/data/generation/data_generation_script.py @@ -561,6 +561,7 @@ def main(args): # ------------ # # Centroid shifts + no_shift_test_zks = None if add_intrapixel_shifts: @@ -599,7 +600,7 @@ def main(args): # Add the centroid shifts to the Zernike coefficients train_zks += train_delta_centroid_shifts - # TO_DEFINE: For now we do add the centroid shifts to the test dataset. Should we? + no_shift_test_zks = np.copy(test_zks) test_zks += test_delta_centroid_shifts # ------------ # @@ -753,12 +754,24 @@ def main(args): sim_PSF_toolkit.generate_poly_PSF(test_SED_list[it], n_bins=n_bins) ) + # Generate test polychromatic PSFs without shifts + if no_shift_test_zks is not None: + test_poly_psf_noshift_list = [] + print("Generate test PSFs at observation resolution without shifts") + for it in tqdm(range(no_shift_test_zks.shape[0])): + sim_PSF_toolkit.set_z_coeffs(no_shift_test_zks[it, :]) + test_poly_psf_noshift_list.append( + sim_PSF_toolkit.generate_poly_PSF(test_SED_list[it], n_bins=n_bins) + ) + # Generate numpy arrays from the lists train_poly_psf_np = np.array(train_poly_psf_list) train_SED_np = np.array(train_SED_list) test_poly_psf_np = np.array(test_poly_psf_list) test_SED_np = np.array(test_SED_list) + if no_shift_test_zks is not None: + test_poly_psf_noshift_np = np.array(test_poly_psf_noshift_list) # Generate the noisy train stars # Copy the training stars @@ -1116,8 +1129,21 @@ def main(args): SR_sim_PSF_toolkit.generate_poly_PSF(test_SED_list[it_j], n_bins=n_bins) ) + # Generate the test super resolved (SR) polychromatic PSFs without shifts + if no_shift_test_zks is not None: + SR_test_poly_psf_noshift_list = [] + + print("Generate testing SR PSFs no shifts") + for it_j in tqdm(range(n_test_stars)): + SR_sim_PSF_toolkit.set_z_coeffs(no_shift_test_zks[it_j, :]) + SR_test_poly_psf_noshift_list.append( + SR_sim_PSF_toolkit.generate_poly_PSF(test_SED_list[it_j], n_bins=n_bins) + ) + # Generate numpy arrays from the lists SR_test_poly_psf_np = np.array(SR_test_poly_psf_list) + if no_shift_test_zks is not None: + SR_test_poly_psf_noshift_np = np.array(SR_test_poly_psf_noshift_list) # ------------ # # Save test datasets @@ -1167,6 +1193,10 @@ def main(args): if add_ccd_misalignments: test_psf_dataset["zernike_ccd_misalignments"] = test_delta_Z3_arr + if no_shift_test_zks is not None: + test_psf_dataset["stars_noshift"] = test_poly_psf_noshift_np + test_psf_dataset["SR_stars_noshift"] = SR_test_poly_psf_noshift_np + if add_intrapixel_shifts: test_psf_dataset["zernike_centroid_shifts"] = test_delta_centroid_shifts test_psf_dataset["pix_centroid_shifts"] = np.stack( From 8fe11d79efa22e8ca2ae029bcb27167514bf7c59 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 7 Jul 2025 15:31:53 +0200 Subject: [PATCH 132/146] update to new paths --- data/generation/data_generation_script.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/data/generation/data_generation_script.py b/data/generation/data_generation_script.py index 2321fb41..a265ded2 100644 --- a/data/generation/data_generation_script.py +++ b/data/generation/data_generation_script.py @@ -25,9 +25,9 @@ ) from wf_psf.utils.read_config import read_conf, RecursiveNamespace from wf_psf.sims.psf_simulator import PSFSimulator -from wf_psf.utils.preprocessing import shift_x_y_to_zk1_2_wavediff +from wf_psf.data.data_zernike_utils import shift_x_y_to_zk1_2_wavediff from wf_psf.sims.spatial_varying_psf import SpatialVaryingPSF, ZernikeHelper -from wf_psf.utils.ccd_misalignments import CCDMisalignmentCalculator +from wf_psf.instrument.ccd_misalignments import CCDMisalignmentCalculator # Pre-defined colormap From 624b101bc9170c1d6e81c540579ba26775697420 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 7 Jul 2025 16:32:46 +0200 Subject: [PATCH 133/146] update the intrapixel shift range --- data/generation/data_generation_params_v.0.1.0.yml | 2 +- data/generation/data_generation_params_v.0.2.0.yml | 2 +- data/generation/data_generation_params_v.0.3.0.yml | 2 +- data/generation/data_generation_params_v.1.1.0.yml | 2 +- data/generation/data_generation_params_v.1.2.0.yml | 2 +- data/generation/data_generation_params_v.1.3.0.yml | 2 +- data/generation/data_generation_params_v.2.1.0.yml | 2 +- data/generation/data_generation_params_v.2.2.0.yml | 2 +- data/generation/data_generation_params_v.2.3.0.yml | 2 +- data/generation/data_generation_params_v.3.1.0.yml | 2 +- data/generation/data_generation_params_v.3.2.0.yml | 2 +- data/generation/data_generation_params_v.3.3.0.yml | 2 +- 12 files changed, 12 insertions(+), 12 deletions(-) diff --git a/data/generation/data_generation_params_v.0.1.0.yml b/data/generation/data_generation_params_v.0.1.0.yml index e373cd06..221a81c2 100644 --- a/data/generation/data_generation_params_v.0.1.0.yml +++ b/data/generation/data_generation_params_v.0.1.0.yml @@ -106,7 +106,7 @@ dataset_features: # Centroid shift options add_intrapixel_shifts: False # In pixels (should be abs(limits)<0.5) - intrapixel_shift_range: [-0.5, 0.5] + intrapixel_shift_range: [-0.3, 0.3] # CCD misalignment options add_ccd_misalignments: False diff --git a/data/generation/data_generation_params_v.0.2.0.yml b/data/generation/data_generation_params_v.0.2.0.yml index 804d9178..2708a649 100644 --- a/data/generation/data_generation_params_v.0.2.0.yml +++ b/data/generation/data_generation_params_v.0.2.0.yml @@ -106,7 +106,7 @@ dataset_features: # Centroid shift options add_intrapixel_shifts: False # In pixels (should be abs(limits)<0.5) - intrapixel_shift_range: [-0.5, 0.5] + intrapixel_shift_range: [-0.3, 0.3] # CCD misalignment options add_ccd_misalignments: False diff --git a/data/generation/data_generation_params_v.0.3.0.yml b/data/generation/data_generation_params_v.0.3.0.yml index 5c351dc9..d7487e2b 100644 --- a/data/generation/data_generation_params_v.0.3.0.yml +++ b/data/generation/data_generation_params_v.0.3.0.yml @@ -106,7 +106,7 @@ dataset_features: # Centroid shift options add_intrapixel_shifts: False # In pixels (should be abs(limits)<0.5) - intrapixel_shift_range: [-0.5, 0.5] + intrapixel_shift_range: [-0.3, 0.3] # CCD misalignment options add_ccd_misalignments: False diff --git a/data/generation/data_generation_params_v.1.1.0.yml b/data/generation/data_generation_params_v.1.1.0.yml index 14baccbf..f19d7719 100644 --- a/data/generation/data_generation_params_v.1.1.0.yml +++ b/data/generation/data_generation_params_v.1.1.0.yml @@ -106,7 +106,7 @@ dataset_features: # Centroid shift options add_intrapixel_shifts: True # In pixels (should be abs(limits)<0.5) - intrapixel_shift_range: [-0.5, 0.5] + intrapixel_shift_range: [-0.3, 0.3] # CCD misalignment options add_ccd_misalignments: True diff --git a/data/generation/data_generation_params_v.1.2.0.yml b/data/generation/data_generation_params_v.1.2.0.yml index 6e489eaa..e0a9953c 100644 --- a/data/generation/data_generation_params_v.1.2.0.yml +++ b/data/generation/data_generation_params_v.1.2.0.yml @@ -106,7 +106,7 @@ dataset_features: # Centroid shift options add_intrapixel_shifts: True # In pixels (should be abs(limits)<0.5) - intrapixel_shift_range: [-0.5, 0.5] + intrapixel_shift_range: [-0.3, 0.3] # CCD misalignment options add_ccd_misalignments: True diff --git a/data/generation/data_generation_params_v.1.3.0.yml b/data/generation/data_generation_params_v.1.3.0.yml index 8b48221d..542a1140 100644 --- a/data/generation/data_generation_params_v.1.3.0.yml +++ b/data/generation/data_generation_params_v.1.3.0.yml @@ -106,7 +106,7 @@ dataset_features: # Centroid shift options add_intrapixel_shifts: True # In pixels (should be abs(limits)<0.5) - intrapixel_shift_range: [-0.5, 0.5] + intrapixel_shift_range: [-0.3, 0.3] # CCD misalignment options add_ccd_misalignments: True diff --git a/data/generation/data_generation_params_v.2.1.0.yml b/data/generation/data_generation_params_v.2.1.0.yml index a06903ca..f09542b1 100644 --- a/data/generation/data_generation_params_v.2.1.0.yml +++ b/data/generation/data_generation_params_v.2.1.0.yml @@ -106,7 +106,7 @@ dataset_features: # Centroid shift options add_intrapixel_shifts: True # In pixels (should be abs(limits)<0.5) - intrapixel_shift_range: [-0.5, 0.5] + intrapixel_shift_range: [-0.3, 0.3] # CCD misalignment options add_ccd_misalignments: True diff --git a/data/generation/data_generation_params_v.2.2.0.yml b/data/generation/data_generation_params_v.2.2.0.yml index 7e2354b6..6c37148e 100644 --- a/data/generation/data_generation_params_v.2.2.0.yml +++ b/data/generation/data_generation_params_v.2.2.0.yml @@ -106,7 +106,7 @@ dataset_features: # Centroid shift options add_intrapixel_shifts: True # In pixels (should be abs(limits)<0.5) - intrapixel_shift_range: [-0.5, 0.5] + intrapixel_shift_range: [-0.3, 0.3] # CCD misalignment options add_ccd_misalignments: True diff --git a/data/generation/data_generation_params_v.2.3.0.yml b/data/generation/data_generation_params_v.2.3.0.yml index 27f644ea..63c76356 100644 --- a/data/generation/data_generation_params_v.2.3.0.yml +++ b/data/generation/data_generation_params_v.2.3.0.yml @@ -106,7 +106,7 @@ dataset_features: # Centroid shift options add_intrapixel_shifts: True # In pixels (should be abs(limits)<0.5) - intrapixel_shift_range: [-0.5, 0.5] + intrapixel_shift_range: [-0.3, 0.3] # CCD misalignment options add_ccd_misalignments: True diff --git a/data/generation/data_generation_params_v.3.1.0.yml b/data/generation/data_generation_params_v.3.1.0.yml index fd3b1c35..e1fc2f7b 100644 --- a/data/generation/data_generation_params_v.3.1.0.yml +++ b/data/generation/data_generation_params_v.3.1.0.yml @@ -106,7 +106,7 @@ dataset_features: # Centroid shift options add_intrapixel_shifts: True # In pixels (should be abs(limits)<0.5) - intrapixel_shift_range: [-0.5, 0.5] + intrapixel_shift_range: [-0.3, 0.3] # CCD misalignment options add_ccd_misalignments: True diff --git a/data/generation/data_generation_params_v.3.2.0.yml b/data/generation/data_generation_params_v.3.2.0.yml index 224ecdeb..2c3c35dd 100644 --- a/data/generation/data_generation_params_v.3.2.0.yml +++ b/data/generation/data_generation_params_v.3.2.0.yml @@ -106,7 +106,7 @@ dataset_features: # Centroid shift options add_intrapixel_shifts: True # In pixels (should be abs(limits)<0.5) - intrapixel_shift_range: [-0.5, 0.5] + intrapixel_shift_range: [-0.3, 0.3] # CCD misalignment options add_ccd_misalignments: True diff --git a/data/generation/data_generation_params_v.3.3.0.yml b/data/generation/data_generation_params_v.3.3.0.yml index bc457818..9e13d377 100644 --- a/data/generation/data_generation_params_v.3.3.0.yml +++ b/data/generation/data_generation_params_v.3.3.0.yml @@ -106,7 +106,7 @@ dataset_features: # Centroid shift options add_intrapixel_shifts: True # In pixels (should be abs(limits)<0.5) - intrapixel_shift_range: [-0.5, 0.5] + intrapixel_shift_range: [-0.3, 0.3] # CCD misalignment options add_ccd_misalignments: True From 91c4c8420261ca9cb0e9ac5c3e566cdc7d6063f7 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Tue, 8 Jul 2025 13:12:35 +0200 Subject: [PATCH 134/146] fix small bug for unitary masks --- data/generation/data_generation_script.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/data/generation/data_generation_script.py b/data/generation/data_generation_script.py index a265ded2..86ebdd1b 100644 --- a/data/generation/data_generation_script.py +++ b/data/generation/data_generation_script.py @@ -810,6 +810,8 @@ def main(args): if add_masks: + masked_noisy_test_poly_psf_np = np.copy(noisy_test_poly_psf_np) + if mask_type == "random": # Generate random train masks train_masks = generate_n_mask( @@ -826,7 +828,6 @@ def main(args): noisy_train_poly_psf_np.dtype ) - masked_noisy_test_poly_psf_np = np.copy(noisy_test_poly_psf_np) # Apply the random masks to the test stars masked_noisy_test_poly_psf_np = ( masked_noisy_test_poly_psf_np From 91748cd55997074509e78df7b2a088ea13dcb4a0 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Tue, 8 Jul 2025 17:12:07 +0200 Subject: [PATCH 135/146] fix chi2 bug --- src/wf_psf/metrics/metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/wf_psf/metrics/metrics.py b/src/wf_psf/metrics/metrics.py index f992d187..18d66d3a 100644 --- a/src/wf_psf/metrics/metrics.py +++ b/src/wf_psf/metrics/metrics.py @@ -292,7 +292,7 @@ def compute_chi2_metric( # Standardize residuals standardized_residuals = np.array( [ - (residual - np.sum(residual) / np.sum(mask)) / std_est + (residual - np.mean(residual)) / (np.sum(mask) * std_est) for residual, mask, std_est in zip( residuals, masks, estimated_noise_std_dev ) From 1e1308efaa8083622daedc9bb15c27c56de0a653 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Tue, 8 Jul 2025 17:31:28 +0200 Subject: [PATCH 136/146] add chi2 per image calculation --- src/wf_psf/metrics/metrics.py | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/src/wf_psf/metrics/metrics.py b/src/wf_psf/metrics/metrics.py index 18d66d3a..4cbb181e 100644 --- a/src/wf_psf/metrics/metrics.py +++ b/src/wf_psf/metrics/metrics.py @@ -289,15 +289,23 @@ def compute_chi2_metric( # Compute residuals residuals = (reference_stars - preds) * masks - # Standardize residuals + # Standardize residuals -> remove mean and divide by std dev standardized_residuals = np.array( [ - (residual - np.mean(residual)) / (np.sum(mask) * std_est) + (residual - np.sum(residual) / np.sum(mask)) / std_est for residual, mask, std_est in zip( residuals, masks, estimated_noise_std_dev ) ] ) + # Per-image reduced chi2 statistic + reduced_chi2_stat_per_image = np.array( + [ + np.sum((standardized_residual * mask) ** 2) / (np.sum(mask)) + for standardized_residual, mask in zip(standardized_residuals, masks) + ] + ) + # Compute the degrees of freedom and the mean degrees_of_freedom = np.sum(masks) mean_standardized_residuals = np.sum(standardized_residuals) / degrees_of_freedom @@ -310,13 +318,23 @@ def compute_chi2_metric( mean_noise_std_dev = np.mean(estimated_noise_std_dev) median_noise_std_dev = np.median(estimated_noise_std_dev) - # TODO: Compute the reduced chi2 for each star. Show median and mean values. - # Print chi2 values + # Print chi2 results logger.info("Reduced chi2:\t\t\t %.5e" % (reduced_chi2_stat)) - logger.info("Average noise std dev:\t %.5e" % (mean_noise_std_dev)) + + logger.info( + "Average chi2 per image:\t\t %.5e" % (np.mean(reduced_chi2_stat_per_image)) + ) + logger.info( + "Median chi2 per image:\t\t %.5e" % (np.median(reduced_chi2_stat_per_image)) + ) + logger.info( + "Std dev chi2 per image:\t\t %.5e" % (np.std(reduced_chi2_stat_per_image)) + ) + + logger.info("Average noise std dev:\t\t %.5e" % (mean_noise_std_dev)) logger.info("Median noise std dev:\t\t %.5e" % (median_noise_std_dev)) - return reduced_chi2_stat, mean_noise_std_dev + return reduced_chi2_stat, mean_noise_std_dev, reduced_chi2_stat_per_image def compute_mono_metric( From b01c4d77eb8be13993329b4c82281463455b4ce4 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Tue, 8 Jul 2025 17:31:41 +0200 Subject: [PATCH 137/146] add per image results --- src/wf_psf/metrics/metrics_interface.py | 35 ++++++++++++++----------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/src/wf_psf/metrics/metrics_interface.py b/src/wf_psf/metrics/metrics_interface.py index ef72f9a7..3965e4e0 100644 --- a/src/wf_psf/metrics/metrics_interface.py +++ b/src/wf_psf/metrics/metrics_interface.py @@ -155,27 +155,30 @@ def evaluate_metrics_chi2( mask = self.trained_model.training_hparams.loss == "mask_mse" # Compute metrics - reduced_chi2_stat, mean_noise_std_dev = wf_metrics.compute_chi2_metric( - tf_trained_psf_model=psf_model, - gt_tf_psf_model=psf_models.get_psf_model( - self.metrics_params.ground_truth_model.model_params, - self.metrics_params.metrics_hparams, - data, - dataset.get("C_poly", None), # Extract C_poly or default to None - ), - simPSF_np=simPSF, - tf_pos=dataset["positions"], - tf_SEDs=dataset["SEDs"], - n_bins_lda=self.trained_model.model_params.n_bins_lda, - n_bins_gt=self.metrics_params.ground_truth_model.model_params.n_bins_lda, - batch_size=self.metrics_params.metrics_hparams.batch_size, - dataset_dict=dataset, - mask=mask, + reduced_chi2_stat, mean_noise_std_dev, reduced_chi2_stat_per_image = ( + wf_metrics.compute_chi2_metric( + tf_trained_psf_model=psf_model, + gt_tf_psf_model=psf_models.get_psf_model( + self.metrics_params.ground_truth_model.model_params, + self.metrics_params.metrics_hparams, + data, + dataset.get("C_poly", None), # Extract C_poly or default to None + ), + simPSF_np=simPSF, + tf_pos=dataset["positions"], + tf_SEDs=dataset["SEDs"], + n_bins_lda=self.trained_model.model_params.n_bins_lda, + n_bins_gt=self.metrics_params.ground_truth_model.model_params.n_bins_lda, + batch_size=self.metrics_params.metrics_hparams.batch_size, + dataset_dict=dataset, + mask=mask, + ) ) return { "reduced_chi2": reduced_chi2_stat, "mean_noise_std_dev": mean_noise_std_dev, + "reduced_chi2_stat_per_image": reduced_chi2_stat_per_image, } def evaluate_metrics_mono_rmse( From 3b94c6ffae96e73ecea223a2a5ef8f5a8070e53b Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Tue, 8 Jul 2025 17:36:20 +0200 Subject: [PATCH 138/146] add more output statistics for the chi2 metric --- src/wf_psf/metrics/metrics.py | 27 +++++++++----- src/wf_psf/metrics/metrics_interface.py | 48 +++++++++++++++---------- 2 files changed, 47 insertions(+), 28 deletions(-) diff --git a/src/wf_psf/metrics/metrics.py b/src/wf_psf/metrics/metrics.py index 4cbb181e..85170f8b 100644 --- a/src/wf_psf/metrics/metrics.py +++ b/src/wf_psf/metrics/metrics.py @@ -314,27 +314,36 @@ def compute_chi2_metric( ((standardized_residuals - mean_standardized_residuals) * masks) ** 2 ) / (degrees_of_freedom - 1) - # Average std deviation + # Compute the average and media values of the noise std deviation mean_noise_std_dev = np.mean(estimated_noise_std_dev) median_noise_std_dev = np.median(estimated_noise_std_dev) + # Compute the average, median and std dev of the reduced chi2 statistic per image + mean_reduced_chi2_stat_per_image = np.mean(reduced_chi2_stat_per_image) + median_reduced_chi2_stat_per_image = np.median(reduced_chi2_stat_per_image) + std_reduced_chi2_stat_per_image = np.std(reduced_chi2_stat_per_image) + # Print chi2 results logger.info("Reduced chi2:\t\t\t %.5e" % (reduced_chi2_stat)) + logger.info("Average chi2 per image:\t\t %.5e" % (mean_reduced_chi2_stat_per_image)) logger.info( - "Average chi2 per image:\t\t %.5e" % (np.mean(reduced_chi2_stat_per_image)) - ) - logger.info( - "Median chi2 per image:\t\t %.5e" % (np.median(reduced_chi2_stat_per_image)) - ) - logger.info( - "Std dev chi2 per image:\t\t %.5e" % (np.std(reduced_chi2_stat_per_image)) + "Median chi2 per image:\t\t %.5e" % (median_reduced_chi2_stat_per_image) ) + logger.info("Std dev chi2 per image:\t\t %.5e" % (std_reduced_chi2_stat_per_image)) logger.info("Average noise std dev:\t\t %.5e" % (mean_noise_std_dev)) logger.info("Median noise std dev:\t\t %.5e" % (median_noise_std_dev)) - return reduced_chi2_stat, mean_noise_std_dev, reduced_chi2_stat_per_image + return ( + reduced_chi2_stat, + reduced_chi2_stat_per_image, + mean_reduced_chi2_stat_per_image, + median_reduced_chi2_stat_per_image, + std_reduced_chi2_stat_per_image, + mean_noise_std_dev, + estimated_noise_std_dev, + ) def compute_mono_metric( diff --git a/src/wf_psf/metrics/metrics_interface.py b/src/wf_psf/metrics/metrics_interface.py index 3965e4e0..c06819ef 100644 --- a/src/wf_psf/metrics/metrics_interface.py +++ b/src/wf_psf/metrics/metrics_interface.py @@ -155,30 +155,40 @@ def evaluate_metrics_chi2( mask = self.trained_model.training_hparams.loss == "mask_mse" # Compute metrics - reduced_chi2_stat, mean_noise_std_dev, reduced_chi2_stat_per_image = ( - wf_metrics.compute_chi2_metric( - tf_trained_psf_model=psf_model, - gt_tf_psf_model=psf_models.get_psf_model( - self.metrics_params.ground_truth_model.model_params, - self.metrics_params.metrics_hparams, - data, - dataset.get("C_poly", None), # Extract C_poly or default to None - ), - simPSF_np=simPSF, - tf_pos=dataset["positions"], - tf_SEDs=dataset["SEDs"], - n_bins_lda=self.trained_model.model_params.n_bins_lda, - n_bins_gt=self.metrics_params.ground_truth_model.model_params.n_bins_lda, - batch_size=self.metrics_params.metrics_hparams.batch_size, - dataset_dict=dataset, - mask=mask, - ) + ( + reduced_chi2_stat, + reduced_chi2_stat_per_image, + mean_reduced_chi2_stat_per_image, + median_reduced_chi2_stat_per_image, + std_reduced_chi2_stat_per_image, + mean_noise_std_dev, + estimated_noise_std_dev, + ) = wf_metrics.compute_chi2_metric( + tf_trained_psf_model=psf_model, + gt_tf_psf_model=psf_models.get_psf_model( + self.metrics_params.ground_truth_model.model_params, + self.metrics_params.metrics_hparams, + data, + dataset.get("C_poly", None), # Extract C_poly or default to None + ), + simPSF_np=simPSF, + tf_pos=dataset["positions"], + tf_SEDs=dataset["SEDs"], + n_bins_lda=self.trained_model.model_params.n_bins_lda, + n_bins_gt=self.metrics_params.ground_truth_model.model_params.n_bins_lda, + batch_size=self.metrics_params.metrics_hparams.batch_size, + dataset_dict=dataset, + mask=mask, ) return { "reduced_chi2": reduced_chi2_stat, - "mean_noise_std_dev": mean_noise_std_dev, "reduced_chi2_stat_per_image": reduced_chi2_stat_per_image, + "mean_reduced_chi2_stat_per_image": mean_reduced_chi2_stat_per_image, + "median_reduced_chi2_stat_per_image": median_reduced_chi2_stat_per_image, + "std_reduced_chi2_stat_per_image": std_reduced_chi2_stat_per_image, + "mean_noise_std_dev": mean_noise_std_dev, + "estimated_noise_std_dev": estimated_noise_std_dev, } def evaluate_metrics_mono_rmse( From aa52b50ea6a5562507ba2a612e204ade386fc402 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Tue, 8 Jul 2025 17:39:52 +0200 Subject: [PATCH 139/146] update default metric config to include chi2 option --- config/metrics_config.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/config/metrics_config.yaml b/config/metrics_config.yaml index cfaca9b9..e51b5ead 100644 --- a/config/metrics_config.yaml +++ b/config/metrics_config.yaml @@ -12,6 +12,9 @@ metrics: # Name of the Trained Model Config file stored in config sub-directory in the trained_model_path parent directory trained_model_config: + # Evaluate the chi2 metric. + eval_chi2_metric: True + # Evaluate the monchromatic RMSE metric. eval_mono_metric: True From d5d08acd6f235cf0c8f3e3a1bf132f98b9e6699d Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Fri, 19 Sep 2025 14:27:08 +0200 Subject: [PATCH 140/146] merge feature/159-psf-output-from-trained-model with real data metrics --- src/wf_psf/data/data_handler.py | 1 - src/wf_psf/inference/psf_inference.py | 2 -- src/wf_psf/psf_models/tf_modules/tf_utils.py | 3 +++ 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index ae690c69..c18b9bf7 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -17,7 +17,6 @@ import wf_psf.utils.utils as utils from wf_psf.psf_models.tf_modules.tf_utils import ensure_tensor import tensorflow as tf -from fractions import Fraction from typing import Optional, Union import logging diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index c7d73249..e660154b 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -386,14 +386,12 @@ def compute_psfs(self, positions: tf.Tensor, sed_data: tf.Tensor) -> np.ndarray: batch_pos = positions[counter:end_sample, :] batch_seds = sed_data[counter:end_sample, :, :] batch_inputs = [batch_pos, batch_seds] - # Generate PSFs for the current batch batch_psfs = self.trained_model(batch_inputs, training=False) self.inferred_psfs[counter:end_sample, :, :] = batch_psfs.numpy() # Update the counter counter = end_sample - return self._inferred_psfs def get_psfs(self) -> np.ndarray: diff --git a/src/wf_psf/psf_models/tf_modules/tf_utils.py b/src/wf_psf/psf_models/tf_modules/tf_utils.py index 4bd1246a..49f3471f 100644 --- a/src/wf_psf/psf_models/tf_modules/tf_utils.py +++ b/src/wf_psf/psf_models/tf_modules/tf_utils.py @@ -83,7 +83,10 @@ def ensure_tensor(input_array, dtype=tf.float32): The input to convert. dtype : tf.DType, optional The desired TensorFlow dtype (default: tf.float32). +<<<<<<< HEAD +======= +>>>>>>> f2d8aa4 (merge feature/159-psf-output-from-trained-model with real data metrics) Returns ------- tf.Tensor From 59e72ac964b8aa23159a45f5ba4093b84cb12f3e Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Wed, 24 Sep 2025 11:27:59 +0200 Subject: [PATCH 141/146] compute the metrics on `noisy_stars` if `stars` is not in the data. Regenerate if it is not there --- src/wf_psf/metrics/metrics.py | 36 +++++++++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/src/wf_psf/metrics/metrics.py b/src/wf_psf/metrics/metrics.py index 85170f8b..f6876c8b 100644 --- a/src/wf_psf/metrics/metrics.py +++ b/src/wf_psf/metrics/metrics.py @@ -93,7 +93,11 @@ def compute_poly_metric( preds = tf_semiparam_field.predict(x=pred_inputs, batch_size=batch_size) # Ground truth data preparation - if dataset_dict is None or "stars" not in dataset_dict: + if ( + dataset_dict is None + or "stars" not in dataset_dict + or "noisy_stars" not in dataset_dict + ): logger.info( "No precomputed ground truth stars found. Regenerating from the ground truth model using configured interpolation settings." ) @@ -113,8 +117,16 @@ def compute_poly_metric( gt_preds = gt_tf_semiparam_field.predict(x=pred_inputs, batch_size=batch_size) else: - logger.info("Using precomputed ground truth stars from dataset_dict['stars'].") - gt_preds = dataset_dict["stars"] + if "stars" in dataset_dict: + gt_preds = dataset_dict["stars"] + logger.info( + "Using precomputed ground truth stars from dataset_dict['stars']." + ) + elif "noisy_stars" in dataset_dict: + gt_preds = dataset_dict["noisy_stars"] + logger.info( + "Using precomputed noisy ground truth stars from dataset_dict['noisy_stars']." + ) # If the data is masked, mask the predictions if mask: @@ -228,7 +240,11 @@ def compute_chi2_metric( preds = tf_trained_psf_model.predict(x=pred_inputs, batch_size=batch_size) # Ground truth data preparation - if dataset_dict is None or "stars" not in dataset_dict: + if ( + dataset_dict is None + or "stars" not in dataset_dict + or "noisy_stars" not in dataset_dict + ): logger.info( "No precomputed ground truth stars found. Regenerating from the ground truth model using configured interpolation settings." ) @@ -251,8 +267,16 @@ def compute_chi2_metric( reference_stars = gt_tf_psf_model.predict(x=pred_inputs, batch_size=batch_size) else: - logger.info("Using precomputed ground truth stars from dataset_dict['stars'].") - reference_stars = dataset_dict["stars"] + if "stars" in dataset_dict: + reference_stars = dataset_dict["stars"] + logger.info( + "Using precomputed ground truth stars from dataset_dict['stars']." + ) + elif "noisy_stars" in dataset_dict: + reference_stars = dataset_dict["noisy_stars"] + logger.info( + "Using precomputed noisy ground truth stars from dataset_dict['noisy_stars']." + ) # If the data is masked, mask the predictions if mask: From 6e119c47f4a076eed3ac07fd859565c3ec856462 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Wed, 24 Sep 2025 12:09:32 +0200 Subject: [PATCH 142/146] fix small bug in nested conditionals --- src/wf_psf/metrics/metrics.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/wf_psf/metrics/metrics.py b/src/wf_psf/metrics/metrics.py index f6876c8b..76895cc7 100644 --- a/src/wf_psf/metrics/metrics.py +++ b/src/wf_psf/metrics/metrics.py @@ -93,10 +93,8 @@ def compute_poly_metric( preds = tf_semiparam_field.predict(x=pred_inputs, batch_size=batch_size) # Ground truth data preparation - if ( - dataset_dict is None - or "stars" not in dataset_dict - or "noisy_stars" not in dataset_dict + if dataset_dict is None or ( + "stars" not in dataset_dict and "noisy_stars" not in dataset_dict ): logger.info( "No precomputed ground truth stars found. Regenerating from the ground truth model using configured interpolation settings." @@ -240,10 +238,8 @@ def compute_chi2_metric( preds = tf_trained_psf_model.predict(x=pred_inputs, batch_size=batch_size) # Ground truth data preparation - if ( - dataset_dict is None - or "stars" not in dataset_dict - or "noisy_stars" not in dataset_dict + if dataset_dict is None or ( + "stars" not in dataset_dict and "noisy_stars" not in dataset_dict ): logger.info( "No precomputed ground truth stars found. Regenerating from the ground truth model using configured interpolation settings." From 76373fc1c290ca48037abd21ad823546bebaaff5 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Wed, 26 Nov 2025 14:50:48 +0100 Subject: [PATCH 143/146] Update evaluate_model unit test with patch for chi2 metric --- src/wf_psf/tests/test_metrics/metrics_interface_test.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/wf_psf/tests/test_metrics/metrics_interface_test.py b/src/wf_psf/tests/test_metrics/metrics_interface_test.py index b994196e..28653761 100644 --- a/src/wf_psf/tests/test_metrics/metrics_interface_test.py +++ b/src/wf_psf/tests/test_metrics/metrics_interface_test.py @@ -67,6 +67,10 @@ def test_evaluate_model( "wf_psf.metrics.metrics_interface.MetricsParamsHandler.evaluate_metrics_polychromatic_lowres", new_callable=MagicMock, ) as mock_evaluate_poly_metric, + patch( + "wf_psf.metrics.metrics_interface.MetricsParamsHandler.evaluate_metrics_chi2", + new_callable=MagicMock, + ) as mock_evaluate_chi2_metric, patch( "wf_psf.metrics.metrics_interface.MetricsParamsHandler.evaluate_metrics_mono_rmse", new_callable=MagicMock, @@ -98,6 +102,9 @@ def test_evaluate_model( assert ( mock_evaluate_poly_metric.call_count == 2 ) # Called twice, once for each dataset + assert ( + mock_evaluate_chi2_metric.call_count == 2 + ) # Called twice, once for each dataset assert ( mock_evaluate_mono_metric.call_count == 2 ) # Called twice, once for each dataset From 81a94fcb8de7a863ea59033fb8211912ce6ee138 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Wed, 26 Nov 2025 14:51:34 +0100 Subject: [PATCH 144/146] Update train_util tests with mock data including binary mask --- .../tests/test_training/train_utils_test.py | 38 +++++++++++++++++-- src/wf_psf/training/train_utils.py | 1 - 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/src/wf_psf/tests/test_training/train_utils_test.py b/src/wf_psf/tests/test_training/train_utils_test.py index 703531c8..ecb4d822 100644 --- a/src/wf_psf/tests/test_training/train_utils_test.py +++ b/src/wf_psf/tests/test_training/train_utils_test.py @@ -178,9 +178,21 @@ def test_calculate_sample_weights_integration( @pytest.mark.parametrize( "loss", [None, "mean_squared_error", "masked_mean_squared_error"] ) -def test_calculate_sample_weights_unit(mock_noise_estimator, loss): +def test_calculate_sample_weights_unit(loss): """Test sample weighting strategy with random images.""" - outputs = np.random.rand(10, 32, 32) # 10 images of size 32x32 + # Generate dummy image data + batch_size, height, width = 10, 32, 32 + + if loss == "masked_mean_squared_error": + # Create image-mask pairs: last dimension has [image, mask] + outputs = np.random.rand(batch_size, height, width, 2) + outputs[..., 1] = np.random.randint( + 0, 2, size=(batch_size, height, width) + ) # Binary mask + else: + outputs = np.random.rand(batch_size, height, width) + + # Calculate sample weights result = train_utils.calculate_sample_weights( outputs, use_sample_weights=True, loss=loss ) @@ -199,10 +211,28 @@ def test_calculate_sample_weights_unit(mock_noise_estimator, loss): @pytest.mark.parametrize( "loss", [None, "mean_squared_error", "masked_mean_squared_error"] ) -def test_calculate_sample_weights_high_variance(mock_noise_estimator, loss): +def test_calculate_sample_weights_high_variance(loss): """Test case for high variance (noisy images).""" # Create high variance images with more noise - outputs = np.random.normal(loc=0.0, scale=10.0, size=(5, 32, 32)) # Larger noise + # Generate dummy image data + batch_size, height, width = 10, 32, 32 + + if loss == "masked_mean_squared_error": + # Create image-mask pairs: last dimension has [image, mask] + outputs = np.zeros((batch_size, height, width, 2), dtype=np.float32) + mask_prob: float = 0.5 # Probability of a pixel being unmasked + # High variance images + outputs[..., 0] = np.random.normal( + loc=0.0, scale=10.0, size=(batch_size, height, width) + ) + # Random masks with adjustable sparsity + outputs[..., 1] = ( + np.random.rand(batch_size, height, width) < mask_prob + ).astype(np.float32) + else: + outputs = np.random.normal( + loc=0.0, scale=10.0, size=(batch_size, height, width) + ) # Larger noise # Calculate sample weights weights = train_utils.calculate_sample_weights( diff --git a/src/wf_psf/training/train_utils.py b/src/wf_psf/training/train_utils.py index e27aeb19..b26c112a 100644 --- a/src/wf_psf/training/train_utils.py +++ b/src/wf_psf/training/train_utils.py @@ -462,7 +462,6 @@ def calculate_sample_weights( An array of sample weights, or None if `use_sample_weights` is False. """ if use_sample_weights: - # Compute noise standard deviation from images if loss is not None and ( (isinstance(loss, str) and loss == "masked_mean_squared_error") From b9cbd6b4d27df2c8035aae4a4e55cc327f8e5d02 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Wed, 26 Nov 2025 15:28:30 +0100 Subject: [PATCH 145/146] Remove duplicated function and unnecessary variable re-assignment --- src/wf_psf/instrument/ccd_misalignments.py | 44 +--------------------- 1 file changed, 1 insertion(+), 43 deletions(-) diff --git a/src/wf_psf/instrument/ccd_misalignments.py b/src/wf_psf/instrument/ccd_misalignments.py index e6cbab74..6b73babe 100644 --- a/src/wf_psf/instrument/ccd_misalignments.py +++ b/src/wf_psf/instrument/ccd_misalignments.py @@ -28,48 +28,6 @@ def compute_ccd_misalignment(model_params, positions: np.ndarray) -> np.ndarray: zernike_ccd_misalignment_array : np.ndarray Numpy array containing the Zernike contributions to model the CCD chip misalignments. """ - obs_positions = positions - - ccd_misalignment_calculator = CCDMisalignmentCalculator( - tiles_path=model_params.ccd_misalignments_input_path, - x_lims=model_params.x_lims, - y_lims=model_params.y_lims, - tel_focal_length=model_params.tel_focal_length, - tel_diameter=model_params.tel_diameter, - ) - # Compute required zernike 4 for each position - zk4_values = np.array( - [ - ccd_misalignment_calculator.get_zk4_from_position(single_pos) - for single_pos in obs_positions - ] - ).reshape(-1, 1) - - # Zero pad array to get shape (n_stars, n_zernike=4) - zernike_ccd_misalignment_array = np.pad( - zk4_values, pad_width=[(0, 0), (3, 0)], mode="constant", constant_values=0 - ) - - return zernike_ccd_misalignment_array - - -def compute_ccd_misalignment(model_params, data): - """Compute CCD misalignment. - - Parameters - ---------- - model_params : RecursiveNamespace - Object containing parameters for this PSF model class. - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - zernike_ccd_misalignment_array : np.ndarray - Numpy array containing the Zernike contributions to model the CCD chip misalignments. - """ - obs_positions = get_np_obs_positions(data) - ccd_misalignment_calculator = CCDMisalignmentCalculator( tiles_path=model_params.ccd_misalignments_input_path, x_lims=model_params.x_lims, @@ -81,7 +39,7 @@ def compute_ccd_misalignment(model_params, data): zk4_values = np.array( [ ccd_misalignment_calculator.get_zk4_from_position(single_pos) - for single_pos in obs_positions + for single_pos in positions ] ).reshape(-1, 1) From 1e500b9fbb72516f6e9a440cc29551be147ffbfe Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Wed, 26 Nov 2025 15:39:22 +0100 Subject: [PATCH 146/146] Remove functions left over from previous rebase --- .../models/psf_model_physical_polychromatic.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index 263f0430..fb8bc902 100644 --- a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -142,23 +142,6 @@ def __init__(self, model_params, training_params, data, coeff_mat=None): zks_total_contribution_np.shape[1], ) - @property - def tf_zernike_OPD(self): - """Lazy loading of the Zernike Optical Path Difference (OPD) layer.""" - if not hasattr(self, "_tf_zernike_OPD"): - self._tf_zernike_OPD = TFZernikeOPD(zernike_maps=self.zernike_maps) - return self._tf_zernike_OPD - - @property - def tf_batch_poly_PSF(self): - """Lazily initialize the batch polychromatic PSF layer.""" - if not hasattr(self, "_tf_batch_poly_PSF"): - obscurations = psfm.tf_obscurations( - pupil_diam=self.model_params.pupil_diameter, - N_filter=self.model_params.LP_filter_length, - rotation_angle=self.model_params.obscuration_rotation_angle, - ) - # Precompute zernike maps as tf.float32 self._zernike_maps = psfm.generate_zernike_maps_3d( n_zernikes=self._n_zks_total, pupil_diam=self.model_params.pupil_diameter