diff --git a/README.md b/README.md index 759a8142..747d387f 100644 --- a/README.md +++ b/README.md @@ -13,3 +13,15 @@ This repository includes: - A differentiable PSF model entirely built in [Tensorflow](https://github.com/tensorflow/tensorflow). - A [numpy-based PSF simulator](https://github.com/CosmoStat/wf-psf/tree/dummy_main/src/wf_psf/sims). - All the scripts, jobs and notebooks required to reproduce the results in [arXiv:2203.04908](http://arxiv.org/abs/2203.04908) and [arXiv:2111.12541](https://arxiv.org/abs/2111.12541). + +--------------------------------------------------------------------------- +NOTICE ABOUT THIRD-PARTY CODE + +This repository contains code copied from TensorFlow Addons +(https://github.com/tensorflow/addons), specifically the +`interpolate_spline` and `types.py` modules. + +Those files are licensed under the Apache License, Version 2.0. +The copyright and license headers are preserved in each copied file. + +All other code in this repository remains under the MIT License. diff --git a/THIRD_PARTY_LICENSES/TFA_LICENSE.txt b/THIRD_PARTY_LICENSES/TFA_LICENSE.txt new file mode 100644 index 00000000..93cc9aae --- /dev/null +++ b/THIRD_PARTY_LICENSES/TFA_LICENSE.txt @@ -0,0 +1,203 @@ +Copyright 2018 The TensorFlow Authors. All rights reserved. + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/changelog.d/20260123_164629_jennifer.pollack_configurable_optimizer.md b/changelog.d/20260123_164629_jennifer.pollack_configurable_optimizer.md new file mode 100644 index 00000000..3d3243cf --- /dev/null +++ b/changelog.d/20260123_164629_jennifer.pollack_configurable_optimizer.md @@ -0,0 +1,43 @@ + + + + +### New features + +- Added configurable optimizer selection system via new `optimizer.py` module with `get_optimizer` function +- Optimizer configuration now supports multiple input types: `RecursiveNamespace` from configs, dictionaries, or string names +- Added support for hyperparameter overrides (learning rate, beta1/beta2, epsilon, amsgrad) via YAML or programmatic configuration +- RectifiedAdam optimizer now dynamically imports TensorFlow Addons only when explicitly specified in configuration + + +### Bug fixes + +- Fixed TensorFlow 2.11 compatibility by automatically using `tf.keras.optimizers.legacy.Adam` for TF < 2.11 + + + + +### Internal changes + +- Refactored `build_PSF_model` to accept either Keras optimizer instances or configuration passed through `get_optimizer` +- Added `interpolation.py` and `types.py` modules with vendored code from TensorFlow Addons repository +- Replaced `tfa.image.interpolate_spline` with local `tfa_interpolate_spline_rbf` implementation +- Added comprehensive unit tests in `test_optimizer.py` and `test_interpolation.py` +- Updated README and added THIRD_PARTY_LICENSE directory with TensorFlow Addons license +- Training now runs on TensorFlow 2.11 without requiring TensorFlow Addons installation +- Removed TensorFlow Addons as a required dependency; RectifiedAdam optimizer now requires explicit TFA installation if needed diff --git a/config/metrics_config.yaml b/config/metrics_config.yaml index cfaca9b9..f8ddd65f 100644 --- a/config/metrics_config.yaml +++ b/config/metrics_config.yaml @@ -141,6 +141,15 @@ metrics: # Batch size to use for the evaluation. batch_size: 16 + # Metrics and model evaluation configuration + optimizer: + name: 'adam' # Only standard Adam used for metrics + learning_rate: 1e-2 + beta_1: 0.9 + beta_2: 0.999 + epsilon: 1e-07 + amsgrad: False + # Save RMS error for each super resolved PSF in the test dataset in addition to the mean across the FOV." # Flag to get Super-Resolution pixel PSF RMSE for each individual test star. # If `True`, the relative pixel RMSE of each star is added to ther saving dictionary. diff --git a/config/training_config.yaml b/config/training_config.yaml index 09a741dd..2201e339 100644 --- a/config/training_config.yaml +++ b/config/training_config.yaml @@ -153,6 +153,10 @@ training: # Default: 'mask_mse'. Choose 'mse' if the dataset does not include a mask. loss: 'mask_mse' + # Optimizer to use during training. Options are: 'adam' or 'rectified_adam'. + optimizer: + name: 'rectified_adam' + multi_cycle_params: # Number of training cycles to perform. Each cycle may use different learning rates or number of epochs. diff --git a/pyproject.toml b/pyproject.toml index d9cc7182..ec2c1711 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,6 @@ dependencies = [ "numpy>=1.26.4,<2.0", "scipy", "tensorflow==2.11.0", - "tensorflow-addons", "tensorflow-estimator", "zernike", "opencv-python", @@ -41,6 +40,10 @@ docs = [ "scriv", ] +addons = [ + "tensorflow-addons", +] + lint = [ "ruff", ] diff --git a/src/wf_psf/metrics/metrics.py b/src/wf_psf/metrics/metrics.py index 4cf611d6..0447d596 100644 --- a/src/wf_psf/metrics/metrics.py +++ b/src/wf_psf/metrics/metrics.py @@ -381,8 +381,9 @@ def compute_shape_metrics( output_Q=1, output_dim=64, batch_size=16, - opt_stars_rel_pix_rmse=False, + optimizer_settings=None, dataset_dict=None, + opt_stars_rel_pix_rmse=False, ): """Compute the pixel, shape and size RMSE of a PSF model. @@ -418,15 +419,18 @@ def compute_shape_metrics( Output dimension of the square PSF stamps. batch_size: int Batch size to process the PSF estimations. - opt_stars_rel_pix_rmse: bool - If `True`, the relative pixel RMSE of each star is added to ther saving dictionary. - The summary statistics are always computed. - Default is `False`. + optimizer_settings: RecursiveNamespace, dict, str, optional + Optimizer configuration (from YAML or programmatically), or string name. dataset_dict: dict Dictionary containing the dataset information. If provided, and if the `'super_res_stars'` key is present, the noiseless super resolved stars from the dataset are used to compute the metrics. Otherwise, the stars are generated from the gt model. Default is `None`. + opt_stars_rel_pix_rmse: bool + If `True`, the relative pixel RMSE of each star is added to ther saving dictionary. + The summary statistics are always computed. + Default is `False`. + Returns ------- @@ -445,8 +449,12 @@ def compute_shape_metrics( gt_tf_semiparam_field.set_output_Q(output_Q=output_Q, output_dim=output_dim) # Need to compile the models again - tf_semiparam_field = build_PSF_model(tf_semiparam_field) - gt_tf_semiparam_field = build_PSF_model(gt_tf_semiparam_field) + tf_semiparam_field = build_PSF_model( + tf_semiparam_field, optimizer=optimizer_settings + ) + gt_tf_semiparam_field = build_PSF_model( + gt_tf_semiparam_field, optimizer=optimizer_settings + ) # Generate SED data list packed_SED_data = [ diff --git a/src/wf_psf/metrics/metrics_interface.py b/src/wf_psf/metrics/metrics_interface.py index 31cdc664..3dff2c6c 100644 --- a/src/wf_psf/metrics/metrics_interface.py +++ b/src/wf_psf/metrics/metrics_interface.py @@ -296,9 +296,10 @@ def evaluate_metrics_shape( tf_pos=dataset["positions"], n_bins_lda=self.trained_model.model_params.n_bins_lda, n_bins_gt=self.metrics_params.ground_truth_model.model_params.n_bins_lda, - batch_size=self.metrics_params.metrics_hparams.batch_size, output_Q=self.metrics_params.metrics_hparams.output_Q, output_dim=self.metrics_params.metrics_hparams.output_dim, + batch_size=self.metrics_params.metrics_hparams.batch_size, + optimizer_settings=self.metrics_params.metrics_hparams.optimizer, opt_stars_rel_pix_rmse=self.metrics_params.metrics_hparams.opt_stars_rel_pix_rmse, dataset_dict=dataset, ) diff --git a/src/wf_psf/psf_models/psf_models.py b/src/wf_psf/psf_models/psf_models.py index dcb6abd9..463d1c52 100644 --- a/src/wf_psf/psf_models/psf_models.py +++ b/src/wf_psf/psf_models/psf_models.py @@ -11,6 +11,7 @@ import tensorflow as tf from wf_psf.sims.psf_simulator import PSFSimulator from wf_psf.utils.utils import zernike_generator +from wf_psf.utils.optimizer import is_optimizer_instance, get_optimizer import glob import logging @@ -160,12 +161,12 @@ def build_PSF_model(model_inst, optimizer=None, loss=None, metrics=None): if loss is None: loss = tf.keras.losses.MeanSquaredError() - # Define optimizer function - if optimizer is None: - optimizer = tf.keras.optimizers.legacy.Adam( - learning_rate=1e-2, beta_1=0.9, beta_2=0.999, epsilon=1e-07, amsgrad=False - ) - + # Handle optimizer: either config object or a Keras optimizer instance + if is_optimizer_instance(optimizer): + pass + else: + optimizer = get_optimizer(optimizer_config=optimizer) + # Define metric functions if metrics is None: metrics = [tf.keras.metrics.MeanSquaredError()] diff --git a/src/wf_psf/psf_models/tf_layers.py b/src/wf_psf/psf_models/tf_layers.py index 6d0681a0..eda43305 100644 --- a/src/wf_psf/psf_models/tf_layers.py +++ b/src/wf_psf/psf_models/tf_layers.py @@ -7,10 +7,10 @@ """ import tensorflow as tf -import tensorflow_addons as tfa from wf_psf.psf_models.tf_modules import TFMonochromaticPSF from wf_psf.utils.utils import calc_poly_position_mat import wf_psf.utils.utils as utils +from wf_psf.utils.interpolation import tfa_interpolate_spline_rbf import logging logger = logging.getLogger(__name__) @@ -577,7 +577,7 @@ def predict(self, positions): # Order 2 means a thin_plate RBF interpolation # All tensors need to expand one dimension to fulfil requirement in # the tfa's interpolate_spline function - A_interp_graph = tfa.image.interpolate_spline( + A_interp_graph = tfa_interpolate_spline_rbf( train_points=tf.expand_dims(self.obs_pos, axis=0), train_values=tf.expand_dims(A_graph_train, axis=0), query_points=tf.expand_dims(positions, axis=0), @@ -758,7 +758,7 @@ def predict(self, positions): # Order 2 means a thin_plate RBF interpolation # All tensors need to expand one dimension to fulfil requirement in # the tfa's interpolate_spline function - A_interp_graph = tfa.image.interpolate_spline( + A_interp_graph = tfa_interpolate_spline_rbf( train_points=tf.expand_dims(self.obs_pos, axis=0), train_values=tf.expand_dims(A_graph_train, axis=0), query_points=tf.expand_dims(positions, axis=0), @@ -895,7 +895,7 @@ def interpolate_all(self, positions): # Order 2 means a thin_plate RBF interpolation # All tensors need to expand one dimension to fulfil requirement in # the tfa's interpolate_spline function - interp_zks = tfa.image.interpolate_spline( + interp_zks = tfa_interpolate_spline_rbf( train_points=tf.expand_dims(self.obs_pos, axis=0), train_values=tf.expand_dims(self.zks_prior, axis=0), query_points=tf.expand_dims(positions, axis=0), diff --git a/src/wf_psf/tests/test_utils/interpolation_test.py b/src/wf_psf/tests/test_utils/interpolation_test.py new file mode 100644 index 00000000..e27f7c30 --- /dev/null +++ b/src/wf_psf/tests/test_utils/interpolation_test.py @@ -0,0 +1,60 @@ +# tests/test_tfa_interpolate_spline.py + +import pytest +import tensorflow as tf +from wf_psf.utils.interpolation import tfa_interpolate_spline_rbf as interpolate_spline + + +@pytest.fixture +def simple_1d_data(): + train_points = tf.constant([[0.0], [1.0], [2.0]], dtype=tf.float32) + train_values = tf.constant([[0.0], [1.0], [4.0]], dtype=tf.float32) + query_points = tf.constant([[0.5], [1.5]], dtype=tf.float32) + return train_points, train_values, query_points + + +def test_output_shape(simple_1d_data): + train_points, train_values, query_points = simple_1d_data + with tf.device("/CPU:0"): + result = interpolate_spline( + train_points=tf.expand_dims(train_points, axis=0), + train_values=tf.expand_dims(train_values, axis=0), + query_points=tf.expand_dims(query_points, axis=0), + order=2, + regularization_weight=0.0, + ) + # Expect shape: [1, n_query, n_values] + assert result.shape == (1, 2, 1) + + +def test_differentiability(simple_1d_data): + train_points, train_values, query_points = simple_1d_data + query = tf.Variable(query_points, dtype=tf.float32) + with tf.device("/CPU:0"): + with tf.GradientTape() as tape: + result = interpolate_spline( + train_points=tf.expand_dims(train_points, axis=0), + train_values=tf.expand_dims(train_values, axis=0), + query_points=tf.expand_dims(query, axis=0), + order=2, + regularization_weight=0.0, + ) + loss = tf.reduce_sum(result) + grad = tape.gradient(loss, query) + + assert grad is not None + assert grad.shape == query.shape + + +@pytest.mark.parametrize("order", [1, 2, 3]) +def test_order_variants(simple_1d_data, order): + train_points, train_values, query_points = simple_1d_data + with tf.device("/CPU:0"): + result = interpolate_spline( + train_points=tf.expand_dims(train_points, axis=0), + train_values=tf.expand_dims(train_values, axis=0), + query_points=tf.expand_dims(query_points, axis=0), + order=order, + regularization_weight=0.0, + ) + assert result.shape == (1, 2, 1) diff --git a/src/wf_psf/tests/test_utils/optimizer_test.py b/src/wf_psf/tests/test_utils/optimizer_test.py new file mode 100644 index 00000000..104e1696 --- /dev/null +++ b/src/wf_psf/tests/test_utils/optimizer_test.py @@ -0,0 +1,111 @@ +import pytest +from types import SimpleNamespace +from wf_psf.utils.optimizer import get_optimizer +from wf_psf.utils.read_config import RecursiveNamespace + + +# Dummy optimizer classes +class DummyAdam: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.name = "Adam" + self.beta_1 = kwargs.get("beta_1") + self.beta_2 = kwargs.get("beta_2") + self.epsilon = kwargs.get("epsilon") + self.amsgrad = kwargs.get("amsgrad") + self.learning_rate = kwargs.get("learning_rate") + + +class DummyLegacyAdam(DummyAdam): + pass + + +class DummyRAdam: + def __init__(self, **kwargs): + self.kwargs = kwargs + self._name = "RectifiedAdam" + self.learning_rate = kwargs.get("learning_rate") + + +# Parametric test for Adam with overrides +@pytest.mark.parametrize( + "optimizer_input, expected_lr, expected_beta1", + [ + ("adam", 0.001, 0.9), + ({"name": "adam", "learning_rate": 0.01, "beta_1": 0.8}, 0.01, 0.8), + (RecursiveNamespace(name="adam", learning_rate=0.02, beta_1=0.85), 0.02, 0.85), + ], +) +def test_adam_optimizer_overrides( + monkeypatch, optimizer_input, expected_lr, expected_beta1 +): + # Mock TF >= 2.11 + fake_tf = SimpleNamespace( + __version__="2.11.0", + keras=SimpleNamespace( + optimizers=SimpleNamespace( + Adam=DummyAdam, legacy=SimpleNamespace(Adam=DummyLegacyAdam) + ) + ), + ) + monkeypatch.setattr("wf_psf.utils.optimizer.tf", fake_tf) + + opt = get_optimizer(optimizer_input) + assert isinstance(opt, DummyAdam) + assert opt.name.lower() == "adam" + assert opt.learning_rate == expected_lr + assert opt.beta_1 == expected_beta1 + + +# Parametric test for RAdam with overrides +@pytest.mark.parametrize( + "optimizer_input, expected_lr", + [ + ("rectified_adam", 0.001), + ({"name": "rectified_adam", "learning_rate": 0.01}, 0.01), + (RecursiveNamespace(name="rectified_adam", learning_rate=0.02), 0.02), + ], +) +def test_radam_optimizer_overrides(monkeypatch, optimizer_input, expected_lr): + # Provide dummy tfa module + dummy_tfa = SimpleNamespace(optimizers=SimpleNamespace(RectifiedAdam=DummyRAdam)) + monkeypatch.setitem(__import__("sys").modules, "tensorflow_addons", dummy_tfa) + + opt = get_optimizer(optimizer_input) + assert isinstance(opt, DummyRAdam) + assert opt._name.lower() == "rectifiedadam" + assert opt.learning_rate == expected_lr + + +def test_legacy_adam_handling(monkeypatch): + """Verify that legacy.Adam is used when TF < 2.11 and parameters are applied correctly.""" + + # Mock TF < 2.11 + fake_tf = SimpleNamespace( + __version__="2.10.0", + keras=SimpleNamespace( + optimizers=SimpleNamespace( + Adam=DummyAdam, legacy=SimpleNamespace(Adam=DummyLegacyAdam) + ) + ), + ) + monkeypatch.setattr("wf_psf.utils.optimizer.tf", fake_tf) + + # Provide RecursiveNamespace input with overrides + opt_config = RecursiveNamespace( + name="adam", + learning_rate=0.02, + beta_1=0.85, + beta_2=0.95, + epsilon=1e-08, + amsgrad=True, + ) + + opt = get_optimizer(opt_config) + assert isinstance(opt, DummyLegacyAdam) + assert opt.name.lower() == "adam" + assert opt.learning_rate == 0.02 + assert opt.beta_1 == 0.85 + assert opt.beta_2 == 0.95 + assert opt.epsilon == 1e-08 + assert opt.amsgrad is True diff --git a/src/wf_psf/training/train.py b/src/wf_psf/training/train.py index d04fb1d2..bb0e3df9 100644 --- a/src/wf_psf/training/train.py +++ b/src/wf_psf/training/train.py @@ -10,10 +10,10 @@ import numpy as np import time import tensorflow as tf -import tensorflow_addons as tfa import logging from wf_psf.psf_models import psf_models import wf_psf.training.train_utils as train_utils +from wf_psf.utils.optimizer import get_optimizer logger = logging.getLogger(__name__) @@ -428,10 +428,12 @@ def train( ) # Prepare the optimizers - param_optim = tfa.optimizers.RectifiedAdam( - learning_rate=training_handler.learning_rate_params[current_cycle - 1] + param_optim = get_optimizer( + optimizer_config=training_handler.training_hparams.optimizer, + learning_rate=training_handler.learning_rate_params[current_cycle - 1], ) - non_param_optim = tfa.optimizers.RectifiedAdam( + non_param_optim = get_optimizer( + optimizer_config=training_handler.training_hparams.optimizer, learning_rate=training_handler.learning_rate_non_params[current_cycle - 1] ) logger.info(f"Starting cycle {current_cycle}..") diff --git a/src/wf_psf/utils/interpolation.py b/src/wf_psf/utils/interpolation.py new file mode 100644 index 00000000..5ee37209 --- /dev/null +++ b/src/wf_psf/utils/interpolation.py @@ -0,0 +1,321 @@ +# ============================================================================ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Modified by CosmoStat Laboratory, 2025. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Polyharmonic spline interpolation.""" + +import tensorflow as tf +from wf_psf.utils.types import FloatTensorLike, TensorLike + +EPSILON = 0.0000000001 + + +def _cross_squared_distance_matrix(x: TensorLike, y: TensorLike) -> tf.Tensor: + """Pairwise squared distance between two (batch) matrices' rows (2nd dim). + + Computes the pairwise distances between rows of x and rows of y. + + Args: + x: `[batch_size, n, d]` float `Tensor`. + y: `[batch_size, m, d]` float `Tensor`. + + Returns: + squared_dists: `[batch_size, n, m]` float `Tensor`, where + `squared_dists[b,i,j] = ||x[b,i,:] - y[b,j,:]||^2`. + """ + x_norm_squared = tf.reduce_sum(tf.square(x), 2) + y_norm_squared = tf.reduce_sum(tf.square(y), 2) + + # Expand so that we can broadcast. + x_norm_squared_tile = tf.expand_dims(x_norm_squared, 2) + y_norm_squared_tile = tf.expand_dims(y_norm_squared, 1) + + x_y_transpose = tf.matmul(x, y, adjoint_b=True) + + # squared_dists[b,i,j] = ||x_bi - y_bj||^2 = + # x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj + squared_dists = x_norm_squared_tile - 2 * x_y_transpose + y_norm_squared_tile + + return squared_dists + + +def _pairwise_squared_distance_matrix(x: TensorLike) -> tf.Tensor: + """Pairwise squared distance among a (batch) matrix's rows (2nd dim). + + This saves a bit of computation vs. using + `_cross_squared_distance_matrix(x, x)` + + Args: + x: `[batch_size, n, d]` float `Tensor`. + + Returns: + squared_dists: `[batch_size, n, n]` float `Tensor`, where + `squared_dists[b,i,j] = ||x[b,i,:] - x[b,j,:]||^2`. + """ + + x_x_transpose = tf.matmul(x, x, adjoint_b=True) + x_norm_squared = tf.linalg.diag_part(x_x_transpose) + x_norm_squared_tile = tf.expand_dims(x_norm_squared, 2) + + # squared_dists[b,i,j] = ||x_bi - x_bj||^2 = + # = x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj + squared_dists = ( + x_norm_squared_tile + - 2 * x_x_transpose + + tf.transpose(x_norm_squared_tile, [0, 2, 1]) + ) + + return squared_dists + + +def _solve_interpolation( + train_points: TensorLike, + train_values: TensorLike, + order: int, + regularization_weight: FloatTensorLike, +) -> TensorLike: + r"""Solve for interpolation coefficients. + + Computes the coefficients of the polyharmonic interpolant for the + 'training' data defined by `(train_points, train_values)` using the kernel + $\phi$. + + Args: + train_points: `[b, n, d]` interpolation centers. + train_values: `[b, n, k]` function values. + order: order of the interpolation. + regularization_weight: weight to place on smoothness regularization term. + + Returns: + w: `[b, n, k]` weights on each interpolation center + v: `[b, d, k]` weights on each input dimension + Raises: + ValueError: if d or k is not fully specified. + """ + + # These dimensions are set dynamically at runtime. + b, n, _ = tf.unstack(tf.shape(train_points), num=3) + + d = train_points.shape[-1] + if d is None: + raise ValueError( + "The dimensionality of the input points (d) must be " + "statically-inferrable." + ) + + k = train_values.shape[-1] + if k is None: + raise ValueError( + "The dimensionality of the output values (k) must be " + "statically-inferrable." + ) + + # First, rename variables so that the notation (c, f, w, v, A, B, etc.) + # follows https://en.wikipedia.org/wiki/Polyharmonic_spline. + # To account for python style guidelines we use + # matrix_a for A and matrix_b for B. + + c = train_points + f = train_values + + # Next, construct the linear system. + with tf.name_scope("construct_linear_system"): + + matrix_a = _phi(_pairwise_squared_distance_matrix(c), order) # [b, n, n] + if regularization_weight > 0: + batch_identity_matrix = tf.expand_dims(tf.eye(n, dtype=c.dtype), 0) + matrix_a += regularization_weight * batch_identity_matrix + + # Append ones to the feature values for the bias term + # in the linear model. + ones = tf.ones_like(c[..., :1], dtype=c.dtype) + matrix_b = tf.concat([c, ones], 2) # [b, n, d + 1] + + # [b, n + d + 1, n] + left_block = tf.concat([matrix_a, tf.transpose(matrix_b, [0, 2, 1])], 1) + + num_b_cols = matrix_b.get_shape()[2] # d + 1 + lhs_zeros = tf.zeros([b, num_b_cols, num_b_cols], train_points.dtype) + right_block = tf.concat([matrix_b, lhs_zeros], 1) # [b, n + d + 1, d + 1] + lhs = tf.concat([left_block, right_block], 2) # [b, n + d + 1, n + d + 1] + + rhs_zeros = tf.zeros([b, d + 1, k], train_points.dtype) + rhs = tf.concat([f, rhs_zeros], 1) # [b, n + d + 1, k] + + # Then, solve the linear system and unpack the results. + with tf.name_scope("solve_linear_system"): + w_v = tf.linalg.solve(lhs, rhs) + w = w_v[:, :n, :] + v = w_v[:, n:, :] + + return w, v + + +def _apply_interpolation( + query_points: TensorLike, + train_points: TensorLike, + w: TensorLike, + v: TensorLike, + order: int, +) -> TensorLike: + """Apply polyharmonic interpolation model to data. + + Given coefficients w and v for the interpolation model, we evaluate + interpolated function values at query_points. + + Args: + query_points: `[b, m, d]` x values to evaluate the interpolation at. + train_points: `[b, n, d]` x values that act as the interpolation centers + (the c variables in the wikipedia article). + w: `[b, n, k]` weights on each interpolation center. + v: `[b, d, k]` weights on each input dimension. + order: order of the interpolation. + + Returns: + Polyharmonic interpolation evaluated at points defined in `query_points`. + """ + + # First, compute the contribution from the rbf term. + pairwise_dists = _cross_squared_distance_matrix(query_points, train_points) + phi_pairwise_dists = _phi(pairwise_dists, order) + + rbf_term = tf.matmul(phi_pairwise_dists, w) + + # Then, compute the contribution from the linear term. + # Pad query_points with ones, for the bias term in the linear model. + query_points_pad = tf.concat( + [query_points, tf.ones_like(query_points[..., :1], train_points.dtype)], 2 + ) + linear_term = tf.matmul(query_points_pad, v) + + return rbf_term + linear_term + + +def _phi(r: FloatTensorLike, order: int) -> FloatTensorLike: + """Coordinate-wise nonlinearity used to define the order of the + interpolation. + + See https://en.wikipedia.org/wiki/Polyharmonic_spline for the definition. + + Args: + r: input op. + order: interpolation order. + + Returns: + `phi_k` evaluated coordinate-wise on `r`, for `k = r`. + """ + + # using EPSILON prevents log(0), sqrt0), etc. + # sqrt(0) is well-defined, but its gradient is not + with tf.name_scope("phi"): + if order == 1: + r = tf.maximum(r, EPSILON) + r = tf.sqrt(r) + return r + elif order == 2: + return 0.5 * r * tf.math.log(tf.maximum(r, EPSILON)) + elif order == 4: + return 0.5 * tf.square(r) * tf.math.log(tf.maximum(r, EPSILON)) + elif order % 2 == 0: + r = tf.maximum(r, EPSILON) + return 0.5 * tf.pow(r, 0.5 * order) * tf.math.log(r) + else: + r = tf.maximum(r, EPSILON) + return tf.pow(r, 0.5 * order) + + +def tfa_interpolate_spline_rbf( + train_points: TensorLike, + train_values: TensorLike, + query_points: TensorLike, + order: int, + regularization_weight: FloatTensorLike = 0.0, + name: str = "interpolate_spline", +) -> tf.Tensor: + r""" + Thin-plate spline interpolation (copied from TensorFlow Addons). + + Interpolate a signal using polyharmonic interpolation. + + The interpolant has the form + $$f(x) = \sum_{i = 1}^n w_i \phi(||x - c_i||) + v^T x + b.$$ + + This is a sum of two terms: (1) a weighted sum of radial basis function + (RBF) terms, with the centers \\(c_1, ... c_n\\), and (2) a linear term + with a bias. The \\(c_i\\) vectors are 'training' points. + In the code, b is absorbed into v + by appending 1 as a final dimension to x. The coefficients w and v are + estimated such that the interpolant exactly fits the value of the function + at the \\(c_i\\) points, the vector w is orthogonal to each \\(c_i\\), + and the vector w sums to 0. With these constraints, the coefficients + can be obtained by solving a linear system. + + \\(\phi\\) is an RBF, parametrized by an interpolation + order. Using order=2 produces the well-known thin-plate spline. + + We also provide the option to perform regularized interpolation. Here, the + interpolant is selected to trade off between the squared loss on the + training data and a certain measure of its curvature + ([details](https://en.wikipedia.org/wiki/Polyharmonic_spline)). + Using a regularization weight greater than zero has the effect that the + interpolant will no longer exactly fit the training data. However, it may + be less vulnerable to overfitting, particularly for high-order + interpolation. + + Note the interpolation procedure is differentiable with respect to all + inputs besides the order parameter. + + We support dynamically-shaped inputs, where batch_size, n, and m are None + at graph construction time. However, d and k must be known. + + Args: + train_points: `[batch_size, n, d]` float `Tensor` of n d-dimensional + locations. These do not need to be regularly-spaced. + train_values: `[batch_size, n, k]` float `Tensor` of n c-dimensional + values evaluated at train_points. + query_points: `[batch_size, m, d]` `Tensor` of m d-dimensional locations + where we will output the interpolant's values. + order: order of the interpolation. Common values are 1 for + \\(\phi(r) = r\\), 2 for \\(\phi(r) = r^2 * log(r)\\) + (thin-plate spline), or 3 for \\(\phi(r) = r^3\\). + regularization_weight: weight placed on the regularization term. + This will depend substantially on the problem, and it should always be + tuned. For many problems, it is reasonable to use no regularization. + If using a non-zero value, we recommend a small value like 0.001. + name: name prefix for ops created by this function + + Returns: + `[b, m, k]` float `Tensor` of query values. We use train_points and + train_values to perform polyharmonic interpolation. The query values are + the values of the interpolant evaluated at the locations specified in + query_points. + """ + with tf.name_scope(name or "interpolate_spline"): + train_points = tf.convert_to_tensor(train_points) + train_values = tf.convert_to_tensor(train_values) + query_points = tf.convert_to_tensor(query_points) + + # First, fit the spline to the observed data. + with tf.name_scope("solve"): + w, v = _solve_interpolation( + train_points, train_values, order, regularization_weight + ) + + # Then, evaluate the spline at the query locations. + with tf.name_scope("predict"): + query_values = _apply_interpolation(query_points, train_points, w, v, order) + + return query_values diff --git a/src/wf_psf/utils/optimizer.py b/src/wf_psf/utils/optimizer.py new file mode 100644 index 00000000..61740dcf --- /dev/null +++ b/src/wf_psf/utils/optimizer.py @@ -0,0 +1,81 @@ +"""Optimizer utilities for WF-PSF. + +This module provides utility functions to create optimizers for training or evaluation of PSF models. + +:Author: Jennifer Pollack + +""" + +import tensorflow as tf + + +def is_optimizer_instance(obj): + return hasattr(obj, "apply_gradients") and hasattr(obj, "get_config") + +def get_optimizer(optimizer_config=None, **overrides): + """ + Return a compiled optimizer instance based on configuration or name. + + Parameters + ---------- + optimizer_config : RecursiveNamespace, dict, or str, optional + Optimizer configuration (from YAML or programmatically), or string name. + **overrides : keyword arguments + Optional hyperparameters to override values in optimizer_config + (e.g., learning_rate, beta_1, beta_2, epsilon, amsgrad). + + Returns + ------- + tf.keras.optimizers.Optimizer + """ + # Detect TensorFlow version + version = tuple(map(int, tf.__version__.split(".")[:2])) + is_legacy = version < (2, 11) + + # --- Normalize input to a dictionary + if isinstance(optimizer_config, str): + optimizer_name = optimizer_config.lower() + optimizer_params = {} + elif isinstance(optimizer_config, dict): + optimizer_name = optimizer_config.get("name", "adam").lower() + optimizer_params = dict(optimizer_config) + elif hasattr(optimizer_config, "__dict__"): # RecursiveNamespace + optimizer_name = getattr(optimizer_config, "name", "adam").lower() + optimizer_params = { + k: getattr(optimizer_config, k) for k in optimizer_config.__dict__ + } + else: + optimizer_name = "adam" + optimizer_params = {} + + # Apply any overrides + optimizer_params.update(overrides) + + # Extract learning_rate + learning_rate = optimizer_params.pop("learning_rate", 1e-3) + + # --- Rectified Adam (TensorFlow Addons) + if optimizer_name in ["rectified_adam", "radam"]: + try: + import tensorflow_addons as tfa + except ImportError: + raise ImportError( + "TensorFlow Addons not found. Install with `pip install wf_psf[addons]`." + ) + optimizer_params.pop("name", None) + return tfa.optimizers.RectifiedAdam(learning_rate=learning_rate) + + # --- Standard Adam (Legacy or Current) + if optimizer_name == "adam": + opt_cls = ( + tf.keras.optimizers.legacy.Adam if is_legacy else tf.keras.optimizers.Adam + ) + return opt_cls( + learning_rate=learning_rate, + beta_1=optimizer_params.get("beta_1", 0.9), + beta_2=optimizer_params.get("beta_2", 0.999), + epsilon=optimizer_params.get("epsilon", 1e-07), + amsgrad=optimizer_params.get("amsgrad", False), + ) + + raise ValueError(f"Unsupported optimizer: {optimizer_name}") diff --git a/src/wf_psf/utils/types.py b/src/wf_psf/utils/types.py new file mode 100644 index 00000000..2a85b3b4 --- /dev/null +++ b/src/wf_psf/utils/types.py @@ -0,0 +1,84 @@ +# ============================================================================ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Modified by CosmoStat Laboratory, 2025. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Types for typing functions signatures.""" + +from typing import Union, Callable, List + +import importlib +import numpy as np +import tensorflow as tf + +from packaging.version import Version + +# Find KerasTensor. +if Version(tf.__version__).release >= Version("2.16").release: + # Determine if loading keras 2 or 3. + if ( + hasattr(tf.keras, "version") + and Version(tf.keras.version()).release >= Version("3.0").release + ): + from keras import KerasTensor + else: + from tf_keras.src.engine.keras_tensor import KerasTensor +elif Version(tf.__version__).release >= Version("2.13").release: + from keras.src.engine.keras_tensor import KerasTensor +elif Version(tf.__version__).release >= Version("2.5").release: + from keras.engine.keras_tensor import KerasTensor +else: + from tensorflow.python.keras.engine.keras_tensor import KerasTensor + + +Number = Union[ + float, + int, + np.float16, + np.float32, + np.float64, + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, +] + +Initializer = Union[None, dict, str, Callable, tf.keras.initializers.Initializer] +Regularizer = Union[None, dict, str, Callable, tf.keras.regularizers.Regularizer] +Constraint = Union[None, dict, str, Callable, tf.keras.constraints.Constraint] +Activation = Union[None, str, Callable] +if importlib.util.find_spec("tensorflow.keras.optimizers.legacy") is not None: + Optimizer = Union[ + tf.keras.optimizers.Optimizer, tf.keras.optimizers.legacy.Optimizer, str + ] +else: + Optimizer = Union[tf.keras.optimizers.Optimizer, str] + +TensorLike = Union[ + List[Union[Number, list]], + tuple, + Number, + np.ndarray, + tf.Tensor, + tf.SparseTensor, + tf.Variable, + KerasTensor, +] +FloatTensorLike = Union[tf.Tensor, float, np.float16, np.float32, np.float64] +AcceptableDTypes = Union[tf.DType, np.dtype, type, int, str, None] \ No newline at end of file diff --git a/src/wf_psf/utils/utils.py b/src/wf_psf/utils/utils.py index e3ae66b7..1b1f2d6d 100644 --- a/src/wf_psf/utils/utils.py +++ b/src/wf_psf/utils/utils.py @@ -6,19 +6,21 @@ import numpy as np import tensorflow as tf -import tensorflow_addons as tfa import PIL import zernike as zk +from wf_psf.utils.interpolation import tfa_interpolate_spline_rbf _HAS_CV2 = False _HAS_SKIMAGE = False try: import cv2 + _HAS_CV2 = True except ImportError: try: from skimage.transform import downscale_local_mean + _HAS_SKIMAGE = True except ImportError: pass @@ -339,6 +341,7 @@ def downsample_im(input_im, output_dim): "Neither OpenCV nor scikit-image is available for image downsampling." ) + def zernike_generator(n_zernikes, wfe_dim): r""" Generate Zernike maps. @@ -624,7 +627,7 @@ def interpolate_zk(self, single_pos): batch_dims=0, ) # Interpolate - interp_zk = tfa.image.interpolate_spline( + interp_zk = tfa_interpolate_spline_rbf( train_points=tf.expand_dims(rec_pos, axis=0), train_values=tf.expand_dims(rec_zks, axis=0), query_points=tf.expand_dims(single_pos[tf.newaxis, :], axis=0),