From 08847375f66c83fc13cee0a4c3c3f405e03158c2 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Tue, 18 Nov 2025 18:08:30 +0100 Subject: [PATCH 001/129] Correct doc string, format errors, unused imports, type hints, etc --- src/wf_psf/utils/read_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/wf_psf/utils/read_config.py b/src/wf_psf/utils/read_config.py index 875ae8ed..48d23e00 100644 --- a/src/wf_psf/utils/read_config.py +++ b/src/wf_psf/utils/read_config.py @@ -140,4 +140,4 @@ def read_stream(conf_file): docs = yaml.load_all(stream, yaml.FullLoader) for doc in docs: # noqa: UP028 - yield doc + yield doc From 532e985a657bed29563687304934b01d58fa0102 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Tue, 18 Nov 2025 18:11:18 +0100 Subject: [PATCH 002/129] Update documentation for v3.0.0: replace modules.rst with api.rst, clean up submodule docs, update conf.py, and update pyproject.toml version --- docs/source/_autosummary/wf_psf.data.rst | 31 +++++++++++ .../wf_psf.data.training_preprocessing.rst | 41 +++++++++++++++ .../_autosummary/wf_psf.metrics.metrics.rst | 32 ++++++++++++ .../wf_psf.metrics.metrics_interface.rst | 35 +++++++++++++ docs/source/_autosummary/wf_psf.metrics.rst | 32 ++++++++++++ .../wf_psf.plotting.plots_interface.rst | 40 +++++++++++++++ docs/source/_autosummary/wf_psf.plotting.rst | 31 +++++++++++ ...wf_psf.psf_models.psf_model_parametric.rst | 29 +++++++++++ ...odels.psf_model_physical_polychromatic.rst | 30 +++++++++++ ...sf.psf_models.psf_model_semiparametric.rst | 30 +++++++++++ .../wf_psf.psf_models.psf_models.rst | 48 +++++++++++++++++ .../wf_psf.psf_models.tf_layers.rst | 36 +++++++++++++ .../wf_psf.psf_models.tf_modules.rst | 33 ++++++++++++ .../wf_psf.psf_models.tf_psf_field.rst | 38 ++++++++++++++ .../wf_psf.psf_models.zernikes.rst | 29 +++++++++++ docs/source/_autosummary/wf_psf.rst | 38 ++++++++++++++ docs/source/_autosummary/wf_psf.run.rst | 30 +++++++++++ .../wf_psf.sims.psf_simulator.rst | 29 +++++++++++ docs/source/_autosummary/wf_psf.sims.rst | 32 ++++++++++++ .../wf_psf.sims.spatial_varying_psf.rst | 33 ++++++++++++ docs/source/_autosummary/wf_psf.training.rst | 32 ++++++++++++ .../_autosummary/wf_psf.training.train.rst | 39 ++++++++++++++ .../wf_psf.training.train_utils.rst | 43 ++++++++++++++++ .../wf_psf.utils.ccd_misalignments.rst | 29 +++++++++++ .../_autosummary/wf_psf.utils.centroids.rst | 39 ++++++++++++++ .../wf_psf.utils.configs_handler.rst | 46 +++++++++++++++++ .../_autosummary/wf_psf.utils.graph_utils.rst | 37 ++++++++++++++ docs/source/_autosummary/wf_psf.utils.io.rst | 29 +++++++++++ .../wf_psf.utils.preprocessing.rst | 31 +++++++++++ .../_autosummary/wf_psf.utils.read_config.rst | 37 ++++++++++++++ docs/source/_autosummary/wf_psf.utils.rst | 38 ++++++++++++++ .../_autosummary/wf_psf.utils.utils.rst | 51 +++++++++++++++++++ 32 files changed, 1128 insertions(+) create mode 100644 docs/source/_autosummary/wf_psf.data.rst create mode 100644 docs/source/_autosummary/wf_psf.data.training_preprocessing.rst create mode 100644 docs/source/_autosummary/wf_psf.metrics.metrics.rst create mode 100644 docs/source/_autosummary/wf_psf.metrics.metrics_interface.rst create mode 100644 docs/source/_autosummary/wf_psf.metrics.rst create mode 100644 docs/source/_autosummary/wf_psf.plotting.plots_interface.rst create mode 100644 docs/source/_autosummary/wf_psf.plotting.rst create mode 100644 docs/source/_autosummary/wf_psf.psf_models.psf_model_parametric.rst create mode 100644 docs/source/_autosummary/wf_psf.psf_models.psf_model_physical_polychromatic.rst create mode 100644 docs/source/_autosummary/wf_psf.psf_models.psf_model_semiparametric.rst create mode 100644 docs/source/_autosummary/wf_psf.psf_models.psf_models.rst create mode 100644 docs/source/_autosummary/wf_psf.psf_models.tf_layers.rst create mode 100644 docs/source/_autosummary/wf_psf.psf_models.tf_modules.rst create mode 100644 docs/source/_autosummary/wf_psf.psf_models.tf_psf_field.rst create mode 100644 docs/source/_autosummary/wf_psf.psf_models.zernikes.rst create mode 100644 docs/source/_autosummary/wf_psf.rst create mode 100644 docs/source/_autosummary/wf_psf.run.rst create mode 100644 docs/source/_autosummary/wf_psf.sims.psf_simulator.rst create mode 100644 docs/source/_autosummary/wf_psf.sims.rst create mode 100644 docs/source/_autosummary/wf_psf.sims.spatial_varying_psf.rst create mode 100644 docs/source/_autosummary/wf_psf.training.rst create mode 100644 docs/source/_autosummary/wf_psf.training.train.rst create mode 100644 docs/source/_autosummary/wf_psf.training.train_utils.rst create mode 100644 docs/source/_autosummary/wf_psf.utils.ccd_misalignments.rst create mode 100644 docs/source/_autosummary/wf_psf.utils.centroids.rst create mode 100644 docs/source/_autosummary/wf_psf.utils.configs_handler.rst create mode 100644 docs/source/_autosummary/wf_psf.utils.graph_utils.rst create mode 100644 docs/source/_autosummary/wf_psf.utils.io.rst create mode 100644 docs/source/_autosummary/wf_psf.utils.preprocessing.rst create mode 100644 docs/source/_autosummary/wf_psf.utils.read_config.rst create mode 100644 docs/source/_autosummary/wf_psf.utils.rst create mode 100644 docs/source/_autosummary/wf_psf.utils.utils.rst diff --git a/docs/source/_autosummary/wf_psf.data.rst b/docs/source/_autosummary/wf_psf.data.rst new file mode 100644 index 00000000..6ba42ec2 --- /dev/null +++ b/docs/source/_autosummary/wf_psf.data.rst @@ -0,0 +1,31 @@ +wf\_psf.data +============ + +.. automodule:: wf_psf.data + + + + + + + + + + + + + + + + + + + +.. rubric:: Modules + +.. autosummary:: + :toctree: + :recursive: + + wf_psf.data.training_preprocessing + diff --git a/docs/source/_autosummary/wf_psf.data.training_preprocessing.rst b/docs/source/_autosummary/wf_psf.data.training_preprocessing.rst new file mode 100644 index 00000000..cfea4561 --- /dev/null +++ b/docs/source/_autosummary/wf_psf.data.training_preprocessing.rst @@ -0,0 +1,41 @@ +wf\_psf.data.training\_preprocessing +==================================== + +.. automodule:: wf_psf.data.training_preprocessing + + + + + + + + .. rubric:: Functions + + .. autosummary:: + + compute_ccd_misalignment + compute_centroid_correction + extract_star_data + get_np_obs_positions + get_np_zernike_prior + get_obs_positions + get_zernike_prior + + + + + + .. rubric:: Classes + + .. autosummary:: + + DataHandler + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.metrics.metrics.rst b/docs/source/_autosummary/wf_psf.metrics.metrics.rst new file mode 100644 index 00000000..36770eb1 --- /dev/null +++ b/docs/source/_autosummary/wf_psf.metrics.metrics.rst @@ -0,0 +1,32 @@ +wf\_psf.metrics.metrics +======================= + +.. automodule:: wf_psf.metrics.metrics + + + + + + + + .. rubric:: Functions + + .. autosummary:: + + compute_mono_metric + compute_opd_metrics + compute_poly_metric + compute_shape_metrics + + + + + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.metrics.metrics_interface.rst b/docs/source/_autosummary/wf_psf.metrics.metrics_interface.rst new file mode 100644 index 00000000..9675e837 --- /dev/null +++ b/docs/source/_autosummary/wf_psf.metrics.metrics_interface.rst @@ -0,0 +1,35 @@ +wf\_psf.metrics.metrics\_interface +================================== + +.. automodule:: wf_psf.metrics.metrics_interface + + + + + + + + .. rubric:: Functions + + .. autosummary:: + + evaluate_model + + + + + + .. rubric:: Classes + + .. autosummary:: + + MetricsParamsHandler + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.metrics.rst b/docs/source/_autosummary/wf_psf.metrics.rst new file mode 100644 index 00000000..ed007b1e --- /dev/null +++ b/docs/source/_autosummary/wf_psf.metrics.rst @@ -0,0 +1,32 @@ +wf\_psf.metrics +=============== + +.. automodule:: wf_psf.metrics + + + + + + + + + + + + + + + + + + + +.. rubric:: Modules + +.. autosummary:: + :toctree: + :recursive: + + wf_psf.metrics.metrics + wf_psf.metrics.metrics_interface + diff --git a/docs/source/_autosummary/wf_psf.plotting.plots_interface.rst b/docs/source/_autosummary/wf_psf.plotting.plots_interface.rst new file mode 100644 index 00000000..173b7f86 --- /dev/null +++ b/docs/source/_autosummary/wf_psf.plotting.plots_interface.rst @@ -0,0 +1,40 @@ +wf\_psf.plotting.plots\_interface +================================= + +.. automodule:: wf_psf.plotting.plots_interface + + + + + + + + .. rubric:: Functions + + .. autosummary:: + + define_plot_style + get_number_of_stars + make_plot + plot_metrics + + + + + + .. rubric:: Classes + + .. autosummary:: + + MetricsPlotHandler + MonochromaticMetricsPlotHandler + ShapeMetricsPlotHandler + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.plotting.rst b/docs/source/_autosummary/wf_psf.plotting.rst new file mode 100644 index 00000000..8c0db4ac --- /dev/null +++ b/docs/source/_autosummary/wf_psf.plotting.rst @@ -0,0 +1,31 @@ +wf\_psf.plotting +================ + +.. automodule:: wf_psf.plotting + + + + + + + + + + + + + + + + + + + +.. rubric:: Modules + +.. autosummary:: + :toctree: + :recursive: + + wf_psf.plotting.plots_interface + diff --git a/docs/source/_autosummary/wf_psf.psf_models.psf_model_parametric.rst b/docs/source/_autosummary/wf_psf.psf_models.psf_model_parametric.rst new file mode 100644 index 00000000..fb16f0e3 --- /dev/null +++ b/docs/source/_autosummary/wf_psf.psf_models.psf_model_parametric.rst @@ -0,0 +1,29 @@ +wf\_psf.psf\_models.psf\_model\_parametric +========================================== + +.. automodule:: wf_psf.psf_models.psf_model_parametric + + + + + + + + + + + + .. rubric:: Classes + + .. autosummary:: + + TFParametricPSFFieldModel + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.psf_models.psf_model_physical_polychromatic.rst b/docs/source/_autosummary/wf_psf.psf_models.psf_model_physical_polychromatic.rst new file mode 100644 index 00000000..311652a3 --- /dev/null +++ b/docs/source/_autosummary/wf_psf.psf_models.psf_model_physical_polychromatic.rst @@ -0,0 +1,30 @@ +wf\_psf.psf\_models.psf\_model\_physical\_polychromatic +======================================================= + +.. automodule:: wf_psf.psf_models.psf_model_physical_polychromatic + + + + + + + + + + + + .. rubric:: Classes + + .. autosummary:: + + PhysicalPolychromaticFieldFactory + TFPhysicalPolychromaticField + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.psf_models.psf_model_semiparametric.rst b/docs/source/_autosummary/wf_psf.psf_models.psf_model_semiparametric.rst new file mode 100644 index 00000000..578bbc92 --- /dev/null +++ b/docs/source/_autosummary/wf_psf.psf_models.psf_model_semiparametric.rst @@ -0,0 +1,30 @@ +wf\_psf.psf\_models.psf\_model\_semiparametric +============================================== + +.. automodule:: wf_psf.psf_models.psf_model_semiparametric + + + + + + + + + + + + .. rubric:: Classes + + .. autosummary:: + + SemiParamFieldFactory + TFSemiParametricField + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.psf_models.psf_models.rst b/docs/source/_autosummary/wf_psf.psf_models.psf_models.rst new file mode 100644 index 00000000..5395637e --- /dev/null +++ b/docs/source/_autosummary/wf_psf.psf_models.psf_models.rst @@ -0,0 +1,48 @@ +wf\_psf.psf\_models.psf\_models +=============================== + +.. automodule:: wf_psf.psf_models.psf_models + + + + + + + + .. rubric:: Functions + + .. autosummary:: + + build_PSF_model + generate_zernike_maps_3d + get_psf_model + get_psf_model_weights_filepath + register_psfclass + set_psf_model + simPSF + tf_obscurations + + + + + + .. rubric:: Classes + + .. autosummary:: + + PSFModelBaseFactory + + + + + + .. rubric:: Exceptions + + .. autosummary:: + + PSFModelError + + + + + diff --git a/docs/source/_autosummary/wf_psf.psf_models.tf_layers.rst b/docs/source/_autosummary/wf_psf.psf_models.tf_layers.rst new file mode 100644 index 00000000..de58615b --- /dev/null +++ b/docs/source/_autosummary/wf_psf.psf_models.tf_layers.rst @@ -0,0 +1,36 @@ +wf\_psf.psf\_models.tf\_layers +============================== + +.. automodule:: wf_psf.psf_models.tf_layers + + + + + + + + + + + + .. rubric:: Classes + + .. autosummary:: + + TFBatchMonochromaticPSF + TFBatchPolychromaticPSF + TFNonParametricGraphOPD + TFNonParametricMCCDOPDv2 + TFNonParametricPolynomialVariationsOPD + TFPhysicalLayer + TFPolynomialZernikeField + TFZernikeOPD + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.psf_models.tf_modules.rst b/docs/source/_autosummary/wf_psf.psf_models.tf_modules.rst new file mode 100644 index 00000000..7571ede7 --- /dev/null +++ b/docs/source/_autosummary/wf_psf.psf_models.tf_modules.rst @@ -0,0 +1,33 @@ +wf\_psf.psf\_models.tf\_modules +=============================== + +.. automodule:: wf_psf.psf_models.tf_modules + + + + + + + + + + + + .. rubric:: Classes + + .. autosummary:: + + TFBuildPhase + TFFftDiffract + TFMonochromaticPSF + TFZernikeMonochromaticPSF + TFZernikeOPD + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.psf_models.tf_psf_field.rst b/docs/source/_autosummary/wf_psf.psf_models.tf_psf_field.rst new file mode 100644 index 00000000..8ab4325b --- /dev/null +++ b/docs/source/_autosummary/wf_psf.psf_models.tf_psf_field.rst @@ -0,0 +1,38 @@ +wf\_psf.psf\_models.tf\_psf\_field +================================== + +.. automodule:: wf_psf.psf_models.tf_psf_field + + + + + + + + .. rubric:: Functions + + .. autosummary:: + + get_ground_truth_zernike + + + + + + .. rubric:: Classes + + .. autosummary:: + + GroundTruthPhysicalFieldFactory + GroundTruthSemiParamFieldFactory + TFGroundTruthPhysicalField + TFGroundTruthSemiParametricField + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.psf_models.zernikes.rst b/docs/source/_autosummary/wf_psf.psf_models.zernikes.rst new file mode 100644 index 00000000..54b92e7c --- /dev/null +++ b/docs/source/_autosummary/wf_psf.psf_models.zernikes.rst @@ -0,0 +1,29 @@ +wf\_psf.psf\_models.zernikes +============================ + +.. automodule:: wf_psf.psf_models.zernikes + + + + + + + + .. rubric:: Functions + + .. autosummary:: + + zernike_generator + + + + + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.rst b/docs/source/_autosummary/wf_psf.rst new file mode 100644 index 00000000..fba774bf --- /dev/null +++ b/docs/source/_autosummary/wf_psf.rst @@ -0,0 +1,38 @@ +wf\_psf +======= + +.. automodule:: wf_psf + + + + + + + + + + + + + + + + + + + +.. rubric:: Modules + +.. autosummary:: + :toctree: + :recursive: + + wf_psf.data + wf_psf.metrics + wf_psf.plotting + wf_psf.psf_models + wf_psf.run + wf_psf.sims + wf_psf.training + wf_psf.utils + diff --git a/docs/source/_autosummary/wf_psf.run.rst b/docs/source/_autosummary/wf_psf.run.rst new file mode 100644 index 00000000..e1533aa0 --- /dev/null +++ b/docs/source/_autosummary/wf_psf.run.rst @@ -0,0 +1,30 @@ +wf\_psf.run +=========== + +.. automodule:: wf_psf.run + + + + + + + + .. rubric:: Functions + + .. autosummary:: + + mainMethod + setProgramOptions + + + + + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.sims.psf_simulator.rst b/docs/source/_autosummary/wf_psf.sims.psf_simulator.rst new file mode 100644 index 00000000..7d29d3ba --- /dev/null +++ b/docs/source/_autosummary/wf_psf.sims.psf_simulator.rst @@ -0,0 +1,29 @@ +wf\_psf.sims.psf\_simulator +=========================== + +.. automodule:: wf_psf.sims.psf_simulator + + + + + + + + + + + + .. rubric:: Classes + + .. autosummary:: + + PSFSimulator + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.sims.rst b/docs/source/_autosummary/wf_psf.sims.rst new file mode 100644 index 00000000..b0f0901d --- /dev/null +++ b/docs/source/_autosummary/wf_psf.sims.rst @@ -0,0 +1,32 @@ +wf\_psf.sims +============ + +.. automodule:: wf_psf.sims + + + + + + + + + + + + + + + + + + + +.. rubric:: Modules + +.. autosummary:: + :toctree: + :recursive: + + wf_psf.sims.psf_simulator + wf_psf.sims.spatial_varying_psf + diff --git a/docs/source/_autosummary/wf_psf.sims.spatial_varying_psf.rst b/docs/source/_autosummary/wf_psf.sims.spatial_varying_psf.rst new file mode 100644 index 00000000..3fee34bc --- /dev/null +++ b/docs/source/_autosummary/wf_psf.sims.spatial_varying_psf.rst @@ -0,0 +1,33 @@ +wf\_psf.sims.spatial\_varying\_psf +================================== + +.. automodule:: wf_psf.sims.spatial_varying_psf + + + + + + + + + + + + .. rubric:: Classes + + .. autosummary:: + + CoordinateHelper + MeshHelper + PolynomialMatrixHelper + SpatialVaryingPSF + ZernikeHelper + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.training.rst b/docs/source/_autosummary/wf_psf.training.rst new file mode 100644 index 00000000..aa692d15 --- /dev/null +++ b/docs/source/_autosummary/wf_psf.training.rst @@ -0,0 +1,32 @@ +wf\_psf.training +================ + +.. automodule:: wf_psf.training + + + + + + + + + + + + + + + + + + + +.. rubric:: Modules + +.. autosummary:: + :toctree: + :recursive: + + wf_psf.training.train + wf_psf.training.train_utils + diff --git a/docs/source/_autosummary/wf_psf.training.train.rst b/docs/source/_autosummary/wf_psf.training.train.rst new file mode 100644 index 00000000..f8f3faea --- /dev/null +++ b/docs/source/_autosummary/wf_psf.training.train.rst @@ -0,0 +1,39 @@ +wf\_psf.training.train +====================== + +.. automodule:: wf_psf.training.train + + + + + + + + .. rubric:: Functions + + .. autosummary:: + + filepath_chkp_callback + get_gpu_info + get_loss_metrics_monitor_and_outputs + setup_training + train + + + + + + .. rubric:: Classes + + .. autosummary:: + + TrainingParamsHandler + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.training.train_utils.rst b/docs/source/_autosummary/wf_psf.training.train_utils.rst new file mode 100644 index 00000000..153313e1 --- /dev/null +++ b/docs/source/_autosummary/wf_psf.training.train_utils.rst @@ -0,0 +1,43 @@ +wf\_psf.training.train\_utils +============================= + +.. automodule:: wf_psf.training.train_utils + + + + + + + + .. rubric:: Functions + + .. autosummary:: + + calculate_sample_weights + configure_optimizer_and_loss + general_train_cycle + get_callbacks + l1_schedule_rule + masked_mse + train_cycle_part + + + + + + .. rubric:: Classes + + .. autosummary:: + + L1ParamScheduler + MaskedMeanSquaredError + MaskedMeanSquaredErrorMetric + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.utils.ccd_misalignments.rst b/docs/source/_autosummary/wf_psf.utils.ccd_misalignments.rst new file mode 100644 index 00000000..00f69a01 --- /dev/null +++ b/docs/source/_autosummary/wf_psf.utils.ccd_misalignments.rst @@ -0,0 +1,29 @@ +wf\_psf.utils.ccd\_misalignments +================================ + +.. automodule:: wf_psf.utils.ccd_misalignments + + + + + + + + + + + + .. rubric:: Classes + + .. autosummary:: + + CCDMisalignmentCalculator + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.utils.centroids.rst b/docs/source/_autosummary/wf_psf.utils.centroids.rst new file mode 100644 index 00000000..64aed9ee --- /dev/null +++ b/docs/source/_autosummary/wf_psf.utils.centroids.rst @@ -0,0 +1,39 @@ +wf\_psf.utils.centroids +======================= + +.. automodule:: wf_psf.utils.centroids + + + + + + + + .. rubric:: Functions + + .. autosummary:: + + compute_zernike_tip_tilt + decim + degradation_op + lanczos + shift_ker_stack + + + + + + .. rubric:: Classes + + .. autosummary:: + + CentroidEstimator + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.utils.configs_handler.rst b/docs/source/_autosummary/wf_psf.utils.configs_handler.rst new file mode 100644 index 00000000..0f1ca838 --- /dev/null +++ b/docs/source/_autosummary/wf_psf.utils.configs_handler.rst @@ -0,0 +1,46 @@ +wf\_psf.utils.configs\_handler +============================== + +.. automodule:: wf_psf.utils.configs_handler + + + + + + + + .. rubric:: Functions + + .. autosummary:: + + get_run_config + register_configclass + set_run_config + + + + + + .. rubric:: Classes + + .. autosummary:: + + DataConfigHandler + MetricsConfigHandler + PlottingConfigHandler + TrainingConfigHandler + + + + + + .. rubric:: Exceptions + + .. autosummary:: + + ConfigParameterError + + + + + diff --git a/docs/source/_autosummary/wf_psf.utils.graph_utils.rst b/docs/source/_autosummary/wf_psf.utils.graph_utils.rst new file mode 100644 index 00000000..fd373ab4 --- /dev/null +++ b/docs/source/_autosummary/wf_psf.utils.graph_utils.rst @@ -0,0 +1,37 @@ +wf\_psf.utils.graph\_utils +========================== + +.. automodule:: wf_psf.utils.graph_utils + + + + + + + + .. rubric:: Functions + + .. autosummary:: + + gen_Pea + pairwise_distances + select_vstar + + + + + + .. rubric:: Classes + + .. autosummary:: + + GraphBuilder + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.utils.io.rst b/docs/source/_autosummary/wf_psf.utils.io.rst new file mode 100644 index 00000000..37155165 --- /dev/null +++ b/docs/source/_autosummary/wf_psf.utils.io.rst @@ -0,0 +1,29 @@ +wf\_psf.utils.io +================ + +.. automodule:: wf_psf.utils.io + + + + + + + + + + + + .. rubric:: Classes + + .. autosummary:: + + FileIOHandler + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.utils.preprocessing.rst b/docs/source/_autosummary/wf_psf.utils.preprocessing.rst new file mode 100644 index 00000000..0f8fed5d --- /dev/null +++ b/docs/source/_autosummary/wf_psf.utils.preprocessing.rst @@ -0,0 +1,31 @@ +wf\_psf.utils.preprocessing +=========================== + +.. automodule:: wf_psf.utils.preprocessing + + + + + + + + .. rubric:: Functions + + .. autosummary:: + + defocus_to_zk4_wavediff + defocus_to_zk4_zemax + shift_x_y_to_zk1_2_wavediff + + + + + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.utils.read_config.rst b/docs/source/_autosummary/wf_psf.utils.read_config.rst new file mode 100644 index 00000000..adccbe4e --- /dev/null +++ b/docs/source/_autosummary/wf_psf.utils.read_config.rst @@ -0,0 +1,37 @@ +wf\_psf.utils.read\_config +========================== + +.. automodule:: wf_psf.utils.read_config + + + + + + + + .. rubric:: Functions + + .. autosummary:: + + read_conf + read_stream + read_yaml + + + + + + .. rubric:: Classes + + .. autosummary:: + + RecursiveNamespace + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.utils.rst b/docs/source/_autosummary/wf_psf.utils.rst new file mode 100644 index 00000000..3c1f8660 --- /dev/null +++ b/docs/source/_autosummary/wf_psf.utils.rst @@ -0,0 +1,38 @@ +wf\_psf.utils +============= + +.. automodule:: wf_psf.utils + + + + + + + + + + + + + + + + + + + +.. rubric:: Modules + +.. autosummary:: + :toctree: + :recursive: + + wf_psf.utils.ccd_misalignments + wf_psf.utils.centroids + wf_psf.utils.configs_handler + wf_psf.utils.graph_utils + wf_psf.utils.io + wf_psf.utils.preprocessing + wf_psf.utils.read_config + wf_psf.utils.utils + diff --git a/docs/source/_autosummary/wf_psf.utils.utils.rst b/docs/source/_autosummary/wf_psf.utils.utils.rst new file mode 100644 index 00000000..71e7d532 --- /dev/null +++ b/docs/source/_autosummary/wf_psf.utils.utils.rst @@ -0,0 +1,51 @@ +wf\_psf.utils.utils +=================== + +.. automodule:: wf_psf.utils.utils + + + + + + + + .. rubric:: Functions + + .. autosummary:: + + add_noise + calc_poly_position_mat + compute_unobscured_zernike_projection + convert_to_tf + decimate_im + decompose_tf_obscured_opd_basis + downsample_im + generalised_sigmoid + generate_SED_elems + generate_SED_elems_in_tensorflow + generate_n_mask + generate_packed_elems + load_multi_cycle_params_click + single_mask_generator + zernike_generator + + + + + + .. rubric:: Classes + + .. autosummary:: + + IndependentZernikeInterpolation + NoiseEstimator + ZernikeInterpolation + + + + + + + + + From be61b0f09cc8ee7262d2cd06f4d2ecdfeed3bb84 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 21 Nov 2025 12:04:30 +0100 Subject: [PATCH 003/129] Update API reference and clean up package/module docstrings - Removed old auto-generated wf_psf.rst from _autosummary - Updated toc.rst (fixed tab issues, reference api.rst) - Added api.rst with :recursive: directive and wf_psf.run entrypoint - Refined __init__.py docstrings for all subpackages for clarity and consistency - Updated module-level docstrings (purpose, authors, TensorFlow notes, etc.) --- docs/source/_autosummary/wf_psf.rst | 38 ------------------------- docs/source/_autosummary/wf_psf.run.rst | 2 +- 2 files changed, 1 insertion(+), 39 deletions(-) delete mode 100644 docs/source/_autosummary/wf_psf.rst diff --git a/docs/source/_autosummary/wf_psf.rst b/docs/source/_autosummary/wf_psf.rst deleted file mode 100644 index fba774bf..00000000 --- a/docs/source/_autosummary/wf_psf.rst +++ /dev/null @@ -1,38 +0,0 @@ -wf\_psf -======= - -.. automodule:: wf_psf - - - - - - - - - - - - - - - - - - - -.. rubric:: Modules - -.. autosummary:: - :toctree: - :recursive: - - wf_psf.data - wf_psf.metrics - wf_psf.plotting - wf_psf.psf_models - wf_psf.run - wf_psf.sims - wf_psf.training - wf_psf.utils - diff --git a/docs/source/_autosummary/wf_psf.run.rst b/docs/source/_autosummary/wf_psf.run.rst index e1533aa0..a4b725b2 100644 --- a/docs/source/_autosummary/wf_psf.run.rst +++ b/docs/source/_autosummary/wf_psf.run.rst @@ -1,4 +1,4 @@ -wf\_psf.run +wf\_psf.run =========== .. automodule:: wf_psf.run From afa2923d899637000738683864df17b9c2eb307c Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 21 Nov 2025 12:47:28 +0100 Subject: [PATCH 004/129] Remove auto-generated _autosummary from repo; will be built in CD --- docs/source/_autosummary/wf_psf.data.rst | 31 ----------- .../wf_psf.data.training_preprocessing.rst | 41 --------------- .../_autosummary/wf_psf.metrics.metrics.rst | 32 ------------ .../wf_psf.metrics.metrics_interface.rst | 35 ------------- docs/source/_autosummary/wf_psf.metrics.rst | 32 ------------ .../wf_psf.plotting.plots_interface.rst | 40 --------------- docs/source/_autosummary/wf_psf.plotting.rst | 31 ----------- ...wf_psf.psf_models.psf_model_parametric.rst | 29 ----------- ...odels.psf_model_physical_polychromatic.rst | 30 ----------- ...sf.psf_models.psf_model_semiparametric.rst | 30 ----------- .../wf_psf.psf_models.psf_models.rst | 48 ----------------- .../wf_psf.psf_models.tf_layers.rst | 36 ------------- .../wf_psf.psf_models.tf_modules.rst | 33 ------------ .../wf_psf.psf_models.tf_psf_field.rst | 38 -------------- .../wf_psf.psf_models.zernikes.rst | 29 ----------- docs/source/_autosummary/wf_psf.run.rst | 30 ----------- .../wf_psf.sims.psf_simulator.rst | 29 ----------- docs/source/_autosummary/wf_psf.sims.rst | 32 ------------ .../wf_psf.sims.spatial_varying_psf.rst | 33 ------------ docs/source/_autosummary/wf_psf.training.rst | 32 ------------ .../_autosummary/wf_psf.training.train.rst | 39 -------------- .../wf_psf.training.train_utils.rst | 43 ---------------- .../wf_psf.utils.ccd_misalignments.rst | 29 ----------- .../_autosummary/wf_psf.utils.centroids.rst | 39 -------------- .../wf_psf.utils.configs_handler.rst | 46 ----------------- .../_autosummary/wf_psf.utils.graph_utils.rst | 37 -------------- docs/source/_autosummary/wf_psf.utils.io.rst | 29 ----------- .../wf_psf.utils.preprocessing.rst | 31 ----------- .../_autosummary/wf_psf.utils.read_config.rst | 37 -------------- docs/source/_autosummary/wf_psf.utils.rst | 38 -------------- .../_autosummary/wf_psf.utils.utils.rst | 51 ------------------- 31 files changed, 1090 deletions(-) delete mode 100644 docs/source/_autosummary/wf_psf.data.rst delete mode 100644 docs/source/_autosummary/wf_psf.data.training_preprocessing.rst delete mode 100644 docs/source/_autosummary/wf_psf.metrics.metrics.rst delete mode 100644 docs/source/_autosummary/wf_psf.metrics.metrics_interface.rst delete mode 100644 docs/source/_autosummary/wf_psf.metrics.rst delete mode 100644 docs/source/_autosummary/wf_psf.plotting.plots_interface.rst delete mode 100644 docs/source/_autosummary/wf_psf.plotting.rst delete mode 100644 docs/source/_autosummary/wf_psf.psf_models.psf_model_parametric.rst delete mode 100644 docs/source/_autosummary/wf_psf.psf_models.psf_model_physical_polychromatic.rst delete mode 100644 docs/source/_autosummary/wf_psf.psf_models.psf_model_semiparametric.rst delete mode 100644 docs/source/_autosummary/wf_psf.psf_models.psf_models.rst delete mode 100644 docs/source/_autosummary/wf_psf.psf_models.tf_layers.rst delete mode 100644 docs/source/_autosummary/wf_psf.psf_models.tf_modules.rst delete mode 100644 docs/source/_autosummary/wf_psf.psf_models.tf_psf_field.rst delete mode 100644 docs/source/_autosummary/wf_psf.psf_models.zernikes.rst delete mode 100644 docs/source/_autosummary/wf_psf.run.rst delete mode 100644 docs/source/_autosummary/wf_psf.sims.psf_simulator.rst delete mode 100644 docs/source/_autosummary/wf_psf.sims.rst delete mode 100644 docs/source/_autosummary/wf_psf.sims.spatial_varying_psf.rst delete mode 100644 docs/source/_autosummary/wf_psf.training.rst delete mode 100644 docs/source/_autosummary/wf_psf.training.train.rst delete mode 100644 docs/source/_autosummary/wf_psf.training.train_utils.rst delete mode 100644 docs/source/_autosummary/wf_psf.utils.ccd_misalignments.rst delete mode 100644 docs/source/_autosummary/wf_psf.utils.centroids.rst delete mode 100644 docs/source/_autosummary/wf_psf.utils.configs_handler.rst delete mode 100644 docs/source/_autosummary/wf_psf.utils.graph_utils.rst delete mode 100644 docs/source/_autosummary/wf_psf.utils.io.rst delete mode 100644 docs/source/_autosummary/wf_psf.utils.preprocessing.rst delete mode 100644 docs/source/_autosummary/wf_psf.utils.read_config.rst delete mode 100644 docs/source/_autosummary/wf_psf.utils.rst delete mode 100644 docs/source/_autosummary/wf_psf.utils.utils.rst diff --git a/docs/source/_autosummary/wf_psf.data.rst b/docs/source/_autosummary/wf_psf.data.rst deleted file mode 100644 index 6ba42ec2..00000000 --- a/docs/source/_autosummary/wf_psf.data.rst +++ /dev/null @@ -1,31 +0,0 @@ -wf\_psf.data -============ - -.. automodule:: wf_psf.data - - - - - - - - - - - - - - - - - - - -.. rubric:: Modules - -.. autosummary:: - :toctree: - :recursive: - - wf_psf.data.training_preprocessing - diff --git a/docs/source/_autosummary/wf_psf.data.training_preprocessing.rst b/docs/source/_autosummary/wf_psf.data.training_preprocessing.rst deleted file mode 100644 index cfea4561..00000000 --- a/docs/source/_autosummary/wf_psf.data.training_preprocessing.rst +++ /dev/null @@ -1,41 +0,0 @@ -wf\_psf.data.training\_preprocessing -==================================== - -.. automodule:: wf_psf.data.training_preprocessing - - - - - - - - .. rubric:: Functions - - .. autosummary:: - - compute_ccd_misalignment - compute_centroid_correction - extract_star_data - get_np_obs_positions - get_np_zernike_prior - get_obs_positions - get_zernike_prior - - - - - - .. rubric:: Classes - - .. autosummary:: - - DataHandler - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.metrics.metrics.rst b/docs/source/_autosummary/wf_psf.metrics.metrics.rst deleted file mode 100644 index 36770eb1..00000000 --- a/docs/source/_autosummary/wf_psf.metrics.metrics.rst +++ /dev/null @@ -1,32 +0,0 @@ -wf\_psf.metrics.metrics -======================= - -.. automodule:: wf_psf.metrics.metrics - - - - - - - - .. rubric:: Functions - - .. autosummary:: - - compute_mono_metric - compute_opd_metrics - compute_poly_metric - compute_shape_metrics - - - - - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.metrics.metrics_interface.rst b/docs/source/_autosummary/wf_psf.metrics.metrics_interface.rst deleted file mode 100644 index 9675e837..00000000 --- a/docs/source/_autosummary/wf_psf.metrics.metrics_interface.rst +++ /dev/null @@ -1,35 +0,0 @@ -wf\_psf.metrics.metrics\_interface -================================== - -.. automodule:: wf_psf.metrics.metrics_interface - - - - - - - - .. rubric:: Functions - - .. autosummary:: - - evaluate_model - - - - - - .. rubric:: Classes - - .. autosummary:: - - MetricsParamsHandler - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.metrics.rst b/docs/source/_autosummary/wf_psf.metrics.rst deleted file mode 100644 index ed007b1e..00000000 --- a/docs/source/_autosummary/wf_psf.metrics.rst +++ /dev/null @@ -1,32 +0,0 @@ -wf\_psf.metrics -=============== - -.. automodule:: wf_psf.metrics - - - - - - - - - - - - - - - - - - - -.. rubric:: Modules - -.. autosummary:: - :toctree: - :recursive: - - wf_psf.metrics.metrics - wf_psf.metrics.metrics_interface - diff --git a/docs/source/_autosummary/wf_psf.plotting.plots_interface.rst b/docs/source/_autosummary/wf_psf.plotting.plots_interface.rst deleted file mode 100644 index 173b7f86..00000000 --- a/docs/source/_autosummary/wf_psf.plotting.plots_interface.rst +++ /dev/null @@ -1,40 +0,0 @@ -wf\_psf.plotting.plots\_interface -================================= - -.. automodule:: wf_psf.plotting.plots_interface - - - - - - - - .. rubric:: Functions - - .. autosummary:: - - define_plot_style - get_number_of_stars - make_plot - plot_metrics - - - - - - .. rubric:: Classes - - .. autosummary:: - - MetricsPlotHandler - MonochromaticMetricsPlotHandler - ShapeMetricsPlotHandler - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.plotting.rst b/docs/source/_autosummary/wf_psf.plotting.rst deleted file mode 100644 index 8c0db4ac..00000000 --- a/docs/source/_autosummary/wf_psf.plotting.rst +++ /dev/null @@ -1,31 +0,0 @@ -wf\_psf.plotting -================ - -.. automodule:: wf_psf.plotting - - - - - - - - - - - - - - - - - - - -.. rubric:: Modules - -.. autosummary:: - :toctree: - :recursive: - - wf_psf.plotting.plots_interface - diff --git a/docs/source/_autosummary/wf_psf.psf_models.psf_model_parametric.rst b/docs/source/_autosummary/wf_psf.psf_models.psf_model_parametric.rst deleted file mode 100644 index fb16f0e3..00000000 --- a/docs/source/_autosummary/wf_psf.psf_models.psf_model_parametric.rst +++ /dev/null @@ -1,29 +0,0 @@ -wf\_psf.psf\_models.psf\_model\_parametric -========================================== - -.. automodule:: wf_psf.psf_models.psf_model_parametric - - - - - - - - - - - - .. rubric:: Classes - - .. autosummary:: - - TFParametricPSFFieldModel - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.psf_models.psf_model_physical_polychromatic.rst b/docs/source/_autosummary/wf_psf.psf_models.psf_model_physical_polychromatic.rst deleted file mode 100644 index 311652a3..00000000 --- a/docs/source/_autosummary/wf_psf.psf_models.psf_model_physical_polychromatic.rst +++ /dev/null @@ -1,30 +0,0 @@ -wf\_psf.psf\_models.psf\_model\_physical\_polychromatic -======================================================= - -.. automodule:: wf_psf.psf_models.psf_model_physical_polychromatic - - - - - - - - - - - - .. rubric:: Classes - - .. autosummary:: - - PhysicalPolychromaticFieldFactory - TFPhysicalPolychromaticField - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.psf_models.psf_model_semiparametric.rst b/docs/source/_autosummary/wf_psf.psf_models.psf_model_semiparametric.rst deleted file mode 100644 index 578bbc92..00000000 --- a/docs/source/_autosummary/wf_psf.psf_models.psf_model_semiparametric.rst +++ /dev/null @@ -1,30 +0,0 @@ -wf\_psf.psf\_models.psf\_model\_semiparametric -============================================== - -.. automodule:: wf_psf.psf_models.psf_model_semiparametric - - - - - - - - - - - - .. rubric:: Classes - - .. autosummary:: - - SemiParamFieldFactory - TFSemiParametricField - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.psf_models.psf_models.rst b/docs/source/_autosummary/wf_psf.psf_models.psf_models.rst deleted file mode 100644 index 5395637e..00000000 --- a/docs/source/_autosummary/wf_psf.psf_models.psf_models.rst +++ /dev/null @@ -1,48 +0,0 @@ -wf\_psf.psf\_models.psf\_models -=============================== - -.. automodule:: wf_psf.psf_models.psf_models - - - - - - - - .. rubric:: Functions - - .. autosummary:: - - build_PSF_model - generate_zernike_maps_3d - get_psf_model - get_psf_model_weights_filepath - register_psfclass - set_psf_model - simPSF - tf_obscurations - - - - - - .. rubric:: Classes - - .. autosummary:: - - PSFModelBaseFactory - - - - - - .. rubric:: Exceptions - - .. autosummary:: - - PSFModelError - - - - - diff --git a/docs/source/_autosummary/wf_psf.psf_models.tf_layers.rst b/docs/source/_autosummary/wf_psf.psf_models.tf_layers.rst deleted file mode 100644 index de58615b..00000000 --- a/docs/source/_autosummary/wf_psf.psf_models.tf_layers.rst +++ /dev/null @@ -1,36 +0,0 @@ -wf\_psf.psf\_models.tf\_layers -============================== - -.. automodule:: wf_psf.psf_models.tf_layers - - - - - - - - - - - - .. rubric:: Classes - - .. autosummary:: - - TFBatchMonochromaticPSF - TFBatchPolychromaticPSF - TFNonParametricGraphOPD - TFNonParametricMCCDOPDv2 - TFNonParametricPolynomialVariationsOPD - TFPhysicalLayer - TFPolynomialZernikeField - TFZernikeOPD - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.psf_models.tf_modules.rst b/docs/source/_autosummary/wf_psf.psf_models.tf_modules.rst deleted file mode 100644 index 7571ede7..00000000 --- a/docs/source/_autosummary/wf_psf.psf_models.tf_modules.rst +++ /dev/null @@ -1,33 +0,0 @@ -wf\_psf.psf\_models.tf\_modules -=============================== - -.. automodule:: wf_psf.psf_models.tf_modules - - - - - - - - - - - - .. rubric:: Classes - - .. autosummary:: - - TFBuildPhase - TFFftDiffract - TFMonochromaticPSF - TFZernikeMonochromaticPSF - TFZernikeOPD - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.psf_models.tf_psf_field.rst b/docs/source/_autosummary/wf_psf.psf_models.tf_psf_field.rst deleted file mode 100644 index 8ab4325b..00000000 --- a/docs/source/_autosummary/wf_psf.psf_models.tf_psf_field.rst +++ /dev/null @@ -1,38 +0,0 @@ -wf\_psf.psf\_models.tf\_psf\_field -================================== - -.. automodule:: wf_psf.psf_models.tf_psf_field - - - - - - - - .. rubric:: Functions - - .. autosummary:: - - get_ground_truth_zernike - - - - - - .. rubric:: Classes - - .. autosummary:: - - GroundTruthPhysicalFieldFactory - GroundTruthSemiParamFieldFactory - TFGroundTruthPhysicalField - TFGroundTruthSemiParametricField - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.psf_models.zernikes.rst b/docs/source/_autosummary/wf_psf.psf_models.zernikes.rst deleted file mode 100644 index 54b92e7c..00000000 --- a/docs/source/_autosummary/wf_psf.psf_models.zernikes.rst +++ /dev/null @@ -1,29 +0,0 @@ -wf\_psf.psf\_models.zernikes -============================ - -.. automodule:: wf_psf.psf_models.zernikes - - - - - - - - .. rubric:: Functions - - .. autosummary:: - - zernike_generator - - - - - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.run.rst b/docs/source/_autosummary/wf_psf.run.rst deleted file mode 100644 index a4b725b2..00000000 --- a/docs/source/_autosummary/wf_psf.run.rst +++ /dev/null @@ -1,30 +0,0 @@ -wf\_psf.run -=========== - -.. automodule:: wf_psf.run - - - - - - - - .. rubric:: Functions - - .. autosummary:: - - mainMethod - setProgramOptions - - - - - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.sims.psf_simulator.rst b/docs/source/_autosummary/wf_psf.sims.psf_simulator.rst deleted file mode 100644 index 7d29d3ba..00000000 --- a/docs/source/_autosummary/wf_psf.sims.psf_simulator.rst +++ /dev/null @@ -1,29 +0,0 @@ -wf\_psf.sims.psf\_simulator -=========================== - -.. automodule:: wf_psf.sims.psf_simulator - - - - - - - - - - - - .. rubric:: Classes - - .. autosummary:: - - PSFSimulator - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.sims.rst b/docs/source/_autosummary/wf_psf.sims.rst deleted file mode 100644 index b0f0901d..00000000 --- a/docs/source/_autosummary/wf_psf.sims.rst +++ /dev/null @@ -1,32 +0,0 @@ -wf\_psf.sims -============ - -.. automodule:: wf_psf.sims - - - - - - - - - - - - - - - - - - - -.. rubric:: Modules - -.. autosummary:: - :toctree: - :recursive: - - wf_psf.sims.psf_simulator - wf_psf.sims.spatial_varying_psf - diff --git a/docs/source/_autosummary/wf_psf.sims.spatial_varying_psf.rst b/docs/source/_autosummary/wf_psf.sims.spatial_varying_psf.rst deleted file mode 100644 index 3fee34bc..00000000 --- a/docs/source/_autosummary/wf_psf.sims.spatial_varying_psf.rst +++ /dev/null @@ -1,33 +0,0 @@ -wf\_psf.sims.spatial\_varying\_psf -================================== - -.. automodule:: wf_psf.sims.spatial_varying_psf - - - - - - - - - - - - .. rubric:: Classes - - .. autosummary:: - - CoordinateHelper - MeshHelper - PolynomialMatrixHelper - SpatialVaryingPSF - ZernikeHelper - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.training.rst b/docs/source/_autosummary/wf_psf.training.rst deleted file mode 100644 index aa692d15..00000000 --- a/docs/source/_autosummary/wf_psf.training.rst +++ /dev/null @@ -1,32 +0,0 @@ -wf\_psf.training -================ - -.. automodule:: wf_psf.training - - - - - - - - - - - - - - - - - - - -.. rubric:: Modules - -.. autosummary:: - :toctree: - :recursive: - - wf_psf.training.train - wf_psf.training.train_utils - diff --git a/docs/source/_autosummary/wf_psf.training.train.rst b/docs/source/_autosummary/wf_psf.training.train.rst deleted file mode 100644 index f8f3faea..00000000 --- a/docs/source/_autosummary/wf_psf.training.train.rst +++ /dev/null @@ -1,39 +0,0 @@ -wf\_psf.training.train -====================== - -.. automodule:: wf_psf.training.train - - - - - - - - .. rubric:: Functions - - .. autosummary:: - - filepath_chkp_callback - get_gpu_info - get_loss_metrics_monitor_and_outputs - setup_training - train - - - - - - .. rubric:: Classes - - .. autosummary:: - - TrainingParamsHandler - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.training.train_utils.rst b/docs/source/_autosummary/wf_psf.training.train_utils.rst deleted file mode 100644 index 153313e1..00000000 --- a/docs/source/_autosummary/wf_psf.training.train_utils.rst +++ /dev/null @@ -1,43 +0,0 @@ -wf\_psf.training.train\_utils -============================= - -.. automodule:: wf_psf.training.train_utils - - - - - - - - .. rubric:: Functions - - .. autosummary:: - - calculate_sample_weights - configure_optimizer_and_loss - general_train_cycle - get_callbacks - l1_schedule_rule - masked_mse - train_cycle_part - - - - - - .. rubric:: Classes - - .. autosummary:: - - L1ParamScheduler - MaskedMeanSquaredError - MaskedMeanSquaredErrorMetric - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.utils.ccd_misalignments.rst b/docs/source/_autosummary/wf_psf.utils.ccd_misalignments.rst deleted file mode 100644 index 00f69a01..00000000 --- a/docs/source/_autosummary/wf_psf.utils.ccd_misalignments.rst +++ /dev/null @@ -1,29 +0,0 @@ -wf\_psf.utils.ccd\_misalignments -================================ - -.. automodule:: wf_psf.utils.ccd_misalignments - - - - - - - - - - - - .. rubric:: Classes - - .. autosummary:: - - CCDMisalignmentCalculator - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.utils.centroids.rst b/docs/source/_autosummary/wf_psf.utils.centroids.rst deleted file mode 100644 index 64aed9ee..00000000 --- a/docs/source/_autosummary/wf_psf.utils.centroids.rst +++ /dev/null @@ -1,39 +0,0 @@ -wf\_psf.utils.centroids -======================= - -.. automodule:: wf_psf.utils.centroids - - - - - - - - .. rubric:: Functions - - .. autosummary:: - - compute_zernike_tip_tilt - decim - degradation_op - lanczos - shift_ker_stack - - - - - - .. rubric:: Classes - - .. autosummary:: - - CentroidEstimator - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.utils.configs_handler.rst b/docs/source/_autosummary/wf_psf.utils.configs_handler.rst deleted file mode 100644 index 0f1ca838..00000000 --- a/docs/source/_autosummary/wf_psf.utils.configs_handler.rst +++ /dev/null @@ -1,46 +0,0 @@ -wf\_psf.utils.configs\_handler -============================== - -.. automodule:: wf_psf.utils.configs_handler - - - - - - - - .. rubric:: Functions - - .. autosummary:: - - get_run_config - register_configclass - set_run_config - - - - - - .. rubric:: Classes - - .. autosummary:: - - DataConfigHandler - MetricsConfigHandler - PlottingConfigHandler - TrainingConfigHandler - - - - - - .. rubric:: Exceptions - - .. autosummary:: - - ConfigParameterError - - - - - diff --git a/docs/source/_autosummary/wf_psf.utils.graph_utils.rst b/docs/source/_autosummary/wf_psf.utils.graph_utils.rst deleted file mode 100644 index fd373ab4..00000000 --- a/docs/source/_autosummary/wf_psf.utils.graph_utils.rst +++ /dev/null @@ -1,37 +0,0 @@ -wf\_psf.utils.graph\_utils -========================== - -.. automodule:: wf_psf.utils.graph_utils - - - - - - - - .. rubric:: Functions - - .. autosummary:: - - gen_Pea - pairwise_distances - select_vstar - - - - - - .. rubric:: Classes - - .. autosummary:: - - GraphBuilder - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.utils.io.rst b/docs/source/_autosummary/wf_psf.utils.io.rst deleted file mode 100644 index 37155165..00000000 --- a/docs/source/_autosummary/wf_psf.utils.io.rst +++ /dev/null @@ -1,29 +0,0 @@ -wf\_psf.utils.io -================ - -.. automodule:: wf_psf.utils.io - - - - - - - - - - - - .. rubric:: Classes - - .. autosummary:: - - FileIOHandler - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.utils.preprocessing.rst b/docs/source/_autosummary/wf_psf.utils.preprocessing.rst deleted file mode 100644 index 0f8fed5d..00000000 --- a/docs/source/_autosummary/wf_psf.utils.preprocessing.rst +++ /dev/null @@ -1,31 +0,0 @@ -wf\_psf.utils.preprocessing -=========================== - -.. automodule:: wf_psf.utils.preprocessing - - - - - - - - .. rubric:: Functions - - .. autosummary:: - - defocus_to_zk4_wavediff - defocus_to_zk4_zemax - shift_x_y_to_zk1_2_wavediff - - - - - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.utils.read_config.rst b/docs/source/_autosummary/wf_psf.utils.read_config.rst deleted file mode 100644 index adccbe4e..00000000 --- a/docs/source/_autosummary/wf_psf.utils.read_config.rst +++ /dev/null @@ -1,37 +0,0 @@ -wf\_psf.utils.read\_config -========================== - -.. automodule:: wf_psf.utils.read_config - - - - - - - - .. rubric:: Functions - - .. autosummary:: - - read_conf - read_stream - read_yaml - - - - - - .. rubric:: Classes - - .. autosummary:: - - RecursiveNamespace - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.utils.rst b/docs/source/_autosummary/wf_psf.utils.rst deleted file mode 100644 index 3c1f8660..00000000 --- a/docs/source/_autosummary/wf_psf.utils.rst +++ /dev/null @@ -1,38 +0,0 @@ -wf\_psf.utils -============= - -.. automodule:: wf_psf.utils - - - - - - - - - - - - - - - - - - - -.. rubric:: Modules - -.. autosummary:: - :toctree: - :recursive: - - wf_psf.utils.ccd_misalignments - wf_psf.utils.centroids - wf_psf.utils.configs_handler - wf_psf.utils.graph_utils - wf_psf.utils.io - wf_psf.utils.preprocessing - wf_psf.utils.read_config - wf_psf.utils.utils - diff --git a/docs/source/_autosummary/wf_psf.utils.utils.rst b/docs/source/_autosummary/wf_psf.utils.utils.rst deleted file mode 100644 index 71e7d532..00000000 --- a/docs/source/_autosummary/wf_psf.utils.utils.rst +++ /dev/null @@ -1,51 +0,0 @@ -wf\_psf.utils.utils -=================== - -.. automodule:: wf_psf.utils.utils - - - - - - - - .. rubric:: Functions - - .. autosummary:: - - add_noise - calc_poly_position_mat - compute_unobscured_zernike_projection - convert_to_tf - decimate_im - decompose_tf_obscured_opd_basis - downsample_im - generalised_sigmoid - generate_SED_elems - generate_SED_elems_in_tensorflow - generate_n_mask - generate_packed_elems - load_multi_cycle_params_click - single_mask_generator - zernike_generator - - - - - - .. rubric:: Classes - - .. autosummary:: - - IndependentZernikeInterpolation - NoiseEstimator - ZernikeInterpolation - - - - - - - - - From 6ecb9ba60e899c0558f960267319ed4fd3c91293 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Mon, 24 Nov 2025 15:41:34 +0100 Subject: [PATCH 005/129] Add arduino formatting to directory structure example --- docs/source/basic_execution.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/basic_execution.md b/docs/source/basic_execution.md index c29f30fe..55f9c089 100644 --- a/docs/source/basic_execution.md +++ b/docs/source/basic_execution.md @@ -34,7 +34,7 @@ WaveDiff creates an output directory at the location specified by `--outputdir` Inside this directory, the following subdirectories are generated: (wf-outputs)= -``` +```arduino wf-outputs-20231119151932213823 ├── checkpoint ├── config From 11bad16c29e255d94f9d00b8153d939321395e77 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Tue, 25 Nov 2025 10:23:11 +0100 Subject: [PATCH 006/129] Revise WaveDiff how-to guide in doc (part 1) - Shorten sections to make developer-friendly - Added new structure for each section: Purppose, Key Fields, Notes, General Notes, etc, where applicable - Partial completion - new PR is required to verify optional versus required settings --- docs/source/configuration.md | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/docs/source/configuration.md b/docs/source/configuration.md index ba87cef2..3d2f31bd 100644 --- a/docs/source/configuration.md +++ b/docs/source/configuration.md @@ -20,7 +20,13 @@ You configure these tasks by passing a configuration file to the `wavediff` comm WaveDiff expects the following configuration files under the `config/` directory: -``` +You configure these tasks by passing a configuration file to the `wavediff` command (e.g., `--config configs.yaml`). + +## Configuration File Structure + +WaveDiff expects the following configuration files under the `config/` directory: + +```arduino config ├── configs.yaml ├── data_config.yaml @@ -220,7 +226,6 @@ training_hparams: n_epochs_non_params: [100, 120] ``` - (metrics_config)= ## `metrics_config.yaml` — Metrics Configuration @@ -402,7 +407,10 @@ plotting_params: ### 4. Example Directory Structure Below is an example of three WaveDiff runs stored under a single parent directory: -``` +**Example Directory Structure** +Below is an example of three WaveDiff runs stored under a single parent directory: + +```arduino wf-outputs/ ├── wf-outputs-202305271829 │ ├── config From 59dbccd5df4f104b2118b5584e793807309a047c Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Tue, 25 Nov 2025 14:55:37 +0100 Subject: [PATCH 007/129] Fix myst error with invalid code format --- docs/source/basic_execution.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/basic_execution.md b/docs/source/basic_execution.md index 55f9c089..c29f30fe 100644 --- a/docs/source/basic_execution.md +++ b/docs/source/basic_execution.md @@ -34,7 +34,7 @@ WaveDiff creates an output directory at the location specified by `--outputdir` Inside this directory, the following subdirectories are generated: (wf-outputs)= -```arduino +``` wf-outputs-20231119151932213823 ├── checkpoint ├── config From 12b89697f75607527399900aaaf1d32bf9f7dd03 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Tue, 13 May 2025 11:33:38 +0200 Subject: [PATCH 008/129] Correct syntax in docstring and generalise exception message --- src/wf_psf/psf_models/psf_models.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/wf_psf/psf_models/psf_models.py b/src/wf_psf/psf_models/psf_models.py index 463d1c52..4c44f698 100644 --- a/src/wf_psf/psf_models/psf_models.py +++ b/src/wf_psf/psf_models/psf_models.py @@ -187,24 +187,24 @@ def build_PSF_model(model_inst, optimizer=None, loss=None, metrics=None): def get_psf_model_weights_filepath(weights_filepath): """Get PSF model weights filepath. - A function to return the basename of the user-specified psf model weights path. + A function to return the basename of the user-specified PSF model weights path. Parameters ---------- weights_filepath: str - Basename of the psf model weights to be loaded. + Basename of the PSF model weights to be loaded. Returns ------- str - The absolute path concatenated to the basename of the psf model weights to be loaded. + The absolute path concatenated to the basename of the PSF model weights to be loaded. """ try: return glob.glob(weights_filepath)[0].split(".")[0] except IndexError: logger.exception( - "PSF weights file not found. Check that you've specified the correct weights file in the metrics config file." + "PSF weights file not found. Check that you've specified the correct weights file in the your config file." ) raise PSFModelError("PSF model weights error.") From eaa0e8e915ad2ce9c4f9a35fc78eafc53dc5bccc Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Tue, 13 May 2025 11:35:24 +0200 Subject: [PATCH 009/129] Add inference and test_inference packages --- src/wf_psf/inference/__init__.py | 0 src/wf_psf/inference/psf_inference.py | 0 src/wf_psf/tests/test_inference/test_psf_inference.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/wf_psf/inference/__init__.py create mode 100644 src/wf_psf/inference/psf_inference.py create mode 100644 src/wf_psf/tests/test_inference/test_psf_inference.py diff --git a/src/wf_psf/inference/__init__.py b/src/wf_psf/inference/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py new file mode 100644 index 00000000..e69de29b diff --git a/src/wf_psf/tests/test_inference/test_psf_inference.py b/src/wf_psf/tests/test_inference/test_psf_inference.py new file mode 100644 index 00000000..e69de29b From 1aebacb4c1cc3acf2eb5dcd5170fb4181c0f8ec6 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Tue, 13 May 2025 14:19:39 +0200 Subject: [PATCH 010/129] Refactor: Encapsulate logic in psf_models package with subpackages: models and tf_modules, add/rm modules, update import statements and tests --- src/wf_psf/psf_models/models/__init__.py | 0 .../{ => models}/psf_model_parametric.py | 2 +- .../psf_model_physical_polychromatic.py | 8 ++++---- .../{ => models}/psf_model_semiparametric.py | 4 ++-- src/wf_psf/psf_models/tf_modules/__init__.py | 0 .../psf_models/{ => tf_modules}/tf_layers.py | 3 ++- .../psf_models/{ => tf_modules}/tf_modules.py | 0 .../psf_models/{ => tf_modules}/tf_psf_field.py | 4 ++-- .../psf_model_physical_polychromatic_test.py | 16 +++++++++------- .../tests/test_psf_models/psf_models_test.py | 6 +++--- 10 files changed, 23 insertions(+), 20 deletions(-) create mode 100644 src/wf_psf/psf_models/models/__init__.py rename src/wf_psf/psf_models/{ => models}/psf_model_parametric.py (99%) rename src/wf_psf/psf_models/{ => models}/psf_model_physical_polychromatic.py (99%) rename src/wf_psf/psf_models/{ => models}/psf_model_semiparametric.py (99%) create mode 100644 src/wf_psf/psf_models/tf_modules/__init__.py rename src/wf_psf/psf_models/{ => tf_modules}/tf_layers.py (99%) rename src/wf_psf/psf_models/{ => tf_modules}/tf_modules.py (100%) rename src/wf_psf/psf_models/{ => tf_modules}/tf_psf_field.py (99%) diff --git a/src/wf_psf/psf_models/models/__init__.py b/src/wf_psf/psf_models/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/wf_psf/psf_models/psf_model_parametric.py b/src/wf_psf/psf_models/models/psf_model_parametric.py similarity index 99% rename from src/wf_psf/psf_models/psf_model_parametric.py rename to src/wf_psf/psf_models/models/psf_model_parametric.py index 0cd703d7..3d85d2bc 100644 --- a/src/wf_psf/psf_models/psf_model_parametric.py +++ b/src/wf_psf/psf_models/models/psf_model_parametric.py @@ -9,7 +9,7 @@ import tensorflow as tf from wf_psf.psf_models.psf_models import register_psfclass -from wf_psf.psf_models.tf_layers import ( +from wf_psf.psf_models.tf_modules.tf_layers import ( TFPolynomialZernikeField, TFZernikeOPD, TFBatchPolychromaticPSF, diff --git a/src/wf_psf/psf_models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py similarity index 99% rename from src/wf_psf/psf_models/psf_model_physical_polychromatic.py rename to src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index ad61c1b5..f9ed8765 100644 --- a/src/wf_psf/psf_models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -10,11 +10,9 @@ from typing import Optional import tensorflow as tf from tensorflow.python.keras.engine import data_adapter -from wf_psf.psf_models import psf_models as psfm -from wf_psf.utils.read_config import RecursiveNamespace -from wf_psf.utils.configs_handler import DataConfigHandler from wf_psf.data.training_preprocessing import get_obs_positions, get_zernike_prior -from wf_psf.psf_models.tf_layers import ( +from wf_psf.psf_models import psf_models as psfm +from wf_psf.psf_models.tf_modules.tf_layers import ( TFPolynomialZernikeField, TFZernikeOPD, TFBatchPolychromaticPSF, @@ -22,6 +20,8 @@ TFNonParametricPolynomialVariationsOPD, TFPhysicalLayer, ) +from wf_psf.utils.read_config import RecursiveNamespace +from wf_psf.utils.configs_handler import DataConfigHandler import logging diff --git a/src/wf_psf/psf_models/psf_model_semiparametric.py b/src/wf_psf/psf_models/models/psf_model_semiparametric.py similarity index 99% rename from src/wf_psf/psf_models/psf_model_semiparametric.py rename to src/wf_psf/psf_models/models/psf_model_semiparametric.py index dc535204..c370956c 100644 --- a/src/wf_psf/psf_models/psf_model_semiparametric.py +++ b/src/wf_psf/psf_models/models/psf_model_semiparametric.py @@ -10,9 +10,9 @@ import numpy as np import tensorflow as tf from wf_psf.psf_models import psf_models as psfm -from wf_psf.psf_models import tf_layers as tfl +from wf_psf.psf_models.tf_modules import tf_layers as tfl from wf_psf.utils.utils import decompose_tf_obscured_opd_basis -from wf_psf.psf_models.tf_layers import ( +from wf_psf.psf_models.tf_modules.tf_layers import ( TFBatchPolychromaticPSF, TFBatchMonochromaticPSF, ) diff --git a/src/wf_psf/psf_models/tf_modules/__init__.py b/src/wf_psf/psf_models/tf_modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/wf_psf/psf_models/tf_layers.py b/src/wf_psf/psf_models/tf_modules/tf_layers.py similarity index 99% rename from src/wf_psf/psf_models/tf_layers.py rename to src/wf_psf/psf_models/tf_modules/tf_layers.py index eda43305..fcb5e8f9 100644 --- a/src/wf_psf/psf_models/tf_layers.py +++ b/src/wf_psf/psf_models/tf_modules/tf_layers.py @@ -7,7 +7,8 @@ """ import tensorflow as tf -from wf_psf.psf_models.tf_modules import TFMonochromaticPSF +import tensorflow_addons as tfa +from wf_psf.psf_models.tf_modules.tf_modules import TFMonochromaticPSF from wf_psf.utils.utils import calc_poly_position_mat import wf_psf.utils.utils as utils from wf_psf.utils.interpolation import tfa_interpolate_spline_rbf diff --git a/src/wf_psf/psf_models/tf_modules.py b/src/wf_psf/psf_models/tf_modules/tf_modules.py similarity index 100% rename from src/wf_psf/psf_models/tf_modules.py rename to src/wf_psf/psf_models/tf_modules/tf_modules.py diff --git a/src/wf_psf/psf_models/tf_psf_field.py b/src/wf_psf/psf_models/tf_modules/tf_psf_field.py similarity index 99% rename from src/wf_psf/psf_models/tf_psf_field.py rename to src/wf_psf/psf_models/tf_modules/tf_psf_field.py index 0c9ba2f7..e25657d9 100644 --- a/src/wf_psf/psf_models/tf_psf_field.py +++ b/src/wf_psf/psf_models/tf_modules/tf_psf_field.py @@ -9,13 +9,13 @@ import numpy as np import tensorflow as tf from tensorflow.python.keras.engine import data_adapter -from wf_psf.psf_models.tf_layers import ( +from wf_psf.psf_models.tf_modules.tf_layers import ( TFZernikeOPD, TFBatchPolychromaticPSF, TFBatchMonochromaticPSF, TFPhysicalLayer, ) -from wf_psf.psf_models.psf_model_semiparametric import TFSemiParametricField +from wf_psf.psf_models.models.psf_model_semiparametric import TFSemiParametricField from wf_psf.data.training_preprocessing import get_obs_positions from wf_psf.psf_models import psf_models as psfm import logging diff --git a/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py b/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py index cae7b141..e25d608d 100644 --- a/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py +++ b/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py @@ -9,7 +9,7 @@ import pytest import numpy as np import tensorflow as tf -from wf_psf.psf_models.psf_model_physical_polychromatic import ( +from wf_psf.psf_models.models.psf_model_physical_polychromatic import ( TFPhysicalPolychromaticField, ) from wf_psf.utils.configs_handler import DataConfigHandler @@ -54,7 +54,7 @@ def test_initialize_parameters(mocker, mock_data, mock_model_params, zks_prior): # Mock internal methods called during initialization mocker.patch( - "wf_psf.psf_models.psf_model_physical_polychromatic.get_zernike_prior", + "wf_psf.psf_models.models.psf_model_physical_polychromatic.get_zernike_prior", return_value=zks_prior, ) @@ -92,7 +92,7 @@ def test_initialize_zernike_parameters(mocker, mock_model_params, mock_data, zks # Mock internal methods called during initialization mocker.patch( - "wf_psf.psf_models.psf_model_physical_polychromatic.get_zernike_prior", + "wf_psf.psf_models.models.psf_model_physical_polychromatic.get_zernike_prior", return_value=zks_prior, ) @@ -146,13 +146,13 @@ def test_initialize_physical_layer_mocking( # Mock internal methods called during initialization mocker.patch( - "wf_psf.psf_models.psf_model_physical_polychromatic.get_zernike_prior", + "wf_psf.psf_models.models.psf_model_physical_polychromatic.get_zernike_prior", return_value=zks_prior, ) # Create a mock for the TFPhysicalLayer class mock_physical_layer_class = mocker.patch( - "wf_psf.psf_models.psf_model_physical_polychromatic.TFPhysicalLayer" + "wf_psf.psf_models.models.psf_model_physical_polychromatic.TFPhysicalLayer" ) # Create TFPhysicalPolychromaticField instance @@ -176,12 +176,14 @@ def physical_layer_instance(mocker, mock_model_params, mock_data, zks_prior): # Mock internal methods called during initialization mocker.patch( - "wf_psf.psf_models.psf_model_physical_polychromatic.get_zernike_prior", + "wf_psf.psf_models.models.psf_model_physical_polychromatic.get_zernike_prior", return_value=zks_prior, ) # Create a mock for the TFPhysicalLayer class - mocker.patch("wf_psf.psf_models.psf_model_physical_polychromatic.TFPhysicalLayer") + mocker.patch( + "wf_psf.psf_models.models.psf_model_physical_polychromatic.TFPhysicalLayer" + ) # Create TFPhysicalPolychromaticField instance psf_field_instance = TFPhysicalPolychromaticField( diff --git a/src/wf_psf/tests/test_psf_models/psf_models_test.py b/src/wf_psf/tests/test_psf_models/psf_models_test.py index 066e1328..b7c906f6 100644 --- a/src/wf_psf/tests/test_psf_models/psf_models_test.py +++ b/src/wf_psf/tests/test_psf_models/psf_models_test.py @@ -7,10 +7,10 @@ """ -from wf_psf.psf_models import ( - psf_models, +from wf_psf.psf_models import psf_models +from wf_psf.psf_models.models import ( psf_model_semiparametric, - psf_model_physical_polychromatic, + psf_model_physical_polychromatic ) import tensorflow as tf import numpy as np From d66d2c92e77ad98e05510071205a54dfbe4b7ff7 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Tue, 13 May 2025 14:29:13 +0200 Subject: [PATCH 011/129] Remove unused module with duplicate zernike_generator function --- src/wf_psf/psf_models/zernikes.py | 58 ------------------------------- 1 file changed, 58 deletions(-) delete mode 100644 src/wf_psf/psf_models/zernikes.py diff --git a/src/wf_psf/psf_models/zernikes.py b/src/wf_psf/psf_models/zernikes.py deleted file mode 100644 index dcfa6e39..00000000 --- a/src/wf_psf/psf_models/zernikes.py +++ /dev/null @@ -1,58 +0,0 @@ -"""Zernikes. - -A module to make Zernike maps. - -:Author: Tobias Liaudat and Jennifer Pollack - -""" - -import numpy as np -import zernike as zk -import logging - -logger = logging.getLogger(__name__) - - -def zernike_generator(n_zernikes, wfe_dim): - """ - Generate Zernike maps. - - Based on the zernike github repository. - https://github.com/jacopoantonello/zernike - - Parameters - ---------- - n_zernikes: int - Number of Zernike modes desired. - wfe_dim: int - Dimension of the Zernike map [wfe_dim x wfe_dim]. - - Returns - ------- - zernikes: list of np.ndarray - List containing the Zernike modes. - The values outside the unit circle are filled with NaNs. - """ - # Calculate which n (from the (n,m) Zernike convention) we need - # so that we have the desired total number of Zernike coefficients - min_n = (-3 + np.sqrt(1 + 8 * n_zernikes)) / 2 - n = int(np.ceil(min_n)) - - # Initialize the zernike generator - cart = zk.RZern(n) - # Create a [-1,1] mesh - ddx = np.linspace(-1.0, 1.0, wfe_dim) - ddy = np.linspace(-1.0, 1.0, wfe_dim) - xv, yv = np.meshgrid(ddx, ddy) - cart.make_cart_grid(xv, yv) - - c = np.zeros(cart.nk) - zernikes = [] - - # Extract each Zernike map one by one - for i in range(n_zernikes): - c *= 0.0 - c[i] = 1.0 - zernikes.append(cart.eval_grid(c, matrix=True)) - - return zernikes From b9cc93729ff2c3e6d8cbc83fe09db254f004b190 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Tue, 13 May 2025 14:31:47 +0200 Subject: [PATCH 012/129] Correct syntax in docstrings and logger messages --- src/wf_psf/utils/configs_handler.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/wf_psf/utils/configs_handler.py b/src/wf_psf/utils/configs_handler.py index 5d45ba79..07737da0 100644 --- a/src/wf_psf/utils/configs_handler.py +++ b/src/wf_psf/utils/configs_handler.py @@ -1,4 +1,4 @@ -"""Configs_Handler. +s"""Configs_Handler. A module which provides general utility methods to manage the parameters of the config files @@ -443,12 +443,12 @@ def _load_data_conf(self): def weights_basename_filepath(self): """Get PSF model weights filepath. - A function to return the basename of the user-specified psf model weights path. + A function to return the basename of the user-specified PSF model weights path. Returns ------- weights_basename: str - The basename of the psf model weights to be loaded. + The basename of the PSF model weights to be loaded. """ return os.path.join( @@ -502,7 +502,7 @@ def call_plot_config_handler_run(self, model_metrics): def run(self): """Run. - A function to run wave-diff according to the + A function to run WaveDiff according to the input configuration. """ From f72e0e5ad482efeb5463c71dbf477843e13be9a8 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Wed, 14 May 2025 17:22:44 +0200 Subject: [PATCH 013/129] Refactor file structure; update import statements in tests; remove unit test due to refactoring --- src/wf_psf/{utils => data}/centroids.py | 2 +- .../data_preprocessing.py} | 2 +- src/wf_psf/data/training_preprocessing.py | 4 +-- src/wf_psf/instrument/__init__.py | 0 .../ccd_misalignments.py | 0 .../centroids_test.py | 10 +++---- .../tests/test_utils/configs_handler_test.py | 27 +++++-------------- 7 files changed, 15 insertions(+), 30 deletions(-) rename src/wf_psf/{utils => data}/centroids.py (99%) rename src/wf_psf/{utils/preprocessing.py => data/data_preprocessing.py} (99%) create mode 100644 src/wf_psf/instrument/__init__.py rename src/wf_psf/{utils => instrument}/ccd_misalignments.py (100%) rename src/wf_psf/tests/{test_utils => test_data}/centroids_test.py (97%) diff --git a/src/wf_psf/utils/centroids.py b/src/wf_psf/data/centroids.py similarity index 99% rename from src/wf_psf/utils/centroids.py rename to src/wf_psf/data/centroids.py index 8b4522bf..e414520e 100644 --- a/src/wf_psf/utils/centroids.py +++ b/src/wf_psf/data/centroids.py @@ -8,7 +8,7 @@ import numpy as np import scipy.signal as scisig -from wf_psf.utils.preprocessing import shift_x_y_to_zk1_2_wavediff +from wf_psf.data.data_preprocessing import shift_x_y_to_zk1_2_wavediff from typing import Optional diff --git a/src/wf_psf/utils/preprocessing.py b/src/wf_psf/data/data_preprocessing.py similarity index 99% rename from src/wf_psf/utils/preprocessing.py rename to src/wf_psf/data/data_preprocessing.py index 210c03e5..44e18436 100644 --- a/src/wf_psf/utils/preprocessing.py +++ b/src/wf_psf/data/data_preprocessing.py @@ -1,4 +1,4 @@ -"""Preprocessing. +"""Data Preprocessing. A module with utils to preprocess data. diff --git a/src/wf_psf/data/training_preprocessing.py b/src/wf_psf/data/training_preprocessing.py index c1402f06..25180ebe 100644 --- a/src/wf_psf/data/training_preprocessing.py +++ b/src/wf_psf/data/training_preprocessing.py @@ -10,8 +10,8 @@ import numpy as np import wf_psf.utils.utils as utils import tensorflow as tf -from wf_psf.utils.ccd_misalignments import CCDMisalignmentCalculator -from wf_psf.utils.centroids import compute_zernike_tip_tilt +from wf_psf.instrument.ccd_misalignments import CCDMisalignmentCalculator +from wf_psf.data.centroids import compute_zernike_tip_tilt from fractions import Fraction import logging diff --git a/src/wf_psf/instrument/__init__.py b/src/wf_psf/instrument/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/wf_psf/utils/ccd_misalignments.py b/src/wf_psf/instrument/ccd_misalignments.py similarity index 100% rename from src/wf_psf/utils/ccd_misalignments.py rename to src/wf_psf/instrument/ccd_misalignments.py diff --git a/src/wf_psf/tests/test_utils/centroids_test.py b/src/wf_psf/tests/test_data/centroids_test.py similarity index 97% rename from src/wf_psf/tests/test_utils/centroids_test.py rename to src/wf_psf/tests/test_data/centroids_test.py index 8557704f..c55e18f9 100644 --- a/src/wf_psf/tests/test_utils/centroids_test.py +++ b/src/wf_psf/tests/test_data/centroids_test.py @@ -9,7 +9,7 @@ import numpy as np import pytest from unittest.mock import MagicMock, patch -from wf_psf.utils.centroids import compute_zernike_tip_tilt, CentroidEstimator +from wf_psf.data.centroids import compute_zernike_tip_tilt, CentroidEstimator # Function to compute centroid based on first-order moments @@ -133,7 +133,7 @@ def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_ma """Test compute_zernike_tip_tilt with single batch input and mocks.""" # Mock the CentroidEstimator class mock_centroid_calc = mocker.patch( - "wf_psf.utils.centroids.CentroidEstimator", autospec=True + "wf_psf.data.centroids.CentroidEstimator", autospec=True ) # Create a mock instance and configure get_intra_pixel_shifts() @@ -144,7 +144,7 @@ def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_ma # Mock shift_x_y_to_zk1_2_wavediff to return predictable values mock_shift_fn = mocker.patch( - "wf_psf.utils.centroids.shift_x_y_to_zk1_2_wavediff", + "wf_psf.data.centroids.shift_x_y_to_zk1_2_wavediff", side_effect=lambda shift: shift * 0.5, # Mocked conversion for test ) @@ -191,7 +191,7 @@ def test_compute_zernike_tip_tilt_batch(mocker, multiple_images): """Test compute_zernike_tip_tilt with batch input and mocks.""" # Mock the CentroidEstimator class mock_centroid_calc = mocker.patch( - "wf_psf.utils.centroids.CentroidEstimator", autospec=True + "wf_psf.data.centroids.CentroidEstimator", autospec=True ) # Create a mock instance and configure get_intra_pixel_shifts() @@ -202,7 +202,7 @@ def test_compute_zernike_tip_tilt_batch(mocker, multiple_images): # Mock shift_x_y_to_zk1_2_wavediff to return predictable values mock_shift_fn = mocker.patch( - "wf_psf.utils.centroids.shift_x_y_to_zk1_2_wavediff", + "wf_psf.data.centroids.shift_x_y_to_zk1_2_wavediff", side_effect=lambda shift: shift * 0.5, # Mocked conversion for test ) diff --git a/src/wf_psf/tests/test_utils/configs_handler_test.py b/src/wf_psf/tests/test_utils/configs_handler_test.py index d95761e9..9b4039a8 100644 --- a/src/wf_psf/tests/test_utils/configs_handler_test.py +++ b/src/wf_psf/tests/test_utils/configs_handler_test.py @@ -4,15 +4,18 @@ :Author: Jennifer Pollack - """ import pytest +from wf_psf.data.training_preprocessing import DataHandler from wf_psf.utils import configs_handler from wf_psf.utils.read_config import RecursiveNamespace from wf_psf.utils.io import FileIOHandler -from wf_psf.utils.configs_handler import TrainingConfigHandler, DataConfigHandler -from wf_psf.data.training_preprocessing import DataHandler +from wf_psf.utils.configs_handler import ( + TrainingConfigHandler, + MetricsConfigHandler, + DataConfigHandler, +) import os @@ -229,21 +232,3 @@ def test_run_method_calls_train_with_correct_arguments( mock_th.optimizer_dir, mock_th.psf_model_dir, ) - - -def test_MetricsConfigHandler_weights_basename_filepath( - path_to_tmp_output_dir, path_to_config_dir -): - test_file_handler = FileIOHandler(path_to_tmp_output_dir, path_to_config_dir) - - metrics_config_file = "validation/main_random_seed/config/metrics_config.yaml" - - metrics_object = configs_handler.MetricsConfigHandler( - os.path.join(path_to_config_dir, metrics_config_file), test_file_handler - ) - weights_filepath = metrics_object.weights_basename_filepath - - assert ( - weights_filepath - == "src/wf_psf/tests/data/validation/main_random_seed/checkpoint/checkpoint*_poly*_sample_w_bis1_2k_cycle2*" - ) From a86352caae93a143ef4c52f789f261ab0b8f77ec Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Wed, 14 May 2025 20:35:12 +0200 Subject: [PATCH 014/129] Update package name in import statement --- src/wf_psf/instrument/ccd_misalignments.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/wf_psf/instrument/ccd_misalignments.py b/src/wf_psf/instrument/ccd_misalignments.py index 1d2135ba..3b7a1eb3 100644 --- a/src/wf_psf/instrument/ccd_misalignments.py +++ b/src/wf_psf/instrument/ccd_misalignments.py @@ -10,7 +10,7 @@ import numpy as np import matplotlib.path as mpltPath from scipy.spatial import KDTree -from wf_psf.utils.preprocessing import defocus_to_zk4_wavediff +from wf_psf.data.data_preprocessing import defocus_to_zk4_wavediff class CCDMisalignmentCalculator: From 2b262288f368d5cbcc6b89530120177c6331ef34 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Wed, 14 May 2025 22:54:33 +0200 Subject: [PATCH 015/129] Reorder imports; Refactor MetricsConfigHandler class attributes, methods and variable names| --- src/wf_psf/utils/configs_handler.py | 231 ++++++++++++---------------- 1 file changed, 98 insertions(+), 133 deletions(-) diff --git a/src/wf_psf/utils/configs_handler.py b/src/wf_psf/utils/configs_handler.py index 07737da0..e6764873 100644 --- a/src/wf_psf/utils/configs_handler.py +++ b/src/wf_psf/utils/configs_handler.py @@ -12,12 +12,14 @@ import os import re import glob -from wf_psf.utils.read_config import read_conf from wf_psf.data.training_preprocessing import DataHandler -from wf_psf.training import train -from wf_psf.psf_models import psf_models from wf_psf.metrics.metrics_interface import evaluate_model from wf_psf.plotting.plots_interface import plot_metrics +from wf_psf.psf_models import psf_models +from wf_psf.psf_models.psf_model_loader import load_trained_psf_model +from wf_psf.training import train +from wf_psf.utils.read_config import read_conf + logger = logging.getLogger(__name__) @@ -253,9 +255,11 @@ class MetricsConfigHandler: def __init__(self, metrics_conf, file_handler, training_conf=None): self._metrics_conf = read_conf(metrics_conf) - self._file_handler = file_handler - self.trained_model_path = self._get_trained_model_path(training_conf) - self._training_conf = self._load_training_conf(training_conf) + self.data_conf = self._load_data_conf() + self._file_handler = file_handler + self.metrics_dir = self._file_handler.get_metrics_dir(self._file_handler._run_output_dir) + self.training_conf = training_conf + self.trained_psf_model = self.load_trained_psf_model(self.training_conf, self.data_conf ) @property def metrics_conf(self): @@ -270,32 +274,27 @@ def metrics_conf(self): """ return self._metrics_conf - @property - def metrics_dir(self): - """Get Metrics Directory. - - A function that returns path - of metrics directory. - - Returns - ------- - str - Absolute path to metrics directory - """ - return self._file_handler.get_metrics_dir(self._file_handler._run_output_dir) - @property def training_conf(self): - """Get Training Conf. - - A function to return the training configuration file name. + """Returns the loaded training configuration.""" + return self._training_conf - Returns - ------- - RecursiveNamespace - An instance of the training configuration file. + @training_conf.setter + def training_conf(self, training_conf): """ - return self._training_conf + Sets the training configuration. If None is provided, attempts to load it + from the trained_model_path in the metrics configuration. + """ + if training_conf is None: + try: + training_conf_path = self._get_training_conf_path_from_metrics() + logger.info(f"Loading training config from inferred path: {training_conf_path}") + self._training_conf = read_conf(training_conf_path) + except Exception as e: + logger.error(f"Failed to load training config: {e}") + raise + else: + self._training_conf = training_conf @property def plotting_conf(self): @@ -310,112 +309,99 @@ def plotting_conf(self): """ return self.metrics_conf.metrics.plotting_config - @property - def data_conf(self): - """Get Data Conf. - - A function to return an instance of the DataConfigHandler class. - - Returns - ------- - An instance of the DataConfigHandler class. - """ - return self._load_data_conf() - - @property - def psf_model(self): - """Get PSF Model. - - A function to return an instance of the PSF model - to be evaluated. - - Returns - ------- - psf_model: obj - An instance of the PSF model to be evaluated. - """ - return psf_models.get_psf_model( - self.training_conf.training.model_params, - self.training_conf.training.training_hparams, + def _load_trained_psf_model(self): + trained_model_path = self._get_trained_model_path() + try: + model_subdir = self.metrics_conf.metrics.model_save_path + cycle = self.metrics_conf.metrics.saved_training_cycle + except AttributeError as e: + raise KeyError("Missing required model config fields.") from e + + model_name = self.training_conf.training.model_params.model_name + id_name = self.training_conf.training.id_name + + weights_path_pattern = os.path.join( + trained_model_path, + model_subdir, + ( + f"{model_subdir}*_{model_name}" + f"*{id_name}_cycle{cycle}*" + ), + ) + return load_trained_psf_model( + self.training_conf, self.data_conf, + weights_path_pattern, ) - @property - def weights_path(self): - """Get Weights Path. - A function to return the full path - of the user-specified psf model weights to be loaded. + def _get_training_conf_path_from_metrics(self): + """ + Retrieves the full path to the training config based on the metrics configuration. Returns ------- str - A string representing the full path to the psf model weights to be loaded. + Full path to the training configuration file. + + Raises + ------ + KeyError + If 'trained_model_config' key is missing. + FileNotFoundError + If the file does not exist at the constructed path. """ - return psf_models.get_psf_model_weights_filepath(self.weights_basename_filepath) - - def _get_trained_model_path(self, training_conf): - """Get Trained Model Path. - - Helper method to get the trained model path. + trained_model_path = self._get_trained_model_path() - Parameters - ---------- - training_conf: None or RecursiveNamespace - None type or RecursiveNamespace - - Returns - ------- - str - A string representing the path to the trained model output run directory. + try: + training_conf_filename = self._metrics_conf.metrics.trained_model_config + except AttributeError as e: + raise KeyError("Missing 'trained_model_config' key in metrics configuration.") from e - """ - if training_conf is None: - try: - return self._metrics_conf.metrics.trained_model_path + training_conf_path = os.path.join( + self._file_handler.get_config_dir(trained_model_path), training_conf_filename) - except TypeError as e: - logger.exception(e) - raise ConfigParameterError( - "Metrics config file trained model path or config values are empty." - ) - else: - return os.path.join( - self._file_handler.output_path, - self._file_handler.parent_output_dir, - self._file_handler.workdir, - ) + if not os.path.exists(training_conf_path): + raise FileNotFoundError(f"Training config file not found: {training_conf_path}") - def _load_training_conf(self, training_conf): - """Load Training Conf. + return training_conf_path - Load the training configuration if training_conf is not provided. - Parameters - ---------- - training_conf: None or RecursiveNamespace - None type or a RecursiveNamespace storing the training configuration parameter setttings. + def _get_trained_model_path(self): + """ + Determine the trained model path from either: + + 1. The metrics configuration file (i.e., for metrics-only runs after training), or + 2. The runtime-generated file handler paths (i.e., for single runs that perform both training and evaluation). Returns ------- - RecursiveNamespace storing the training configuration parameter settings. + str + Path to the trained model directory. + Raises + ------ + ConfigParameterError + If the path specified in the metrics config is invalid or missing. """ - if training_conf is None: - try: - return read_conf( - os.path.join( - self._file_handler.get_config_dir(self.trained_model_path), - self._metrics_conf.metrics.trained_model_config, - ) - ) - except TypeError as e: - logger.exception(e) + trained_model_path = getattr(self._metrics_conf.metrics, "trained_model_path", None) + + if trained_model_path: + if not os.path.isdir(trained_model_path): raise ConfigParameterError( - "Metrics config file trained model path or config values are empty." + f"The trained model path provided in the metrics config is not a valid directory: {trained_model_path}" ) - else: - return training_conf + logger.info(f"Using trained model path from metrics config: {trained_model_path}") + return trained_model_path + + # Fallback for single-run training + metrics evaluation mode + fallback_path = os.path.join( + self._file_handler.output_path, + self._file_handler.parent_output_dir, + self._file_handler.workdir, + ) + logger.info(f"Using fallback trained model path from runtime file handler: {fallback_path}") + return fallback_path def _load_data_conf(self): """Load Data Conf. @@ -439,26 +425,6 @@ def _load_data_conf(self): logger.exception(e) raise ConfigParameterError("Data configuration loading error.") - @property - def weights_basename_filepath(self): - """Get PSF model weights filepath. - - A function to return the basename of the user-specified PSF model weights path. - - Returns - ------- - weights_basename: str - The basename of the PSF model weights to be loaded. - - """ - return os.path.join( - self.trained_model_path, - self.metrics_conf.metrics.model_save_path, - ( - f"{self.metrics_conf.metrics.model_save_path}*_{self.training_conf.training.model_params.model_name}" - f"*{self.training_conf.training.id_name}_cycle{self.metrics_conf.metrics.saved_training_cycle}*" - ), - ) def call_plot_config_handler_run(self, model_metrics): """Make Metrics Plots. @@ -513,7 +479,6 @@ def run(self): self.training_conf.training, self.data_conf, self.psf_model, - self.weights_path, self.metrics_dir, ) From b8375ba3493cf59908360f999c8db73f3f96fc73 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Wed, 14 May 2025 22:55:47 +0200 Subject: [PATCH 016/129] Move psf_model weights loader to psf_model_loader.py module --- src/wf_psf/metrics/metrics_interface.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/wf_psf/metrics/metrics_interface.py b/src/wf_psf/metrics/metrics_interface.py index 3dff2c6c..c922a4fb 100644 --- a/src/wf_psf/metrics/metrics_interface.py +++ b/src/wf_psf/metrics/metrics_interface.py @@ -351,14 +351,6 @@ def evaluate_model( # Prepare np input simPSF_np = data.training_data.simPSF - ## Load the model's weights - try: - logger.info(f"Loading PSF model weights from {weights_path}") - psf_model.load_weights(weights_path) - except Exception as e: - logger.exception("An error occurred with the weights_path file: %s", e) - exit() - # Define datasets datasets = {"test": data.test_data.dataset, "train": data.training_data.dataset} From f3876fb0d028277b8add157305d3b536fa94d499 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Wed, 14 May 2025 22:56:13 +0200 Subject: [PATCH 017/129] Add psf_model_loader module --- src/wf_psf/psf_models/psf_model_loader.py | 52 +++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 src/wf_psf/psf_models/psf_model_loader.py diff --git a/src/wf_psf/psf_models/psf_model_loader.py b/src/wf_psf/psf_models/psf_model_loader.py new file mode 100644 index 00000000..26056e57 --- /dev/null +++ b/src/wf_psf/psf_models/psf_model_loader.py @@ -0,0 +1,52 @@ +"""PSF Model Loader. + +This module provides helper functions for loading trained PSF models. +It includes utilities to: +- Load a model from disk using its configuration and weights. +- Prepare inputs for inference or evaluation workflows. + +Author: Jennifer Pollack +""" +from wf_psf.psf_models.psf_models import ( + get_psf_model, + get_psf_model_weights_filepath +) + +def load_trained_psf_model(training_conf, data_conf, weights_path_pattern): + """ + Loads a trained PSF model and applies saved weights. + + Parameters + ---------- + training_conf : RecursiveNamespace + Configuration object containing model parameters and training hyperparameters. + Supports attribute-style access to nested fields. + data_conf : RecursiveNamespace + Configuration object containing data-related parameters. + weights_path_pattern : str + Glob-style pattern used to locate the model weights file. + + Returns + ------- + model : tf.keras.Model or compatible + The PSF model instance with loaded weights. + + Raises + ------ + RuntimeError + If loading the model weights fails for any reason. + """ + model = get_psf_model(training_conf.training.model_params, + training_conf.training.training_hparams, + data_conf) + + weights_path = get_psf_model_weights_filepath(weights_path_pattern) + + try: + logger.info(f"Loading PSF model weights from {weights_path}") + model.load_weights(weights_path) + except Exception as e: + logger.exception("Failed to load model weights.") + raise RuntimeError("Model weight loading failed.") from e + return model + From ffb49e3f520d083ab93e3dd2794382494d6dbc3b Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 15 May 2025 13:13:38 +0200 Subject: [PATCH 018/129] Remove weights_path arg from evaluate_model method; Update logger.info statement and comment --- src/wf_psf/metrics/metrics_interface.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/wf_psf/metrics/metrics_interface.py b/src/wf_psf/metrics/metrics_interface.py index c922a4fb..bd6249e2 100644 --- a/src/wf_psf/metrics/metrics_interface.py +++ b/src/wf_psf/metrics/metrics_interface.py @@ -311,7 +311,6 @@ def evaluate_model( trained_model_params, data, psf_model, - weights_path, metrics_output, ): """Evaluate the trained model on both training and test datasets by computing various metrics. @@ -341,8 +340,8 @@ def evaluate_model( try: ## Load datasets # ----------------------------------------------------- - # Get training data - logger.info("Fetching and preprocessing training and test data...") + # Get training and test data + logger.info("Fetching training and test data...") # Initialize metrics_handler metrics_handler = MetricsParamsHandler(metrics_params, trained_model_params) From 100127706ea41073b5ac9f0e6278aa47815a3938 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 15 May 2025 13:14:22 +0200 Subject: [PATCH 019/129] Update variable name and logger statement --- src/wf_psf/utils/configs_handler.py | 62 +++++++++++++++++------------ 1 file changed, 36 insertions(+), 26 deletions(-) diff --git a/src/wf_psf/utils/configs_handler.py b/src/wf_psf/utils/configs_handler.py index e6764873..4b08a723 100644 --- a/src/wf_psf/utils/configs_handler.py +++ b/src/wf_psf/utils/configs_handler.py @@ -1,4 +1,4 @@ -s"""Configs_Handler. +"""Configs_Handler. A module which provides general utility methods to manage the parameters of the config files @@ -255,11 +255,13 @@ class MetricsConfigHandler: def __init__(self, metrics_conf, file_handler, training_conf=None): self._metrics_conf = read_conf(metrics_conf) - self.data_conf = self._load_data_conf() - self._file_handler = file_handler - self.metrics_dir = self._file_handler.get_metrics_dir(self._file_handler._run_output_dir) + self._file_handler = file_handler self.training_conf = training_conf - self.trained_psf_model = self.load_trained_psf_model(self.training_conf, self.data_conf ) + self.data_conf = self._load_data_conf() + self.metrics_dir = self._file_handler.get_metrics_dir( + self._file_handler._run_output_dir + ) + self.trained_psf_model = self._load_trained_psf_model() @property def metrics_conf(self): @@ -288,7 +290,9 @@ def training_conf(self, training_conf): if training_conf is None: try: training_conf_path = self._get_training_conf_path_from_metrics() - logger.info(f"Loading training config from inferred path: {training_conf_path}") + logger.info( + f"Loading training config from inferred path: {training_conf_path}" + ) self._training_conf = read_conf(training_conf_path) except Exception as e: logger.error(f"Failed to load training config: {e}") @@ -319,14 +323,11 @@ def _load_trained_psf_model(self): model_name = self.training_conf.training.model_params.model_name id_name = self.training_conf.training.id_name - + weights_path_pattern = os.path.join( - trained_model_path, - model_subdir, - ( - f"{model_subdir}*_{model_name}" - f"*{id_name}_cycle{cycle}*" - ), + trained_model_path, + model_subdir, + (f"{model_subdir}*_{model_name}" f"*{id_name}_cycle{cycle}*"), ) return load_trained_psf_model( self.training_conf, @@ -334,7 +335,6 @@ def _load_trained_psf_model(self): weights_path_pattern, ) - def _get_training_conf_path_from_metrics(self): """ Retrieves the full path to the training config based on the metrics configuration. @@ -356,22 +356,27 @@ def _get_training_conf_path_from_metrics(self): try: training_conf_filename = self._metrics_conf.metrics.trained_model_config except AttributeError as e: - raise KeyError("Missing 'trained_model_config' key in metrics configuration.") from e + raise KeyError( + "Missing 'trained_model_config' key in metrics configuration." + ) from e training_conf_path = os.path.join( - self._file_handler.get_config_dir(trained_model_path), training_conf_filename) + self._file_handler.get_config_dir(trained_model_path), + training_conf_filename, + ) if not os.path.exists(training_conf_path): - raise FileNotFoundError(f"Training config file not found: {training_conf_path}") + raise FileNotFoundError( + f"Training config file not found: {training_conf_path}" + ) return training_conf_path - def _get_trained_model_path(self): """ Determine the trained model path from either: - - 1. The metrics configuration file (i.e., for metrics-only runs after training), or + + 1. The metrics configuration file (i.e., for metrics-only runs after training), or 2. The runtime-generated file handler paths (i.e., for single runs that perform both training and evaluation). Returns @@ -384,14 +389,18 @@ def _get_trained_model_path(self): ConfigParameterError If the path specified in the metrics config is invalid or missing. """ - trained_model_path = getattr(self._metrics_conf.metrics, "trained_model_path", None) + trained_model_path = getattr( + self._metrics_conf.metrics, "trained_model_path", None + ) if trained_model_path: if not os.path.isdir(trained_model_path): raise ConfigParameterError( f"The trained model path provided in the metrics config is not a valid directory: {trained_model_path}" ) - logger.info(f"Using trained model path from metrics config: {trained_model_path}") + logger.info( + f"Using trained model path from metrics config: {trained_model_path}" + ) return trained_model_path # Fallback for single-run training + metrics evaluation mode @@ -400,7 +409,9 @@ def _get_trained_model_path(self): self._file_handler.parent_output_dir, self._file_handler.workdir, ) - logger.info(f"Using fallback trained model path from runtime file handler: {fallback_path}") + logger.info( + f"Using fallback trained model path from runtime file handler: {fallback_path}" + ) return fallback_path def _load_data_conf(self): @@ -425,7 +436,6 @@ def _load_data_conf(self): logger.exception(e) raise ConfigParameterError("Data configuration loading error.") - def call_plot_config_handler_run(self, model_metrics): """Make Metrics Plots. @@ -472,13 +482,13 @@ def run(self): input configuration. """ - logger.info(f"Running metrics evaluation on psf model: {self.weights_path}") + logger.info("Running metrics evaluation on trained PSF model...") model_metrics = evaluate_model( self.metrics_conf.metrics, self.training_conf.training, self.data_conf, - self.psf_model, + self.trained_psf_model, self.metrics_dir, ) From 8e659d0ecb3ffc5bff6603230e333e6f4d9d88a8 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 15 May 2025 13:15:53 +0200 Subject: [PATCH 020/129] Add import logging and create logger object --- src/wf_psf/psf_models/psf_model_loader.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/wf_psf/psf_models/psf_model_loader.py b/src/wf_psf/psf_models/psf_model_loader.py index 26056e57..1d2e267f 100644 --- a/src/wf_psf/psf_models/psf_model_loader.py +++ b/src/wf_psf/psf_models/psf_model_loader.py @@ -7,11 +7,15 @@ Author: Jennifer Pollack """ +import logging from wf_psf.psf_models.psf_models import ( get_psf_model, get_psf_model_weights_filepath ) + +logger = logging.getLogger(__name__) + def load_trained_psf_model(training_conf, data_conf, weights_path_pattern): """ Loads a trained PSF model and applies saved weights. From 6d9c40683620198c3c4524129c76668f36367cf5 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Thu, 15 May 2025 13:18:43 +0200 Subject: [PATCH 021/129] Create psf_inference.py module --- src/wf_psf/inference/psf_inference.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index e69de29b..4f5b39e5 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -0,0 +1,21 @@ +"""Inference. + +A module which provides a set of functions to perform inference +on PSF models. It includes functions to load a trained model, +perform inference on a dataset of SEDs and positions, and generate a polychromatic PSF. + +:Authors: Jennifer Pollack + +""" + +import os +import glob +import logging +import numpy as np +from wf_psf.psf_models import psf_models, psf_model_loader +import tensorflow as tf + + +#def prepare_inputs(...): ... +#def generate_psfs(...): ... +#def run_pipeline(...): ... From 47e9c22dee1b21f64f72219d619a482a75fc9fed Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Thu, 15 May 2025 13:21:11 +0200 Subject: [PATCH 022/129] Remove arg from evaluate_model unit test --- src/wf_psf/tests/test_metrics/metrics_interface_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/wf_psf/tests/test_metrics/metrics_interface_test.py b/src/wf_psf/tests/test_metrics/metrics_interface_test.py index 44bd09bf..12d51447 100644 --- a/src/wf_psf/tests/test_metrics/metrics_interface_test.py +++ b/src/wf_psf/tests/test_metrics/metrics_interface_test.py @@ -209,7 +209,6 @@ def test_evaluate_model( trained_model_params=mock_trained_model_params, data=mock_data, psf_model=mock_psf_model, - weights_path="/mock/weights/path", metrics_output="/mock/metrics/output", ) From e3ad71ab065e5adf643b78a64765e3b0e88471bd Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 16 May 2025 14:41:17 +0200 Subject: [PATCH 023/129] Refactor: reorganise modules, relocate utility functions, rename modules, update import statements and unit tests --- src/wf_psf/data/centroids.py | 71 ++- src/wf_psf/data/data_handler.py | 245 ++++++++++ src/wf_psf/data/data_preprocessing.py | 102 ---- src/wf_psf/data/data_zernike_utils.py | 295 +++++++++++ src/wf_psf/data/training_preprocessing.py | 458 ------------------ src/wf_psf/instrument/ccd_misalignments.py | 40 ++ .../psf_models/tf_modules/tf_psf_field.py | 2 +- src/wf_psf/tests/__init__.py | 1 + src/wf_psf/tests/conftest.py | 2 +- src/wf_psf/tests/test_data/__init__.py | 0 src/wf_psf/tests/test_data/centroids_test.py | 235 ++++----- src/wf_psf/tests/test_data/conftest.py | 61 +++ ...rocessing_test.py => data_handler_test.py} | 214 +------- .../test_data/data_zernike_utils_test.py | 133 +++++ src/wf_psf/tests/test_data/test_data_utils.py | 30 ++ .../test_metrics/metrics_interface_test.py | 4 +- src/wf_psf/tests/test_psf_models/conftest.py | 2 +- .../psf_model_physical_polychromatic_test.py | 2 +- .../tests/test_utils/configs_handler_test.py | 2 +- src/wf_psf/utils/configs_handler.py | 2 +- 20 files changed, 992 insertions(+), 909 deletions(-) create mode 100644 src/wf_psf/data/data_handler.py delete mode 100644 src/wf_psf/data/data_preprocessing.py create mode 100644 src/wf_psf/data/data_zernike_utils.py delete mode 100644 src/wf_psf/data/training_preprocessing.py create mode 100644 src/wf_psf/tests/__init__.py create mode 100644 src/wf_psf/tests/test_data/__init__.py rename src/wf_psf/tests/test_data/{training_preprocessing_test.py => data_handler_test.py} (52%) create mode 100644 src/wf_psf/tests/test_data/data_zernike_utils_test.py create mode 100644 src/wf_psf/tests/test_data/test_data_utils.py diff --git a/src/wf_psf/data/centroids.py b/src/wf_psf/data/centroids.py index e414520e..27f7894b 100644 --- a/src/wf_psf/data/centroids.py +++ b/src/wf_psf/data/centroids.py @@ -8,10 +8,79 @@ import numpy as np import scipy.signal as scisig -from wf_psf.data.data_preprocessing import shift_x_y_to_zk1_2_wavediff +from wf_psf.data.data_handler import extract_star_data +from fractions import Fraction +from wf_psf.data.data_zernike_utils import shift_x_y_to_zk1_2_wavediff +import tensorflow as tf from typing import Optional +def compute_centroid_correction(model_params, data, batch_size: int=1) -> np.ndarray: + """Compute centroid corrections using Zernike polynomials. + + This function calculates the Zernike contributions required to match the centroid + of the WaveDiff PSF model to the observed star centroids, processing in batches. + + Parameters + ---------- + model_params : RecursiveNamespace + An object containing parameters for the PSF model, including pixel sampling and initial centroid window parameters. + + data : DataConfigHandler + An object containing training and test datasets, including observed PSFs + and optional star masks. + + batch_size : int, optional + The batch size to use when processing the stars. Default is 16. + + + Returns + ------- + zernike_centroid_array : np.ndarray + A 2D NumPy array of shape `(n_stars, 3)`, where `n_stars` is the number of + observed stars. The array contains the computed Zernike contributions, + with zero padding applied to the first column to ensure a consistent shape. + """ + star_postage_stamps = extract_star_data(data=data, train_key="noisy_stars", test_key="stars") + + # Get star mask catalogue only if "masks" exist in both training and test datasets + star_masks = ( + extract_star_data(data=data, train_key="masks", test_key="masks") + if ( + data.training_data.dataset.get("masks") is not None + and data.test_data.dataset.get("masks") is not None + and tf.size(data.training_data.dataset["masks"]) > 0 + and tf.size(data.test_data.dataset["masks"]) > 0 + ) + else None + ) + + pix_sampling = model_params.pix_sampling * 1e-6 # Change units from [um] to [m] + + # Ensure star_masks is properly handled + star_masks = star_masks if star_masks is not None else None + + reference_shifts = [float(Fraction(value)) for value in model_params.reference_shifts] + + n_stars = len(star_postage_stamps) + zernike_centroid_array = [] + + # Batch process the stars + for i in range(0, n_stars, batch_size): + batch_postage_stamps = star_postage_stamps[i:i + batch_size] + batch_masks = star_masks[i:i + batch_size] if star_masks is not None else None + + # Compute Zernike 1 and Zernike 2 for the batch + zk1_2_batch = -1.0 * compute_zernike_tip_tilt( + batch_postage_stamps, batch_masks, pix_sampling, reference_shifts + ) + + # Zero pad array for each batch and append + zernike_centroid_array.append(np.pad(zk1_2_batch, pad_width=[(0, 0), (1, 0)], mode="constant", constant_values=0)) + + # Combine all batches into a single array + return np.concatenate(zernike_centroid_array, axis=0) + def compute_zernike_tip_tilt( star_images: np.ndarray, star_masks: Optional[np.ndarray] = None, diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py new file mode 100644 index 00000000..e2946b0d --- /dev/null +++ b/src/wf_psf/data/data_handler.py @@ -0,0 +1,245 @@ +"""Data Handler Module. + +Provides tools for loading, preprocessing, and managing data used in both training and inference workflows. + +Includes: +- The `DataHandler` class for managing datasets and associated metadata +- Utility functions for loading structured data products +- Preprocessing routines for spectral energy distributions (SEDs), including format conversion (e.g., to TensorFlow) and transformations + +This module serves as a central interface between raw data and modeling components. + +Authors: Jennifer Pollack , Tobias Liaudat +""" + +import os +import numpy as np +import wf_psf.utils.utils as utils +import tensorflow as tf +from fractions import Fraction +import logging + +logger = logging.getLogger(__name__) + + +class DataHandler: + """Data Handler. + + This class manages loading and processing of training and testing data for use during PSF model training and validation. + It provides methods to access and preprocess the data. + + Parameters + ---------- + dataset_type : str + Type of dataset ("train" or "test"). + data_params : RecursiveNamespace + Recursive Namespace object containing parameters for both 'train' and 'test' datasets. + simPSF : PSFSimulator + Instance of the PSFSimulator class for simulating PSF models. + n_bins_lambda : int + Number of wavelength bins for SED processing. + load_data : bool, optional + If True, data is loaded and processed during initialization. If False, data loading + is deferred until explicitly called. Default is True. + + Attributes + ---------- + dataset_type : str + Type of dataset ("train" or "test"). + data_params : RecursiveNamespace + Parameters for the current dataset type. + dataset : dict or None + Dictionary containing the loaded dataset, including positions and stars/noisy_stars. + simPSF : PSFSimulator + Instance of the PSFSimulator class for simulating PSF models. + n_bins_lambda : int + Number of wavelength bins. + sed_data : tf.Tensor or None + TensorFlow tensor containing processed SED data for training/testing. + load_data_on_init : bool + Flag controlling whether data is loaded during initialization. + """ + + def __init__( + self, dataset_type, data_params, simPSF, n_bins_lambda, load_data: bool = True + ): + """ + Initialize the dataset handler for PSF simulation. + + Parameters + ---------- + dataset_type : str + Type of dataset ("train" or "test"). + data_params : RecursiveNamespace + Recursive Namespace object containing parameters for both 'train' and 'test' datasets. + simPSF : PSFSimulator + Instance of the PSFSimulator class for simulating PSF models. + n_bins_lambda : int + Number of wavelength bins for SED processing. + load_data : bool, optional + If True, data is loaded and processed during initialization. If False, data loading + is deferred. Default is True. + """ + self.dataset_type = dataset_type + self.data_params = data_params.__dict__[dataset_type] + self.simPSF = simPSF + self.n_bins_lambda = n_bins_lambda + self.load_data_on_init = load_data + if self.load_data_on_init: + self.load_dataset() + self.process_sed_data() + else: + self.dataset = None + self.sed_data = None + + def load_dataset(self): + """Load dataset. + + Load the dataset based on the specified dataset type. + + """ + self.dataset = np.load( + os.path.join(self.data_params.data_dir, self.data_params.file), + allow_pickle=True, + )[()] + self.dataset["positions"] = tf.convert_to_tensor( + self.dataset["positions"], dtype=tf.float32 + ) + if self.dataset_type == "training": + if "noisy_stars" in self.dataset: + self.dataset["noisy_stars"] = tf.convert_to_tensor( + self.dataset["noisy_stars"], dtype=tf.float32 + ) + else: + logger.warning(f"Missing 'noisy_stars' in {self.dataset_type} dataset.") + elif self.dataset_type == "test": + if "stars" in self.dataset: + self.dataset["stars"] = tf.convert_to_tensor( + self.dataset["stars"], dtype=tf.float32 + ) + else: + logger.warning(f"Missing 'stars' in {self.dataset_type} dataset.") + elif "inference" == self.dataset_type: + pass + + def process_sed_data(self): + """Process SED Data. + + A method to generate and process SED data. + + """ + self.sed_data = [ + utils.generate_SED_elems_in_tensorflow( + _sed, self.simPSF, n_bins=self.n_bins_lambda, tf_dtype=tf.float64 + ) + for _sed in self.dataset["SEDs"] + ] + self.sed_data = tf.convert_to_tensor(self.sed_data, dtype=tf.float32) + self.sed_data = tf.transpose(self.sed_data, perm=[0, 2, 1]) + + +def get_np_obs_positions(data): + """Get observed positions in numpy from the provided dataset. + + This method concatenates the positions of the stars from both the training + and test datasets to obtain the observed positions. + + Parameters + ---------- + data : DataConfigHandler + Object containing training and test datasets. + + Returns + ------- + np.ndarray + Numpy array containing the observed positions of the stars. + + Notes + ----- + The observed positions are obtained by concatenating the positions of stars + from both the training and test datasets along the 0th axis. + """ + obs_positions = np.concatenate( + ( + data.training_data.dataset["positions"], + data.test_data.dataset["positions"], + ), + axis=0, + ) + + return obs_positions + + +def get_obs_positions(data): + """Get observed positions from the provided dataset. + + Parameters + ---------- + data : DataConfigHandler + Object containing training and test datasets. + + Returns + ------- + tf.Tensor + Tensor containing the observed positions of the stars. + """ + obs_positions = get_np_obs_positions(data) + + return tf.convert_to_tensor(obs_positions, dtype=tf.float32) + + +def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: + """Extract specific star-related data from training and test datasets. + + This function retrieves and concatenates specific star-related data (e.g., stars, masks) from the + star training and test datasets such as star stamps or masks, based on the provided keys. + + Parameters + ---------- + data : DataConfigHandler + Object containing training and test datasets. + train_key : str + The key to retrieve data from the training dataset (e.g., 'noisy_stars', 'masks'). + test_key : str + The key to retrieve data from the test dataset (e.g., 'stars', 'masks'). + + Returns + ------- + np.ndarray + A NumPy array containing the concatenated data for the given keys. + + Raises + ------ + KeyError + If the specified keys do not exist in the training or test datasets. + + Notes + ----- + - If the dataset contains TensorFlow tensors, they will be converted to NumPy arrays. + - Ensure that eager execution is enabled when calling this function. + """ + # Ensure the requested keys exist in both training and test datasets + missing_keys = [ + key + for key, dataset in [ + (train_key, data.training_data.dataset), + (test_key, data.test_data.dataset), + ] + if key not in dataset + ] + + if missing_keys: + raise KeyError(f"Missing keys in dataset: {missing_keys}") + + # Retrieve data from training and test sets + train_data = data.training_data.dataset[train_key] + test_data = data.test_data.dataset[test_key] + + # Convert to NumPy if necessary + if tf.is_tensor(train_data): + train_data = train_data.numpy() + if tf.is_tensor(test_data): + test_data = test_data.numpy() + + # Concatenate and return + return np.concatenate((train_data, test_data), axis=0) diff --git a/src/wf_psf/data/data_preprocessing.py b/src/wf_psf/data/data_preprocessing.py deleted file mode 100644 index 44e18436..00000000 --- a/src/wf_psf/data/data_preprocessing.py +++ /dev/null @@ -1,102 +0,0 @@ -"""Data Preprocessing. - -A module with utils to preprocess data. - -:Author: Tobias Liaudat - -""" - -import numpy as np - - -def shift_x_y_to_zk1_2_wavediff(dxy, tel_focal_length=24.5, tel_diameter=1.2): - """Compute Zernike 1(2) for a given shifts in x(y) in WaveDifff conventions. - - All inputs should be in [m]. - A displacement of, for example, 0.5 pixels should be scaled with the corresponding pixel scale, - e.g. 12[um], to get a displacement in [m], which would be `dxy=0.5*12e-6`. - - The output zernike coefficient is in [um] units as expected by wavediff. - - To apply match the centroid with a `dx` that has a corresponding `zk1`, - the new PSF should be generated with `-zk1`. - - The same applies to `dy` and `zk2`. - - Parameters - ---------- - dxy : float - Centroid shift in [m]. It can be on the x-axis or the y-axis. - tel_focal_length : float - Telescope focal length in [m]. - tel_diameter : float - Telescope aperture diameter in [m]. - """ - reference_pix_sampling = 12e-6 - zernike_norm_factor = 2.0 - - # return zernike_norm_factor * (dx/reference_pix_sampling) / (tel_focal_length * tel_diameter / 2) - return ( - zernike_norm_factor - * (tel_diameter / 2) - * np.sin(np.arctan((dxy / reference_pix_sampling) / tel_focal_length)) - * 3.0 - ) - - -def defocus_to_zk4_zemax(dz, tel_focal_length=24.5, tel_diameter=1.2): - """Compute Zernike 4 value for a given defocus in zemax conventions. - - All inputs should be in [m]. - - Parameters - ---------- - dz : float - Shift in the z-axis, perpendicular to the focal plane. Units in [m]. - tel_focal_length : float - Telescope focal length in [m]. - tel_diameter : float - Telescope aperture diameter in [m]. - """ - # Base calculation - zk4 = dz / (8.0 * (tel_focal_length / tel_diameter) ** 2) - # Apply Z4 normalisation - # This step depends on the normalisation of the Zernike basis used - zk4 /= np.sqrt(3) - # Convert to waves with a reference of 800nm - zk4 /= 800e-9 - # Remove the peak to valley value - zk4 /= 2.0 - - return zk4 - - -def defocus_to_zk4_wavediff(dz, tel_focal_length=24.5, tel_diameter=1.2): - """Compute Zernike 4 value for a given defocus in WaveDifff conventions. - - All inputs should be in [m]. - - The output zernike coefficient is in [um] units as expected by wavediff. - - Parameters - ---------- - dz : float - Shift in the z-axis, perpendicular to the focal plane. Units in [m]. - tel_focal_length : float - Telescope focal length in [m]. - tel_diameter : float - Telescope aperture diameter in [m]. - """ - # Base calculation - zk4 = dz / (8.0 * (tel_focal_length / tel_diameter) ** 2) - # Apply Z4 normalisation - # This step depends on the normalisation of the Zernike basis used - zk4 /= np.sqrt(3) - - # Remove the peak to valley value - zk4 /= 2.0 - - # Change units to [um] as Wavediff uses - zk4 *= 1e6 - - return zk4 diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py new file mode 100644 index 00000000..760adb11 --- /dev/null +++ b/src/wf_psf/data/data_zernike_utils.py @@ -0,0 +1,295 @@ +"""Utilities for Zernike Data Handling. + +This module provides utility functions for working with Zernike coefficients, including: +- Prior generation +- Data loading +- Conversions between physical displacements (e.g., defocus, centroid shifts) and modal Zernike coefficients + +Useful in contexts where Zernike representations are used to model optical aberrations or link physical misalignments to wavefront modes. + +:Author: Tobias Liaudat + +""" + +import numpy as np +import tensorflow as tf +from typing import Optional +import logging + +logger = logging.getLogger(__name__) + + +def get_np_zernike_prior(data): + """Get the zernike prior from the provided dataset. + + This method concatenates the stars from both the training + and test datasets to obtain the full prior. + + Parameters + ---------- + data : DataConfigHandler + Object containing training and test datasets. + + Returns + ------- + zernike_prior : np.ndarray + Numpy array containing the full prior. + """ + zernike_prior = np.concatenate( + ( + data.training_data.dataset["zernike_prior"], + data.test_data.dataset["zernike_prior"], + ), + axis=0, + ) + + return zernike_prior + + +def get_zernike_prior(model_params, data, batch_size: int=16): + """Get Zernike priors from the provided dataset. + + This method concatenates the Zernike priors from both the training + and test datasets. + + Parameters + ---------- + model_params : RecursiveNamespace + Object containing parameters for this PSF model class. + data : DataConfigHandler + Object containing training and test datasets. + batch_size : int, optional + The batch size to use when processing the stars. Default is 16. + + Returns + ------- + tf.Tensor + Tensor containing the observed positions of the stars. + + Notes + ----- + The Zernike prior are obtained by concatenating the Zernike priors + from both the training and test datasets along the 0th axis. + + """ + # List of zernike contribution + zernike_contribution_list = [] + + if model_params.use_prior: + logger.info("Reading in Zernike prior into Zernike contribution list...") + zernike_contribution_list.append(get_np_zernike_prior(data)) + + if model_params.correct_centroids: + logger.info("Adding centroid correction to Zernike contribution list...") + zernike_contribution_list.append( + compute_centroid_correction(model_params, data, batch_size) + ) + + if model_params.add_ccd_misalignments: + logger.info("Adding CCD mis-alignments to Zernike contribution list...") + zernike_contribution_list.append(compute_ccd_misalignment(model_params, data)) + + if len(zernike_contribution_list) == 1: + zernike_contribution = zernike_contribution_list[0] + else: + # Get max zk order + max_zk_order = np.max( + np.array( + [ + zk_contribution.shape[1] + for zk_contribution in zernike_contribution_list + ] + ) + ) + + zernike_contribution = np.zeros( + (zernike_contribution_list[0].shape[0], max_zk_order) + ) + + # Pad arrays to get the same length and add the final contribution + for it in range(len(zernike_contribution_list)): + current_zk_order = zernike_contribution_list[it].shape[1] + current_zernike_contribution = np.pad( + zernike_contribution_list[it], + pad_width=[(0, 0), (0, int(max_zk_order - current_zk_order))], + mode="constant", + constant_values=0, + ) + + zernike_contribution += current_zernike_contribution + + return tf.convert_to_tensor(zernike_contribution, dtype=tf.float32) + + +def shift_x_y_to_zk1_2_wavediff(dxy, tel_focal_length=24.5, tel_diameter=1.2): + """Compute Zernike 1(2) for a given shifts in x(y) in WaveDifff conventions. + + All inputs should be in [m]. + A displacement of, for example, 0.5 pixels should be scaled with the corresponding pixel scale, + e.g. 12[um], to get a displacement in [m], which would be `dxy=0.5*12e-6`. + + The output zernike coefficient is in [um] units as expected by wavediff. + + To apply match the centroid with a `dx` that has a corresponding `zk1`, + the new PSF should be generated with `-zk1`. + + The same applies to `dy` and `zk2`. + + Parameters + ---------- + dxy : float + Centroid shift in [m]. It can be on the x-axis or the y-axis. + tel_focal_length : float + Telescope focal length in [m]. + tel_diameter : float + Telescope aperture diameter in [m]. + """ + reference_pix_sampling = 12e-6 + zernike_norm_factor = 2.0 + + # return zernike_norm_factor * (dx/reference_pix_sampling) / (tel_focal_length * tel_diameter / 2) + return ( + zernike_norm_factor + * (tel_diameter / 2) + * np.sin(np.arctan((dxy / reference_pix_sampling) / tel_focal_length)) + * 3.0 + ) + +def compute_zernike_tip_tilt( + star_images: np.ndarray, + star_masks: Optional[np.ndarray] = None, + pixel_sampling: float = 12e-6, + reference_shifts: list[float] = [-1/3, -1/3], + sigma_init: float = 2.5, + n_iter: int = 20, +) -> np.ndarray: + """ + Compute Zernike tip-tilt corrections for a batch of PSF images. + + This function estimates the centroid shifts of multiple PSFs and computes + the corresponding Zernike tip-tilt corrections to align them with a reference. + + Parameters + ---------- + star_images : np.ndarray + A batch of PSF images (3D array of shape `(num_images, height, width)`). + star_masks : np.ndarray, optional + A batch of masks (same shape as `star_postage_stamps`). Each mask can have: + - `0` to ignore the pixel. + - `1` to fully consider the pixel. + - Values in `(0,1]` as weights for partial consideration. + Defaults to None. + pixel_sampling : float, optional + The pixel size in meters. Defaults to `12e-6 m` (12 microns). + reference_shifts : list[float], optional + The target centroid shifts in pixels, specified as `[dy, dx]`. + Defaults to `[-1/3, -1/3]` (nominal Euclid conditions). + sigma_init : float, optional + Initial standard deviation for centroid estimation. Default is `2.5`. + n_iter : int, optional + Number of iterations for centroid refinement. Default is `20`. + + Returns + ------- + np.ndarray + An array of shape `(num_images, 2)`, where: + - Column 0 contains `Zk1` (tip) values. + - Column 1 contains `Zk2` (tilt) values. + + Notes + ----- + - This function processes all images at once using vectorized operations. + - The Zernike coefficients are computed in the WaveDiff convention. + """ + from wf_psf.data.centroids import CentroidEstimator + + # Vectorize the centroid computation + centroid_estimator = CentroidEstimator( + im=star_images, + mask=star_masks, + sigma_init=sigma_init, + n_iter=n_iter + ) + + shifts = centroid_estimator.get_intra_pixel_shifts() + + # Ensure reference_shifts is a NumPy array (if it's not already) + reference_shifts = np.array(reference_shifts) + + # Reshape to ensure it's a column vector (1, 2) + reference_shifts = reference_shifts[None,:] + + # Broadcast reference_shifts to match the shape of shifts + reference_shifts = np.broadcast_to(reference_shifts, shifts.shape) + + # Compute displacements + displacements = (reference_shifts - shifts) # + + # Ensure the correct axis order for displacements (x-axis, then y-axis) + displacements_swapped = displacements[:, [1, 0]] # Adjust axis order if necessary + + # Call shift_x_y_to_zk1_2_wavediff directly on the vector of displacements + zk1_2_array = shift_x_y_to_zk1_2_wavediff(displacements_swapped.flatten() * pixel_sampling ) # vectorized call + + # Reshape the result back to the original shape of displacements + zk1_2_array = zk1_2_array.reshape(displacements.shape) + + return zk1_2_array + + +def defocus_to_zk4_zemax(dz, tel_focal_length=24.5, tel_diameter=1.2): + """Compute Zernike 4 value for a given defocus in zemax conventions. + + All inputs should be in [m]. + + Parameters + ---------- + dz : float + Shift in the z-axis, perpendicular to the focal plane. Units in [m]. + tel_focal_length : float + Telescope focal length in [m]. + tel_diameter : float + Telescope aperture diameter in [m]. + """ + # Base calculation + zk4 = dz / (8.0 * (tel_focal_length / tel_diameter) ** 2) + # Apply Z4 normalisation + # This step depends on the normalisation of the Zernike basis used + zk4 /= np.sqrt(3) + # Convert to waves with a reference of 800nm + zk4 /= 800e-9 + # Remove the peak to valley value + zk4 /= 2.0 + + return zk4 + + +def defocus_to_zk4_wavediff(dz, tel_focal_length=24.5, tel_diameter=1.2): + """Compute Zernike 4 value for a given defocus in WaveDifff conventions. + + All inputs should be in [m]. + + The output zernike coefficient is in [um] units as expected by wavediff. + + Parameters + ---------- + dz : float + Shift in the z-axis, perpendicular to the focal plane. Units in [m]. + tel_focal_length : float + Telescope focal length in [m]. + tel_diameter : float + Telescope aperture diameter in [m]. + """ + # Base calculation + zk4 = dz / (8.0 * (tel_focal_length / tel_diameter) ** 2) + # Apply Z4 normalisation + # This step depends on the normalisation of the Zernike basis used + zk4 /= np.sqrt(3) + + # Remove the peak to valley value + zk4 /= 2.0 + + # Change units to [um] as Wavediff uses + zk4 *= 1e6 + + return zk4 diff --git a/src/wf_psf/data/training_preprocessing.py b/src/wf_psf/data/training_preprocessing.py deleted file mode 100644 index 25180ebe..00000000 --- a/src/wf_psf/data/training_preprocessing.py +++ /dev/null @@ -1,458 +0,0 @@ -"""Training Data Processing. - -A module to load and preprocess training and validation test data. - -:Authors: Jennifer Pollack and Tobias Liaudat - -""" - -import os -import numpy as np -import wf_psf.utils.utils as utils -import tensorflow as tf -from wf_psf.instrument.ccd_misalignments import CCDMisalignmentCalculator -from wf_psf.data.centroids import compute_zernike_tip_tilt -from fractions import Fraction -import logging - -logger = logging.getLogger(__name__) - - -class DataHandler: - """Data Handler. - - This class manages loading and processing of training and testing data for use during PSF model training and validation. - It provides methods to access and preprocess the data. - - Parameters - ---------- - dataset_type : str - Type of dataset ("train" or "test"). - data_params : RecursiveNamespace - Recursive Namespace object containing parameters for both 'train' and 'test' datasets. - simPSF : PSFSimulator - Instance of the PSFSimulator class for simulating PSF models. - n_bins_lambda : int - Number of wavelength bins for SED processing. - load_data : bool, optional - If True, data is loaded and processed during initialization. If False, data loading - is deferred until explicitly called. Default is True. - - Attributes - ---------- - dataset_type : str - Type of dataset ("train" or "test"). - data_params : RecursiveNamespace - Parameters for the current dataset type. - dataset : dict or None - Dictionary containing the loaded dataset, including positions and stars/noisy_stars. - simPSF : PSFSimulator - Instance of the PSFSimulator class for simulating PSF models. - n_bins_lambda : int - Number of wavelength bins. - sed_data : tf.Tensor or None - TensorFlow tensor containing processed SED data for training/testing. - load_data_on_init : bool - Flag controlling whether data is loaded during initialization. - """ - - def __init__( - self, dataset_type, data_params, simPSF, n_bins_lambda, load_data: bool = True - ): - """ - Initialize the dataset handler for PSF simulation. - - Parameters - ---------- - dataset_type : str - Type of dataset ("train" or "test"). - data_params : RecursiveNamespace - Recursive Namespace object containing parameters for both 'train' and 'test' datasets. - simPSF : PSFSimulator - Instance of the PSFSimulator class for simulating PSF models. - n_bins_lambda : int - Number of wavelength bins for SED processing. - load_data : bool, optional - If True, data is loaded and processed during initialization. If False, data loading - is deferred. Default is True. - """ - self.dataset_type = dataset_type - self.data_params = data_params.__dict__[dataset_type] - self.simPSF = simPSF - self.n_bins_lambda = n_bins_lambda - self.dataset = None - self.sed_data = None - self.load_data_on_init = load_data - if self.load_data_on_init: - self.load_dataset() - self.process_sed_data() - - def load_dataset(self): - """Load dataset. - - Load the dataset based on the specified dataset type. - - """ - self.dataset = np.load( - os.path.join(self.data_params.data_dir, self.data_params.file), - allow_pickle=True, - )[()] - self.dataset["positions"] = tf.convert_to_tensor( - self.dataset["positions"], dtype=tf.float32 - ) - if self.dataset_type == "training": - if "noisy_stars" in self.dataset: - self.dataset["noisy_stars"] = tf.convert_to_tensor( - self.dataset["noisy_stars"], dtype=tf.float32 - ) - else: - logger.warning(f"Missing 'noisy_stars' in {self.dataset_type} dataset.") - elif self.dataset_type == "test": - if "stars" in self.dataset: - self.dataset["stars"] = tf.convert_to_tensor( - self.dataset["stars"], dtype=tf.float32 - ) - else: - logger.warning(f"Missing 'stars' in {self.dataset_type} dataset.") - - def process_sed_data(self): - """Process SED Data. - - A method to generate and process SED data. - - """ - self.sed_data = [ - utils.generate_SED_elems_in_tensorflow( - _sed, self.simPSF, n_bins=self.n_bins_lambda, tf_dtype=tf.float64 - ) - for _sed in self.dataset["SEDs"] - ] - self.sed_data = tf.convert_to_tensor(self.sed_data, dtype=tf.float32) - self.sed_data = tf.transpose(self.sed_data, perm=[0, 2, 1]) - - -def get_np_obs_positions(data): - """Get observed positions in numpy from the provided dataset. - - This method concatenates the positions of the stars from both the training - and test datasets to obtain the observed positions. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - np.ndarray - Numpy array containing the observed positions of the stars. - - Notes - ----- - The observed positions are obtained by concatenating the positions of stars - from both the training and test datasets along the 0th axis. - """ - obs_positions = np.concatenate( - ( - data.training_data.dataset["positions"], - data.test_data.dataset["positions"], - ), - axis=0, - ) - - return obs_positions - - -def get_obs_positions(data): - """Get observed positions from the provided dataset. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - tf.Tensor - Tensor containing the observed positions of the stars. - """ - obs_positions = get_np_obs_positions(data) - - return tf.convert_to_tensor(obs_positions, dtype=tf.float32) - - -def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: - """Extract specific star-related data from training and test datasets. - - This function retrieves and concatenates specific star-related data (e.g., stars, masks) from the - star training and test datasets such as star stamps or masks, based on the provided keys. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - train_key : str - The key to retrieve data from the training dataset (e.g., 'noisy_stars', 'masks'). - test_key : str - The key to retrieve data from the test dataset (e.g., 'stars', 'masks'). - - Returns - ------- - np.ndarray - A NumPy array containing the concatenated data for the given keys. - - Raises - ------ - KeyError - If the specified keys do not exist in the training or test datasets. - - Notes - ----- - - If the dataset contains TensorFlow tensors, they will be converted to NumPy arrays. - - Ensure that eager execution is enabled when calling this function. - """ - # Ensure the requested keys exist in both training and test datasets - missing_keys = [ - key - for key, dataset in [ - (train_key, data.training_data.dataset), - (test_key, data.test_data.dataset), - ] - if key not in dataset - ] - - if missing_keys: - raise KeyError(f"Missing keys in dataset: {missing_keys}") - - # Retrieve data from training and test sets - train_data = data.training_data.dataset[train_key] - test_data = data.test_data.dataset[test_key] - - # Convert to NumPy if necessary - if tf.is_tensor(train_data): - train_data = train_data.numpy() - if tf.is_tensor(test_data): - test_data = test_data.numpy() - - # Concatenate and return - return np.concatenate((train_data, test_data), axis=0) - - -def get_np_zernike_prior(data): - """Get the zernike prior from the provided dataset. - - This method concatenates the stars from both the training - and test datasets to obtain the full prior. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - zernike_prior : np.ndarray - Numpy array containing the full prior. - """ - zernike_prior = np.concatenate( - ( - data.training_data.dataset["zernike_prior"], - data.test_data.dataset["zernike_prior"], - ), - axis=0, - ) - - return zernike_prior - - -def compute_centroid_correction(model_params, data, batch_size: int = 1) -> np.ndarray: - """Compute centroid corrections using Zernike polynomials. - - This function calculates the Zernike contributions required to match the centroid - of the WaveDiff PSF model to the observed star centroids, processing in batches. - - Parameters - ---------- - model_params : RecursiveNamespace - An object containing parameters for the PSF model, including pixel sampling and initial centroid window parameters. - - data : DataConfigHandler - An object containing training and test datasets, including observed PSFs - and optional star masks. - - batch_size : int, optional - The batch size to use when processing the stars. Default is 16. - - - Returns - ------- - zernike_centroid_array : np.ndarray - A 2D NumPy array of shape `(n_stars, 3)`, where `n_stars` is the number of - observed stars. The array contains the computed Zernike contributions, - with zero padding applied to the first column to ensure a consistent shape. - """ - star_postage_stamps = extract_star_data( - data=data, train_key="noisy_stars", test_key="stars" - ) - - # Get star mask catalogue only if "masks" exist in both training and test datasets - star_masks = ( - extract_star_data(data=data, train_key="masks", test_key="masks") - if ( - data.training_data.dataset.get("masks") is not None - and data.test_data.dataset.get("masks") is not None - and tf.size(data.training_data.dataset["masks"]) > 0 - and tf.size(data.test_data.dataset["masks"]) > 0 - ) - else None - ) - - pix_sampling = model_params.pix_sampling * 1e-6 # Change units from [um] to [m] - - # Ensure star_masks is properly handled - star_masks = star_masks if star_masks is not None else None - - reference_shifts = [ - float(Fraction(value)) for value in model_params.reference_shifts - ] - - n_stars = len(star_postage_stamps) - zernike_centroid_array = [] - - # Batch process the stars - for i in range(0, n_stars, batch_size): - batch_postage_stamps = star_postage_stamps[i : i + batch_size] - batch_masks = star_masks[i : i + batch_size] if star_masks is not None else None - - # Compute Zernike 1 and Zernike 2 for the batch - zk1_2_batch = -1.0 * compute_zernike_tip_tilt( - batch_postage_stamps, batch_masks, pix_sampling, reference_shifts - ) - - # Zero pad array for each batch and append - zernike_centroid_array.append( - np.pad( - zk1_2_batch, - pad_width=[(0, 0), (1, 0)], - mode="constant", - constant_values=0, - ) - ) - - # Combine all batches into a single array - return np.concatenate(zernike_centroid_array, axis=0) - - -def compute_ccd_misalignment(model_params, data): - """Compute CCD misalignment. - - Parameters - ---------- - model_params : RecursiveNamespace - Object containing parameters for this PSF model class. - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - zernike_ccd_misalignment_array : np.ndarray - Numpy array containing the Zernike contributions to model the CCD chip misalignments. - """ - obs_positions = get_np_obs_positions(data) - - ccd_misalignment_calculator = CCDMisalignmentCalculator( - tiles_path=model_params.ccd_misalignments_input_path, - x_lims=model_params.x_lims, - y_lims=model_params.y_lims, - tel_focal_length=model_params.tel_focal_length, - tel_diameter=model_params.tel_diameter, - ) - # Compute required zernike 4 for each position - zk4_values = np.array( - [ - ccd_misalignment_calculator.get_zk4_from_position(single_pos) - for single_pos in obs_positions - ] - ).reshape(-1, 1) - - # Zero pad array to get shape (n_stars, n_zernike=4) - zernike_ccd_misalignment_array = np.pad( - zk4_values, pad_width=[(0, 0), (3, 0)], mode="constant", constant_values=0 - ) - - return zernike_ccd_misalignment_array - - -def get_zernike_prior(model_params, data, batch_size: int = 16): - """Get Zernike priors from the provided dataset. - - This method concatenates the Zernike priors from both the training - and test datasets. - - Parameters - ---------- - model_params : RecursiveNamespace - Object containing parameters for this PSF model class. - data : DataConfigHandler - Object containing training and test datasets. - batch_size : int, optional - The batch size to use when processing the stars. Default is 16. - - Returns - ------- - tf.Tensor - Tensor containing the observed positions of the stars. - - Notes - ----- - The Zernike prior are obtained by concatenating the Zernike priors - from both the training and test datasets along the 0th axis. - - """ - # List of zernike contribution - zernike_contribution_list = [] - - if model_params.use_prior: - logger.info("Reading in Zernike prior into Zernike contribution list...") - zernike_contribution_list.append(get_np_zernike_prior(data)) - - if model_params.correct_centroids: - logger.info("Adding centroid correction to Zernike contribution list...") - zernike_contribution_list.append( - compute_centroid_correction(model_params, data, batch_size) - ) - - if model_params.add_ccd_misalignments: - logger.info("Adding CCD mis-alignments to Zernike contribution list...") - zernike_contribution_list.append(compute_ccd_misalignment(model_params, data)) - - if len(zernike_contribution_list) == 1: - zernike_contribution = zernike_contribution_list[0] - else: - # Get max zk order - max_zk_order = np.max( - np.array( - [ - zk_contribution.shape[1] - for zk_contribution in zernike_contribution_list - ] - ) - ) - - zernike_contribution = np.zeros( - (zernike_contribution_list[0].shape[0], max_zk_order) - ) - - # Pad arrays to get the same length and add the final contribution - for it in range(len(zernike_contribution_list)): - current_zk_order = zernike_contribution_list[it].shape[1] - current_zernike_contribution = np.pad( - zernike_contribution_list[it], - pad_width=[(0, 0), (0, int(max_zk_order - current_zk_order))], - mode="constant", - constant_values=0, - ) - - zernike_contribution += current_zernike_contribution - - return tf.convert_to_tensor(zernike_contribution, dtype=tf.float32) diff --git a/src/wf_psf/instrument/ccd_misalignments.py b/src/wf_psf/instrument/ccd_misalignments.py index 3b7a1eb3..35343886 100644 --- a/src/wf_psf/instrument/ccd_misalignments.py +++ b/src/wf_psf/instrument/ccd_misalignments.py @@ -13,6 +13,46 @@ from wf_psf.data.data_preprocessing import defocus_to_zk4_wavediff +def compute_ccd_misalignment(model_params, data): + """Compute CCD misalignment. + + Parameters + ---------- + model_params : RecursiveNamespace + Object containing parameters for this PSF model class. + data : DataConfigHandler + Object containing training and test datasets. + + Returns + ------- + zernike_ccd_misalignment_array : np.ndarray + Numpy array containing the Zernike contributions to model the CCD chip misalignments. + """ + obs_positions = get_np_obs_positions(data) + + ccd_misalignment_calculator = CCDMisalignmentCalculator( + tiles_path=model_params.ccd_misalignments_input_path, + x_lims=model_params.x_lims, + y_lims=model_params.y_lims, + tel_focal_length=model_params.tel_focal_length, + tel_diameter=model_params.tel_diameter, + ) + # Compute required zernike 4 for each position + zk4_values = np.array( + [ + ccd_misalignment_calculator.get_zk4_from_position(single_pos) + for single_pos in obs_positions + ] + ).reshape(-1, 1) + + # Zero pad array to get shape (n_stars, n_zernike=4) + zernike_ccd_misalignment_array = np.pad( + zk4_values, pad_width=[(0, 0), (3, 0)], mode="constant", constant_values=0 + ) + + return zernike_ccd_misalignment_array + + class CCDMisalignmentCalculator: """CCD Misalignment Calculator. diff --git a/src/wf_psf/psf_models/tf_modules/tf_psf_field.py b/src/wf_psf/psf_models/tf_modules/tf_psf_field.py index e25657d9..23a19b8b 100644 --- a/src/wf_psf/psf_models/tf_modules/tf_psf_field.py +++ b/src/wf_psf/psf_models/tf_modules/tf_psf_field.py @@ -16,7 +16,7 @@ TFPhysicalLayer, ) from wf_psf.psf_models.models.psf_model_semiparametric import TFSemiParametricField -from wf_psf.data.training_preprocessing import get_obs_positions +from wf_psf.data.data_handler import get_obs_positions from wf_psf.psf_models import psf_models as psfm import logging diff --git a/src/wf_psf/tests/__init__.py b/src/wf_psf/tests/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/src/wf_psf/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/src/wf_psf/tests/conftest.py b/src/wf_psf/tests/conftest.py index 5b617c63..beb6b9fb 100644 --- a/src/wf_psf/tests/conftest.py +++ b/src/wf_psf/tests/conftest.py @@ -13,7 +13,7 @@ from wf_psf.training.train import TrainingParamsHandler from wf_psf.utils.configs_handler import DataConfigHandler from wf_psf.psf_models import psf_models -from wf_psf.data.training_preprocessing import DataHandler +from wf_psf.data.data_handler import DataHandler training_config = RecursiveNamespace( id_name="-coherent_euclid_200stars", diff --git a/src/wf_psf/tests/test_data/__init__.py b/src/wf_psf/tests/test_data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/wf_psf/tests/test_data/centroids_test.py b/src/wf_psf/tests/test_data/centroids_test.py index c55e18f9..85719f4f 100644 --- a/src/wf_psf/tests/test_data/centroids_test.py +++ b/src/wf_psf/tests/test_data/centroids_test.py @@ -8,8 +8,11 @@ import numpy as np import pytest +from wf_psf.data.centroids import compute_centroid_correction, CentroidEstimator +from wf_psf.data.data_handler import extract_star_data +from wf_psf.data.data_zernike_utils import compute_zernike_tip_tilt +from wf_psf.utils.read_config import RecursiveNamespace from unittest.mock import MagicMock, patch -from wf_psf.data.centroids import compute_zernike_tip_tilt, CentroidEstimator # Function to compute centroid based on first-order moments @@ -28,25 +31,6 @@ def calculate_centroid(image, mask=None): return (xc, yc) -@pytest.fixture -def simple_image(): - """Fixture for a batch of simple star images.""" - num_images = 1 # Change this to test with multiple images - image = np.zeros((num_images, 5, 5)) # Create a 3D array - image[:, 2, 2] = 1 # Place the star at the center for each image - return image - - -@pytest.fixture -def multiple_images(): - """Fixture for a batch of images with stars at different positions.""" - images = np.zeros((3, 5, 5)) # 3 images, each of size 5x5 - images[0, 2, 2] = 1 # Star at center of image 0 - images[1, 1, 3] = 1 # Star at (1, 3) in image 1 - images[2, 3, 1] = 1 # Star at (3, 1) in image 2 - return images - - @pytest.fixture def simple_star_and_mask(): """Fixture for an image with multiple non-zero pixels for centroid calculation.""" @@ -68,12 +52,6 @@ def simple_star_and_mask(): return image, mask -@pytest.fixture -def identity_mask(): - """Creates a mask where all pixels are fully considered.""" - return np.ones((5, 5)) - - @pytest.fixture def simple_image_with_mask(simple_image): """Fixture for a batch of star images with masks.""" @@ -129,133 +107,92 @@ def batch_images(): return images -def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_mask): - """Test compute_zernike_tip_tilt with single batch input and mocks.""" - # Mock the CentroidEstimator class - mock_centroid_calc = mocker.patch( - "wf_psf.data.centroids.CentroidEstimator", autospec=True - ) - - # Create a mock instance and configure get_intra_pixel_shifts() - mock_instance = mock_centroid_calc.return_value - mock_instance.get_intra_pixel_shifts.return_value = np.array( - [[0.05, -0.02]] - ) # Shape (1, 2) - - # Mock shift_x_y_to_zk1_2_wavediff to return predictable values - mock_shift_fn = mocker.patch( - "wf_psf.data.centroids.shift_x_y_to_zk1_2_wavediff", - side_effect=lambda shift: shift * 0.5, # Mocked conversion for test - ) - - # Define test inputs (batch of 1 image) - pixel_sampling = 12e-6 - reference_shifts = [-1 / 3, -1 / 3] # Default Euclid conditions - - # Run the function - zernike_corrections = compute_zernike_tip_tilt( - simple_image, identity_mask, pixel_sampling, reference_shifts - ) - zernike_corrections = compute_zernike_tip_tilt( - simple_image, identity_mask, pixel_sampling, reference_shifts - ) - - # Expected shifts based on centroid calculation - expected_dx = reference_shifts[1] - (-0.02) # Expected x-axis shift in meters - expected_dy = reference_shifts[0] - 0.05 # Expected y-axis shift in meters - - # Expected calls to the mocked function - # Extract the arguments passed to mock_shift_fn - args, _ = mock_shift_fn.call_args_list[0] # Get the first call args - - # Compare expected values with the actual arguments passed to the mock function - np.testing.assert_allclose( - args[0][0], expected_dx * pixel_sampling, rtol=1e-7, atol=0 +def test_compute_centroid_correction_with_masks(mock_data): + """Test compute_centroid_correction function with masks present.""" + # Given that compute_centroid_correction expects a model_params and data object + model_params = RecursiveNamespace( + pix_sampling=12e-6, # Example pixel sampling in meters + correct_centroids=True, + reference_shifts=["-1/3", "-1/3"], ) - # Check dy values similarly - np.testing.assert_allclose( - args[0][1], expected_dy * pixel_sampling, rtol=1e-7, atol=0 - ) - - # Expected values based on mock side_effect (0.5 * shift) - np.testing.assert_allclose( - zernike_corrections[0, 0], expected_dx * pixel_sampling * 0.5 - ) # Zk1 - np.testing.assert_allclose( - zernike_corrections[0, 1], expected_dy * pixel_sampling * 0.5 - ) # Zk2 - - -def test_compute_zernike_tip_tilt_batch(mocker, multiple_images): - """Test compute_zernike_tip_tilt with batch input and mocks.""" - # Mock the CentroidEstimator class - mock_centroid_calc = mocker.patch( - "wf_psf.data.centroids.CentroidEstimator", autospec=True - ) - - # Create a mock instance and configure get_intra_pixel_shifts() - mock_instance = mock_centroid_calc.return_value - mock_instance.get_intra_pixel_shifts.return_value = np.array( - [[0.05, -0.02], [0.04, -0.01], [0.06, -0.03]] - ) # Shape (3, 2) - - # Mock shift_x_y_to_zk1_2_wavediff to return predictable values - mock_shift_fn = mocker.patch( - "wf_psf.data.centroids.shift_x_y_to_zk1_2_wavediff", - side_effect=lambda shift: shift * 0.5, # Mocked conversion for test - ) - - # Define test inputs (batch of 3 images) - pixel_sampling = 12e-6 - reference_shifts = [-1 / 3, -1 / 3] # Default Euclid conditions - - # Run the function - zernike_corrections = compute_zernike_tip_tilt( - star_images=multiple_images, - pixel_sampling=pixel_sampling, - reference_shifts=reference_shifts, - ) + # Mock the internal function calls: + with ( + patch("wf_psf.data.centroids.extract_star_data") as mock_extract_star_data, + patch( + "wf_psf.data.centroids.compute_zernike_tip_tilt" + ) as mock_compute_zernike_tip_tilt, + ): - # Check if the mock function was called once with the full batch - assert len(mock_shift_fn.call_args_list) == 1, ( - f"Expected 1 call, but got {len(mock_shift_fn.call_args_list)}" + # Mock the return values of extract_star_data and compute_zernike_tip_tilt + mock_extract_star_data.side_effect = lambda data, train_key, test_key: ( + np.array([[1, 2], [3, 4]]) + if train_key == "noisy_stars" + else np.array([[5, 6], [7, 8]]) + ) + mock_compute_zernike_tip_tilt.return_value = np.array([[0.1, 0.2], [0.3, 0.4]]) + + # Call the function under test + result = compute_centroid_correction(model_params, mock_data) + + # Ensure the result has the correct shape + assert result.shape == (4, 3) # Should be (n_stars, 3 Zernike components) + + assert np.allclose( + result[0, :], np.array([0, -0.1, -0.2]) + ) # First star Zernike coefficients + assert np.allclose( + result[1, :], np.array([0, -0.3, -0.4]) + ) # Second star Zernike coefficients + + +def test_compute_centroid_correction_without_masks(mock_data): + """Test compute_centroid_correction function when no masks are provided.""" + # Remove masks from mock_data + mock_data.test_data.dataset["masks"] = None + mock_data.training_data.dataset["masks"] = None + + # Define model parameters + model_params = RecursiveNamespace( + pix_sampling=12e-6, # Example pixel sampling in meters + correct_centroids=True, + reference_shifts=["-1/3", "-1/3"], ) - # Get the arguments passed to the mock function for the batch of images - args, _ = mock_shift_fn.call_args_list[0] + # Mock internal function calls + with ( + patch("wf_psf.data.centroids.extract_star_data") as mock_extract_star_data, + patch( + "wf_psf.data.centroids.compute_zernike_tip_tilt" + ) as mock_compute_zernike_tip_tilt, + ): - print("Shape of args[0]:", args[0].shape) - print("Contents of args[0]:", args[0]) - print("Mock function call args list:", mock_shift_fn.call_args_list) + # Mock extract_star_data to return synthetic star postage stamps + mock_extract_star_data.side_effect = lambda data, train_key, test_key: ( + np.array([[1, 2], [3, 4]]) + if train_key == "noisy_stars" + else np.array([[5, 6], [7, 8]]) + ) - # Reshape args[0] to (N, 2) for batch processing - args_array = np.array(args[0]).reshape(-1, 2) + # Mock compute_zernike_tip_tilt assuming no masks + mock_compute_zernike_tip_tilt.return_value = np.array([[0.1, 0.2], [0.3, 0.4]]) - # Process the displacements and expected values for each image in the batch - expected_dx = ( - reference_shifts[1] - mock_instance.get_intra_pixel_shifts.return_value[:, 1] - ) # Expected x-axis shift in meters + # Call function under test + result = compute_centroid_correction(model_params, mock_data) - expected_dy = ( - reference_shifts[0] - mock_instance.get_intra_pixel_shifts.return_value[:, 0] - ) # Expected y-axis shift in meters + # Validate result shape + assert result.shape == (4, 3) # (n_stars, 3 Zernike components) - # Compare expected values with the actual arguments passed to the mock function - np.testing.assert_allclose( - args_array[:, 0], expected_dx * pixel_sampling, rtol=1e-7, atol=0 - ) - np.testing.assert_allclose( - args_array[:, 1], expected_dy * pixel_sampling, rtol=1e-7, atol=0 - ) - - # Expected values based on mock side_effect (0.5 * shift) - np.testing.assert_allclose( - zernike_corrections[:, 0], expected_dx * pixel_sampling * 0.5 - ) # Zk1 for each image - np.testing.assert_allclose( - zernike_corrections[:, 1], expected_dy * pixel_sampling * 0.5 - ) # Zk2 for each image + # Validate expected values (adjust based on behavior) + expected_result = np.array( + [ + [0, -0.1, -0.2], # From training data + [0, -0.3, -0.4], + [0, -0.1, -0.2], # From test data (reused mocked return) + [0, -0.3, -0.4], + ] + ) + assert np.allclose(result, expected_result) # Test for centroid calculation without mask @@ -442,9 +379,9 @@ def test_intra_pixel_shifts(simple_image_with_centroid): expected_y_shift = 2.7 - 2.0 # yc - yc0 # Check that the shifts are correct - assert np.isclose(shifts[0], expected_x_shift), ( - f"Expected {expected_x_shift}, got {shifts[0]}" - ) - assert np.isclose(shifts[1], expected_y_shift), ( - f"Expected {expected_y_shift}, got {shifts[1]}" - ) + assert np.isclose( + shifts[0], expected_x_shift + ), f"Expected {expected_x_shift}, got {shifts[0]}" + assert np.isclose( + shifts[1], expected_y_shift + ), f"Expected {expected_y_shift}, got {shifts[1]}" diff --git a/src/wf_psf/tests/test_data/conftest.py b/src/wf_psf/tests/test_data/conftest.py index 6159d53a..04a56893 100644 --- a/src/wf_psf/tests/test_data/conftest.py +++ b/src/wf_psf/tests/test_data/conftest.py @@ -9,8 +9,11 @@ """ import pytest +import numpy as np +import tensorflow as tf from wf_psf.utils.read_config import RecursiveNamespace from wf_psf.psf_models import psf_models +from wf_psf.tests.test_data.test_data_utils import MockData, MockDataset training_config = RecursiveNamespace( id_name="-coherent_euclid_200stars", @@ -93,6 +96,64 @@ ) +@pytest.fixture +def mock_data(scope="module"): + """Fixture to provide mock data for testing.""" + # Mock positions and Zernike priors + training_positions = np.array([[1, 2], [3, 4]]) + test_positions = np.array([[5, 6], [7, 8]]) + training_zernike_priors = np.array([[0.1, 0.2], [0.3, 0.4]]) + test_zernike_priors = np.array([[0.5, 0.6], [0.7, 0.8]]) + + # Define dummy 5x5 image patches for stars (mock star images) + # Define varied values for 5x5 star images + noisy_stars = tf.constant([ + np.arange(25).reshape(5, 5), + np.arange(25, 50).reshape(5, 5) + ], dtype=tf.float32) + + noisy_masks = tf.constant([ + np.eye(5), + np.ones((5, 5)) + ], dtype=tf.float32) + + stars = tf.constant([ + np.full((5, 5), 100), + np.full((5, 5), 200) + ], dtype=tf.float32) + + masks = tf.constant([ + np.zeros((5, 5)), + np.tri(5) + ], dtype=tf.float32) + + return MockData( + training_positions, test_positions, training_zernike_priors, + test_zernike_priors, noisy_stars, noisy_masks, stars, masks + ) + +@pytest.fixture +def simple_image(scope="module"): + """Fixture for a simple star image.""" + num_images = 1 # Change this to test with multiple images + image = np.zeros((num_images, 5, 5)) # Create a 3D array + image[:, 2, 2] = 1 # Place the star at the center for each image + return image + +@pytest.fixture +def identity_mask(scope="module"): + """Creates a mask where all pixels are fully considered.""" + return np.ones((5, 5)) + +@pytest.fixture +def multiple_images(scope="module"): + """Fixture for a batch of images with stars at different positions.""" + images = np.zeros((3, 5, 5)) # 3 images, each of size 5x5 + images[0, 2, 2] = 1 # Star at center of image 0 + images[1, 1, 3] = 1 # Star at (1, 3) in image 1 + images[2, 3, 1] = 1 # Star at (3, 1) in image 2 + return images + @pytest.fixture(scope="module", params=[data]) def data_params(): return data diff --git a/src/wf_psf/tests/test_data/training_preprocessing_test.py b/src/wf_psf/tests/test_data/data_handler_test.py similarity index 52% rename from src/wf_psf/tests/test_data/training_preprocessing_test.py rename to src/wf_psf/tests/test_data/data_handler_test.py index 3efc8272..5d167a6b 100644 --- a/src/wf_psf/tests/test_data/training_preprocessing_test.py +++ b/src/wf_psf/tests/test_data/data_handler_test.py @@ -1,81 +1,16 @@ import pytest import numpy as np import tensorflow as tf -from wf_psf.utils.read_config import RecursiveNamespace -from wf_psf.data.training_preprocessing import ( +from wf_psf.data.data_handler import ( DataHandler, get_obs_positions, - get_zernike_prior, extract_star_data, - compute_centroid_correction, ) +from wf_psf.utils.read_config import RecursiveNamespace +import logging from unittest.mock import patch -class MockData: - def __init__( - self, - training_positions, - test_positions, - training_zernike_priors, - test_zernike_priors, - noisy_stars=None, - noisy_masks=None, - stars=None, - masks=None, - ): - self.training_data = MockDataset( - positions=training_positions, - zernike_priors=training_zernike_priors, - star_type="noisy_stars", - stars=noisy_stars, - masks=noisy_masks, - ) - self.test_data = MockDataset( - positions=test_positions, - zernike_priors=test_zernike_priors, - star_type="stars", - stars=stars, - masks=masks, - ) - - -class MockDataset: - def __init__(self, positions, zernike_priors, star_type, stars, masks): - self.dataset = { - "positions": positions, - "zernike_prior": zernike_priors, - star_type: stars, - "masks": masks, - } - - -@pytest.fixture -def mock_data(): - # Mock data for testing - # Mock training and test positions and Zernike priors - training_positions = np.array([[1, 2], [3, 4]]) - test_positions = np.array([[5, 6], [7, 8]]) - training_zernike_priors = np.array([[0.1, 0.2], [0.3, 0.4]]) - test_zernike_priors = np.array([[0.5, 0.6], [0.7, 0.8]]) - # Mock noisy stars, stars and masks - noisy_stars = tf.constant([[1, 2], [3, 4]], dtype=tf.float32) - noisy_masks = tf.constant([[1], [0]], dtype=tf.float32) - stars = tf.constant([[5, 6], [7, 8]], dtype=tf.float32) - masks = tf.constant([[0], [1]], dtype=tf.float32) - - return MockData( - training_positions, - test_positions, - training_zernike_priors, - test_zernike_priors, - noisy_stars, - noisy_masks, - stars, - masks, - ) - - def test_load_train_dataset(tmp_path, data_params, simPSF): # Create a temporary directory and a temporary data file data_dir = tmp_path / "data" @@ -170,7 +105,7 @@ def test_load_train_dataset_missing_noisy_stars(tmp_path, data_params, simPSF): "training", data_params, simPSF, n_bins_lambda, load_data=False ) - with patch("wf_psf.data.training_preprocessing.logger.warning") as mock_warning: + with patch("wf_psf.data.data_handler.logger.warning") as mock_warning: data_handler.load_dataset() mock_warning.assert_called_with("Missing 'noisy_stars' in training dataset.") @@ -197,7 +132,7 @@ def test_load_test_dataset_missing_stars(tmp_path, data_params, simPSF): "test", data_params, simPSF, n_bins_lambda, load_data=False ) - with patch("wf_psf.data.training_preprocessing.logger.warning") as mock_warning: + with patch("wf_psf.data.data_handler.logger.warning") as mock_warning: data_handler.load_dataset() mock_warning.assert_called_with("Missing 'stars' in test dataset.") @@ -230,40 +165,21 @@ def test_get_obs_positions(mock_data): assert tf.reduce_all(tf.equal(observed_positions, expected_positions)) -def test_get_zernike_prior(model_params, mock_data): - zernike_priors = get_zernike_prior(model_params, mock_data) - expected_shape = ( - 4, - 2, - ) # Assuming 2 Zernike priors for each dataset (training and test) - assert zernike_priors.shape == expected_shape - - -def test_get_zernike_prior_dtype(model_params, mock_data): - zernike_priors = get_zernike_prior(model_params, mock_data) - assert zernike_priors.dtype == np.float32 - - -def test_get_zernike_prior_concatenation(model_params, mock_data): - zernike_priors = get_zernike_prior(model_params, mock_data) - expected_zernike_priors = tf.convert_to_tensor( - np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]]), dtype=tf.float32 - ) - - assert np.array_equal(zernike_priors, expected_zernike_priors) - - -def test_get_zernike_prior_empty_data(model_params): - empty_data = MockData(np.array([]), np.array([]), np.array([]), np.array([])) - zernike_priors = get_zernike_prior(model_params, empty_data) - assert zernike_priors.shape == tf.TensorShape([0]) # Check for empty array shape - - def test_extract_star_data_valid_keys(mock_data): """Test extracting valid data from the dataset.""" result = extract_star_data(mock_data, train_key="noisy_stars", test_key="stars") - expected = np.array([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=np.float32) + expected = tf.concat( + [ + tf.constant( + [np.arange(25).reshape(5, 5), np.arange(25, 50).reshape(5, 5)], + dtype=tf.float32, + ), + tf.constant([np.full((5, 5), 100), np.full((5, 5), 200)], dtype=tf.float32), + ], + axis=0, + ) + np.testing.assert_array_equal(result, expected) @@ -271,7 +187,13 @@ def test_extract_star_data_masks(mock_data): """Test extracting star masks from the dataset.""" result = extract_star_data(mock_data, train_key="masks", test_key="masks") - expected = np.array([[1], [0], [0], [1]], dtype=np.float32) + mask0 = np.eye(5, dtype=np.float32) + mask1 = np.ones((5, 5), dtype=np.float32) + mask2 = np.zeros((5, 5), dtype=np.float32) + mask3 = np.tri(5, dtype=np.float32) + + expected = np.array([mask0, mask1, mask2, mask3], dtype=np.float32) + np.testing.assert_array_equal(result, expected) @@ -297,96 +219,6 @@ def test_extract_star_data_tensor_conversion(mock_data): assert result.dtype == np.float32, "The NumPy array should have dtype float32" -def test_compute_centroid_correction_with_masks(mock_data): - """Test compute_centroid_correction function with masks present.""" - # Given that compute_centroid_correction expects a model_params and data object - model_params = RecursiveNamespace( - pix_sampling=12e-6, # Example pixel sampling in meters - correct_centroids=True, - reference_shifts=["-1/3", "-1/3"], - ) - - # Mock the internal function calls: - with ( - patch( - "wf_psf.data.training_preprocessing.extract_star_data" - ) as mock_extract_star_data, - patch( - "wf_psf.data.training_preprocessing.compute_zernike_tip_tilt" - ) as mock_compute_zernike_tip_tilt, - ): - # Mock the return values of extract_star_data and compute_zernike_tip_tilt - mock_extract_star_data.side_effect = lambda data, train_key, test_key: ( - np.array([[1, 2], [3, 4]]) - if train_key == "noisy_stars" - else np.array([[5, 6], [7, 8]]) - ) - mock_compute_zernike_tip_tilt.return_value = np.array([[0.1, 0.2], [0.3, 0.4]]) - - # Call the function under test - result = compute_centroid_correction(model_params, mock_data) - - # Ensure the result has the correct shape - assert result.shape == (4, 3) # Should be (n_stars, 3 Zernike components) - - assert np.allclose( - result[0, :], np.array([0, -0.1, -0.2]) - ) # First star Zernike coefficients - assert np.allclose( - result[1, :], np.array([0, -0.3, -0.4]) - ) # Second star Zernike coefficients - - -def test_compute_centroid_correction_without_masks(mock_data): - """Test compute_centroid_correction function when no masks are provided.""" - # Remove masks from mock_data - mock_data.test_data.dataset["masks"] = None - mock_data.training_data.dataset["masks"] = None - - # Define model parameters - model_params = RecursiveNamespace( - pix_sampling=12e-6, # Example pixel sampling in meters - correct_centroids=True, - reference_shifts=["-1/3", "-1/3"], - ) - - # Mock internal function calls - with ( - patch( - "wf_psf.data.training_preprocessing.extract_star_data" - ) as mock_extract_star_data, - patch( - "wf_psf.data.training_preprocessing.compute_zernike_tip_tilt" - ) as mock_compute_zernike_tip_tilt, - ): - # Mock extract_star_data to return synthetic star postage stamps - mock_extract_star_data.side_effect = lambda data, train_key, test_key: ( - np.array([[1, 2], [3, 4]]) - if train_key == "noisy_stars" - else np.array([[5, 6], [7, 8]]) - ) - - # Mock compute_zernike_tip_tilt assuming no masks - mock_compute_zernike_tip_tilt.return_value = np.array([[0.1, 0.2], [0.3, 0.4]]) - - # Call function under test - result = compute_centroid_correction(model_params, mock_data) - - # Validate result shape - assert result.shape == (4, 3) # (n_stars, 3 Zernike components) - - # Validate expected values (adjust based on behavior) - expected_result = np.array( - [ - [0, -0.1, -0.2], # From training data - [0, -0.3, -0.4], - [0, -0.1, -0.2], # From test data (reused mocked return) - [0, -0.3, -0.4], - ] - ) - assert np.allclose(result, expected_result) - - def test_reference_shifts_broadcasting(): reference_shifts = [-1 / 3, -1 / 3] # Example reference_shifts shifts = np.random.rand(2, 2400) # Example shifts array diff --git a/src/wf_psf/tests/test_data/data_zernike_utils_test.py b/src/wf_psf/tests/test_data/data_zernike_utils_test.py new file mode 100644 index 00000000..692624be --- /dev/null +++ b/src/wf_psf/tests/test_data/data_zernike_utils_test.py @@ -0,0 +1,133 @@ + +import pytest +import numpy as np +import tensorflow as tf +from wf_psf.data.data_zernike_utils import ( + get_zernike_prior, + compute_zernike_tip_tilt, +) +from wf_psf.tests.test_data.test_data_utils import MockData, MockDataset + +def test_get_zernike_prior(model_params, mock_data): + zernike_priors = get_zernike_prior(model_params, mock_data) + expected_shape = ( + 4, + 2, + ) # Assuming 2 Zernike priors for each dataset (training and test) + assert zernike_priors.shape == expected_shape + + +def test_get_zernike_prior_dtype(model_params, mock_data): + zernike_priors = get_zernike_prior(model_params, mock_data) + assert zernike_priors.dtype == np.float32 + + +def test_get_zernike_prior_concatenation(model_params, mock_data): + zernike_priors = get_zernike_prior(model_params, mock_data) + expected_zernike_priors = tf.convert_to_tensor( + np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]]), dtype=tf.float32 + ) + + assert np.array_equal(zernike_priors, expected_zernike_priors) + + +def test_get_zernike_prior_empty_data(model_params): + empty_data = MockData(np.array([]), np.array([]), np.array([]), np.array([])) + zernike_priors = get_zernike_prior(model_params, empty_data) + assert zernike_priors.shape == tf.TensorShape([0]) # Check for empty array shape + + +def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_mask): + """Test compute_zernike_tip_tilt with single batch input and mocks.""" + + # Mock the CentroidEstimator class + mock_centroid_calc = mocker.patch("wf_psf.data.centroids.CentroidEstimator", autospec=True) + + # Create a mock instance and configure get_intra_pixel_shifts() + mock_instance = mock_centroid_calc.return_value + mock_instance.get_intra_pixel_shifts.return_value = np.array([[0.05, -0.02]]) # Shape (1, 2) + + # Mock shift_x_y_to_zk1_2_wavediff to return predictable values + mock_shift_fn = mocker.patch( + "wf_psf.data.data_zernike_utils.shift_x_y_to_zk1_2_wavediff", + side_effect=lambda shift: shift * 0.5 # Mocked conversion for test + ) + + # Define test inputs (batch of 1 image) + pixel_sampling = 12e-6 + reference_shifts = [-1 / 3, -1 / 3] # Default Euclid conditions + + # Run the function + zernike_corrections = compute_zernike_tip_tilt(simple_image, identity_mask, pixel_sampling, reference_shifts) + zernike_corrections = compute_zernike_tip_tilt(simple_image, identity_mask, pixel_sampling, reference_shifts) + + # Expected shifts based on centroid calculation + expected_dx = (reference_shifts[1] - (-0.02)) # Expected x-axis shift in meters + expected_dy = (reference_shifts[0] - 0.05) # Expected y-axis shift in meters + + # Expected calls to the mocked function + # Extract the arguments passed to mock_shift_fn + args, _ = mock_shift_fn.call_args_list[0] # Get the first call args + + # Compare expected values with the actual arguments passed to the mock function + np.testing.assert_allclose(args[0][0], expected_dx * pixel_sampling, rtol=1e-7, atol=0) + + # Check dy values similarly + np.testing.assert_allclose(args[0][1], expected_dy * pixel_sampling, rtol=1e-7, atol=0) + + # Expected values based on mock side_effect (0.5 * shift) + np.testing.assert_allclose(zernike_corrections[0, 0], expected_dx * pixel_sampling * 0.5) # Zk1 + np.testing.assert_allclose(zernike_corrections[0, 1], expected_dy * pixel_sampling * 0.5) # Zk2 + +def test_compute_zernike_tip_tilt_batch(mocker, multiple_images): + """Test compute_zernike_tip_tilt with batch input and mocks.""" + + # Mock the CentroidEstimator class + mock_centroid_calc = mocker.patch("wf_psf.data.centroids.CentroidEstimator", autospec=True) + + # Create a mock instance and configure get_intra_pixel_shifts() + mock_instance = mock_centroid_calc.return_value + mock_instance.get_intra_pixel_shifts.return_value = np.array([[0.05, -0.02], [0.04, -0.01], [0.06, -0.03]]) # Shape (3, 2) + + # Mock shift_x_y_to_zk1_2_wavediff to return predictable values + mock_shift_fn = mocker.patch( + "wf_psf.data.data_zernike_utils.shift_x_y_to_zk1_2_wavediff", + side_effect=lambda shift: shift * 0.5 # Mocked conversion for test + ) + + # Define test inputs (batch of 3 images) + pixel_sampling = 12e-6 + reference_shifts = [-1 / 3, -1 / 3] # Default Euclid conditions + + # Run the function + zernike_corrections = compute_zernike_tip_tilt( + star_images=multiple_images, + pixel_sampling=pixel_sampling, + reference_shifts=reference_shifts + ) + + # Check if the mock function was called once with the full batch + assert len(mock_shift_fn.call_args_list) == 1, f"Expected 1 call, but got {len(mock_shift_fn.call_args_list)}" + + # Get the arguments passed to the mock function for the batch of images + args, _ = mock_shift_fn.call_args_list[0] + + print("Shape of args[0]:", args[0].shape) + print("Contents of args[0]:", args[0]) + print("Mock function call args list:", mock_shift_fn.call_args_list) + + # Reshape args[0] to (N, 2) for batch processing + args_array = np.array(args[0]).reshape(-1, 2) + + # Process the displacements and expected values for each image in the batch + expected_dx = reference_shifts[1] - mock_instance.get_intra_pixel_shifts.return_value[:, 1] # Expected x-axis shift in meters + + expected_dy = reference_shifts[0] - mock_instance.get_intra_pixel_shifts.return_value[:, 0] # Expected y-axis shift in meters + + # Compare expected values with the actual arguments passed to the mock function + np.testing.assert_allclose(args_array[:, 0], expected_dx * pixel_sampling, rtol=1e-7, atol=0) + np.testing.assert_allclose(args_array[:, 1], expected_dy * pixel_sampling, rtol=1e-7, atol=0) + + # Expected values based on mock side_effect (0.5 * shift) + np.testing.assert_allclose(zernike_corrections[:, 0], expected_dx * pixel_sampling * 0.5) # Zk1 for each image + np.testing.assert_allclose(zernike_corrections[:, 1], expected_dy * pixel_sampling * 0.5) # Zk2 for each image \ No newline at end of file diff --git a/src/wf_psf/tests/test_data/test_data_utils.py b/src/wf_psf/tests/test_data/test_data_utils.py new file mode 100644 index 00000000..a5ead298 --- /dev/null +++ b/src/wf_psf/tests/test_data/test_data_utils.py @@ -0,0 +1,30 @@ + +class MockDataset: + def __init__(self, positions, zernike_priors, star_type, stars, masks): + self.dataset = {"positions": positions, "zernike_prior": zernike_priors, star_type: stars, "masks": masks} + +class MockData: + def __init__( + self, + training_positions, + test_positions, + training_zernike_priors, + test_zernike_priors, + noisy_stars=None, + noisy_masks=None, + stars=None, + masks=None, + ): + self.training_data = MockDataset( + positions=training_positions, + zernike_priors=training_zernike_priors, + star_type="noisy_stars", + stars=noisy_stars, + masks=noisy_masks) + self.test_data = MockDataset( + positions=test_positions, + zernike_priors=test_zernike_priors, + star_type="stars", + stars=stars, + masks=masks) + diff --git a/src/wf_psf/tests/test_metrics/metrics_interface_test.py b/src/wf_psf/tests/test_metrics/metrics_interface_test.py index 12d51447..bf52f0aa 100644 --- a/src/wf_psf/tests/test_metrics/metrics_interface_test.py +++ b/src/wf_psf/tests/test_metrics/metrics_interface_test.py @@ -1,7 +1,7 @@ from unittest.mock import patch, MagicMock import pytest -from wf_psf.metrics.metrics_interface import evaluate_model -from wf_psf.data.training_preprocessing import DataHandler +from wf_psf.metrics.metrics_interface import evaluate_model, MetricsParamsHandler +from wf_psf.data.data_handler import DataHandler @pytest.fixture diff --git a/src/wf_psf/tests/test_psf_models/conftest.py b/src/wf_psf/tests/test_psf_models/conftest.py index cbaae8d9..4693b343 100644 --- a/src/wf_psf/tests/test_psf_models/conftest.py +++ b/src/wf_psf/tests/test_psf_models/conftest.py @@ -12,7 +12,7 @@ from wf_psf.utils.read_config import RecursiveNamespace from wf_psf.training.train import TrainingParamsHandler from wf_psf.psf_models import psf_models -from wf_psf.data.training_preprocessing import DataHandler +from wf_psf.data.data_handler import DataHandler training_config = RecursiveNamespace( id_name="_sample_w_bis1_2k", diff --git a/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py b/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py index e25d608d..7e967465 100644 --- a/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py +++ b/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py @@ -59,7 +59,7 @@ def test_initialize_parameters(mocker, mock_data, mock_model_params, zks_prior): ) mocker.patch( - "wf_psf.data.training_preprocessing.get_obs_positions", return_value=True + "wf_psf.data.data_handler.get_obs_positions", return_value=True ) # Create TFPhysicalPolychromaticField instance diff --git a/src/wf_psf/tests/test_utils/configs_handler_test.py b/src/wf_psf/tests/test_utils/configs_handler_test.py index 9b4039a8..2ca1d17d 100644 --- a/src/wf_psf/tests/test_utils/configs_handler_test.py +++ b/src/wf_psf/tests/test_utils/configs_handler_test.py @@ -7,7 +7,7 @@ """ import pytest -from wf_psf.data.training_preprocessing import DataHandler +from wf_psf.data.data_handler import DataHandler from wf_psf.utils import configs_handler from wf_psf.utils.read_config import RecursiveNamespace from wf_psf.utils.io import FileIOHandler diff --git a/src/wf_psf/utils/configs_handler.py b/src/wf_psf/utils/configs_handler.py index 4b08a723..b7ca6444 100644 --- a/src/wf_psf/utils/configs_handler.py +++ b/src/wf_psf/utils/configs_handler.py @@ -12,7 +12,7 @@ import os import re import glob -from wf_psf.data.training_preprocessing import DataHandler +from wf_psf.data.data_handler import DataHandler from wf_psf.metrics.metrics_interface import evaluate_model from wf_psf.plotting.plots_interface import plot_metrics from wf_psf.psf_models import psf_models From 0cfb8df7c49f2ede9ab8396c1d2c2dadba52414b Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 16 May 2025 14:44:07 +0200 Subject: [PATCH 024/129] Update import statements to new module names --- .../psf_models/models/psf_model_physical_polychromatic.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index f9ed8765..baec8923 100644 --- a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -10,7 +10,8 @@ from typing import Optional import tensorflow as tf from tensorflow.python.keras.engine import data_adapter -from wf_psf.data.training_preprocessing import get_obs_positions, get_zernike_prior +from wf_psf.data.data_handler import get_obs_positions +from wf_psf.data.data_zernike_utils import get_zernike_prior from wf_psf.psf_models import psf_models as psfm from wf_psf.psf_models.tf_modules.tf_layers import ( TFPolynomialZernikeField, From 2946d47fb983eb0dc4f18db7ac9f4f33c6da351b Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 16 May 2025 14:57:57 +0200 Subject: [PATCH 025/129] Update DataHandler class docstring to include option for inference dataset handling --- src/wf_psf/data/data_handler.py | 38 +++++++++++++++++---------------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index e2946b0d..4e940115 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -25,7 +25,8 @@ class DataHandler: """Data Handler. - This class manages loading and processing of training and testing data for use during PSF model training and validation. + This class manages loading and processing of training, testing and inference data for use during PSF model training, inference, and validation. + It provides methods to access and preprocess the data. Parameters @@ -44,20 +45,21 @@ class DataHandler: Attributes ---------- - dataset_type : str - Type of dataset ("train" or "test"). - data_params : RecursiveNamespace - Parameters for the current dataset type. - dataset : dict or None - Dictionary containing the loaded dataset, including positions and stars/noisy_stars. - simPSF : PSFSimulator - Instance of the PSFSimulator class for simulating PSF models. - n_bins_lambda : int - Number of wavelength bins. - sed_data : tf.Tensor or None - TensorFlow tensor containing processed SED data for training/testing. - load_data_on_init : bool - Flag controlling whether data is loaded during initialization. + dataset_type: str + A string indicating the type of dataset ("train", "test" or "inference"). + data_params: Recursive Namespace object + A Recursive Namespace object containing training, test or inference data parameters. + dataset: dict + A dictionary containing the loaded dataset, including positions and stars/noisy_stars. + simPSF: object + An instance of the SimPSFToolkit class for simulating PSF. + n_bins_lambda: int + The number of bins in wavelength. + sed_data: tf.Tensor + A TensorFlow tensor containing the SED data for training/testing. + load_data_on_init: bool, optional + A flag used to control data loading steps. If True, data is loaded and processed + during initialization. If False, data loading is deferred until explicitly called. """ def __init__( @@ -69,9 +71,9 @@ def __init__( Parameters ---------- dataset_type : str - Type of dataset ("train" or "test"). - data_params : RecursiveNamespace - Recursive Namespace object containing parameters for both 'train' and 'test' datasets. + A string indicating the type of data ("train", "test", or "inference"). + data_params : Recursive Namespace object + A Recursive Namespace object containing parameters for both 'train', 'test', 'inference' datasets. simPSF : PSFSimulator Instance of the PSFSimulator class for simulating PSF models. n_bins_lambda : int From 0657e24614db4dad05ff6004f729891a89e49be5 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sun, 18 May 2025 12:22:27 +0200 Subject: [PATCH 026/129] Refactor data_handler with new utility functions to validate and process datasets and update docstrings --- src/wf_psf/data/data_handler.py | 184 ++++++++++++++++++++++++-------- 1 file changed, 139 insertions(+), 45 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index 4e940115..63f4650e 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -17,79 +17,122 @@ import wf_psf.utils.utils as utils import tensorflow as tf from fractions import Fraction +from typing import Optional, Union import logging logger = logging.getLogger(__name__) class DataHandler: - """Data Handler. - - This class manages loading and processing of training, testing and inference data for use during PSF model training, inference, and validation. + """ + DataHandler for WaveDiff PSF modeling. - It provides methods to access and preprocess the data. + This class manages loading, preprocessing, and TensorFlow conversion of datasets used + for PSF model training, testing, inference, and validation in the WaveDiff framework. Parameters ---------- dataset_type : str - Type of dataset ("train" or "test"). + Indicates the dataset mode ("train", "test", or "inference"). data_params : RecursiveNamespace - Recursive Namespace object containing parameters for both 'train' and 'test' datasets. + Configuration object containing dataset parameters (e.g., file paths, preprocessing flags). simPSF : PSFSimulator - Instance of the PSFSimulator class for simulating PSF models. + An instance of the PSFSimulator class used to encode SEDs into a TensorFlow-compatible format. n_bins_lambda : int - Number of wavelength bins for SED processing. + Number of wavelength bins used to discretize SEDs. load_data : bool, optional - If True, data is loaded and processed during initialization. If False, data loading - is deferred until explicitly called. Default is True. + If True (default), loads and processes data during initialization. If False, data loading + must be triggered explicitly. + dataset : dict or list, optional + If provided, uses this pre-loaded dataset instead of triggering automatic loading. + sed_data : dict or list, optional + If provided, uses this SED data directly instead of extracting it from the dataset. Attributes ---------- - dataset_type: str - A string indicating the type of dataset ("train", "test" or "inference"). - data_params: Recursive Namespace object - A Recursive Namespace object containing training, test or inference data parameters. - dataset: dict - A dictionary containing the loaded dataset, including positions and stars/noisy_stars. - simPSF: object - An instance of the SimPSFToolkit class for simulating PSF. - n_bins_lambda: int - The number of bins in wavelength. - sed_data: tf.Tensor - A TensorFlow tensor containing the SED data for training/testing. - load_data_on_init: bool, optional - A flag used to control data loading steps. If True, data is loaded and processed - during initialization. If False, data loading is deferred until explicitly called. + dataset_type : str + Indicates the dataset mode ("train", "test", or "inference"). + data_params : RecursiveNamespace + Configuration parameters for data access and structure. + simPSF : PSFSimulator + Simulator used to transform SEDs into TensorFlow-ready tensors. + n_bins_lambda : int + Number of wavelength bins in the SED representation. + load_data_on_init : bool + Whether data was loaded automatically during initialization. + dataset : dict + Loaded dataset including keys such as 'positions', 'stars', 'noisy_stars', or similar. + sed_data : tf.Tensor + TensorFlow-formatted SED data with shape [batch_size, n_bins_lambda, features]. """ def __init__( - self, dataset_type, data_params, simPSF, n_bins_lambda, load_data: bool = True + self, + dataset_type, + data_params, + simPSF, + n_bins_lambda, + load_data: bool = True, + dataset: Optional[Union[dict, list]] = None, + sed_data: Optional[Union[dict, list]] = None, ): """ - Initialize the dataset handler for PSF simulation. + Initialize the DataHandler for PSF dataset preparation. + + This constructor sets up the dataset handler used for PSF simulation tasks, + such as training, testing, or inference. It supports three modes of use: + + 1. **Manual mode** (`load_data=False`, no `dataset`): data loading and SED processing + must be triggered manually via `load_dataset()` and `process_sed_data()`. + 2. **Pre-loaded dataset mode** (`dataset` is provided): the given dataset is used directly, + and `process_sed_data()` is called with either the given `sed_data` or `dataset["SEDs"]`. + 3. **Automatic loading mode** (`load_data=True` and no `dataset`): the dataset is loaded + from disk using `data_params`, and SEDs are extracted and processed automatically. Parameters ---------- dataset_type : str - A string indicating the type of data ("train", "test", or "inference"). - data_params : Recursive Namespace object - A Recursive Namespace object containing parameters for both 'train', 'test', 'inference' datasets. + One of {"train", "test", "inference"} indicating dataset usage. + data_params : RecursiveNamespace + Configuration object with paths, preprocessing options, and metadata. simPSF : PSFSimulator - Instance of the PSFSimulator class for simulating PSF models. + Used to convert SEDs to TensorFlow format. n_bins_lambda : int - Number of wavelength bins for SED processing. + Number of wavelength bins for the SEDs. load_data : bool, optional - If True, data is loaded and processed during initialization. If False, data loading - is deferred. Default is True. + Whether to automatically load and process the dataset (default: True). + dataset : dict or list, optional + A pre-loaded dataset to use directly (overrides `load_data`). + sed_data : array-like, optional + Pre-loaded SED data to use directly. If not provided but `dataset` is, + SEDs are taken from `dataset["SEDs"]`. + + Raises + ------ + ValueError + If SEDs cannot be found in either `dataset` or as `sed_data`. + + Notes + ----- + - `self.dataset` and `self.sed_data` are both `None` if neither `dataset` nor + `load_data=True` is used. + - TensorFlow conversion is performed at the end of initialization via `convert_dataset_to_tensorflow()`. """ + self.dataset_type = dataset_type - self.data_params = data_params.__dict__[dataset_type] + self.data_params = data_params self.simPSF = simPSF self.n_bins_lambda = n_bins_lambda self.load_data_on_init = load_data - if self.load_data_on_init: + + if dataset is not None: + self.dataset = dataset + self.process_sed_data(sed_data) + self.validate_and_process_dataset() + elif self.load_data_on_init: self.load_dataset() - self.process_sed_data() + self.process_sed_data(self.dataset["SEDs"]) + self.validate_and_process_dataset() else: self.dataset = None self.sed_data = None @@ -104,6 +147,34 @@ def load_dataset(self): os.path.join(self.data_params.data_dir, self.data_params.file), allow_pickle=True, )[()] + + def validate_and_process_dataset(self): + """Validate the dataset structure and convert fields to TensorFlow tensors.""" + self._validate_dataset_structure() + self._convert_dataset_to_tensorflow() + + def _validate_dataset_structure(self): + """Validate dataset structure based on dataset_type.""" + if self.dataset is None: + raise ValueError("Dataset is None") + + if "positions" not in self.dataset: + raise ValueError("Dataset missing required field: 'positions'") + + if self.dataset_type == "train": + if "noisy_stars" not in self.dataset: + logger.warning("Missing 'noisy_stars' in 'train' dataset.") + elif self.dataset_type == "test": + if "stars" not in self.dataset: + logger.warning("Missing 'stars' in 'test' dataset.") + elif self.dataset_type == "inference": + pass + else: + logger.warning(f"Unrecognized dataset_type: {self.dataset_type}") + + def _convert_dataset_to_tensorflow(self): + """Convert dataset to TensorFlow tensors.""" + self.dataset["positions"] = tf.convert_to_tensor( self.dataset["positions"], dtype=tf.float32 ) @@ -119,22 +190,45 @@ def load_dataset(self): self.dataset["stars"] = tf.convert_to_tensor( self.dataset["stars"], dtype=tf.float32 ) - else: - logger.warning(f"Missing 'stars' in {self.dataset_type} dataset.") - elif "inference" == self.dataset_type: - pass - def process_sed_data(self): - """Process SED Data. + def process_sed_data(self, sed_data): + """ + Generate and process SED (Spectral Energy Distribution) data. - A method to generate and process SED data. + This method transforms raw SED inputs into TensorFlow tensors suitable for model input. + It generates wavelength-binned SED elements using the PSF simulator, converts the result + into a tensor, and transposes it to match the expected shape for training or inference. + Parameters + ---------- + sed_data : list or array-like + A list or array of raw SEDs, where each SED is typically a vector of flux values + or coefficients. These will be processed using the PSF simulator. + + Raises + ------ + ValueError + If `sed_data` is None. + + Notes + ----- + The resulting tensor is stored in `self.sed_data` and has shape + `(num_samples, n_bins_lambda, n_components)`, where: + - `num_samples` is the number of SEDs, + - `n_bins_lambda` is the number of wavelength bins, + - `n_components` is the number of components per SED (e.g., filters or basis terms). + + The intermediate tensor is created with `tf.float64` for precision during generation, + but is converted to `tf.float32` after processing for use in training. """ + if sed_data is None: + raise ValueError("SED data must be provided explicitly or via dataset.") + self.sed_data = [ utils.generate_SED_elems_in_tensorflow( _sed, self.simPSF, n_bins=self.n_bins_lambda, tf_dtype=tf.float64 ) - for _sed in self.dataset["SEDs"] + for _sed in sed_data ] self.sed_data = tf.convert_to_tensor(self.sed_data, dtype=tf.float32) self.sed_data = tf.transpose(self.sed_data, perm=[0, 2, 1]) From e28fe00b8c2cef8a1e2db5af4bbb5b3cee2c53c8 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sun, 18 May 2025 12:23:27 +0200 Subject: [PATCH 027/129] Update unit tests associated to changes in data_handler.py --- .../tests/test_data/data_handler_test.py | 58 ++++++++++++++----- .../tests/test_utils/configs_handler_test.py | 17 +++++- 2 files changed, 59 insertions(+), 16 deletions(-) diff --git a/src/wf_psf/tests/test_data/data_handler_test.py b/src/wf_psf/tests/test_data/data_handler_test.py index 5d167a6b..42e19a71 100644 --- a/src/wf_psf/tests/test_data/data_handler_test.py +++ b/src/wf_psf/tests/test_data/data_handler_test.py @@ -11,6 +11,26 @@ from unittest.mock import patch +def mock_sed(): + # Create a fake SED with shape (n_wavelengths,) — match what your real SEDs look like + return np.linspace(0.1, 1.0, 50) + + +def test_process_sed_data(data_params, simPSF): + # Test processing SED data without initialization + data_handler = DataHandler( + "training", data_params, simPSF, n_bins_lambda=10, load_data=False + ) + assert data_handler.sed_data is None # SED data should not be processed + + +def test_process_sed_data_auto_load(data_params, simPSF): + # load_data=True → dataset is used and SEDs processed automatically + data_handler = DataHandler( + "training", data_params, simPSF, n_bins_lambda=10, load_data=True + ) + + def test_load_train_dataset(tmp_path, data_params, simPSF): # Create a temporary directory and a temporary data file data_dir = tmp_path / "data" @@ -71,7 +91,11 @@ def test_load_test_dataset(tmp_path, data_params, simPSF): n_bins_lambda = 10 data_handler = DataHandler( - "test", data_params, simPSF, n_bins_lambda, load_data=False + dataset_type="test", + data_params=data_params.test, + simPSF=simPSF, + n_bins_lambda=n_bins_lambda, + load_data=False, ) # Call the load_dataset method @@ -83,8 +107,8 @@ def test_load_test_dataset(tmp_path, data_params, simPSF): assert np.array_equal(data_handler.dataset["SEDs"], mock_dataset["SEDs"]) -def test_load_train_dataset_missing_noisy_stars(tmp_path, data_params, simPSF): - """Test that a warning is raised if 'noisy_stars' is missing in training data.""" +def test_validate_train_dataset_missing_noisy_stars_raises(tmp_path, simPSF): + """Test that validation raises an error if 'noisy_stars' is missing in training data.""" data_dir = tmp_path / "data" data_dir.mkdir() temp_data_file = data_dir / "train_data.npy" @@ -140,24 +164,32 @@ def test_load_test_dataset_missing_stars(tmp_path, data_params, simPSF): def test_process_sed_data(data_params, simPSF): mock_dataset = { "positions": np.array([[1, 2], [3, 4]]), - "noisy_stars": np.array([[5, 6], [7, 8]]), "SEDs": np.array([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]]), + # Missing 'noisy_stars' } # Initialize DataHandler instance n_bins_lambda = 4 data_handler = DataHandler("training", data_params, simPSF, n_bins_lambda, False) - data_handler.dataset = mock_dataset - data_handler.process_sed_data() - # Assertions - assert isinstance(data_handler.sed_data, tf.Tensor) - assert data_handler.sed_data.dtype == tf.float32 - assert data_handler.sed_data.shape == ( - len(data_handler.dataset["positions"]), - n_bins_lambda, - len(["feasible_N", "feasible_wv", "SED_norm"]), + np.save(temp_data_file, mock_dataset) + + data_params = RecursiveNamespace(data_dir=str(data_dir), file="train_data.npy") + + data_handler = DataHandler( + dataset_type="train", + data_params=data_params, + simPSF=simPSF, + n_bins_lambda=10, + load_data=False, ) + data_handler.load_dataset() + data_handler.process_sed_data(mock_dataset["SEDs"]) + + with patch("wf_psf.data.data_handler.logger.warning") as mock_warning: + data_handler._validate_dataset_structure() + mock_warning.assert_called_with("Missing 'noisy_stars' in 'train' dataset.") + def test_get_obs_positions(mock_data): observed_positions = get_obs_positions(mock_data) diff --git a/src/wf_psf/tests/test_utils/configs_handler_test.py b/src/wf_psf/tests/test_utils/configs_handler_test.py index 2ca1d17d..f8299997 100644 --- a/src/wf_psf/tests/test_utils/configs_handler_test.py +++ b/src/wf_psf/tests/test_utils/configs_handler_test.py @@ -119,10 +119,21 @@ def test_data_config_handler_init(mock_training_conf, mock_data_read_conf, mocke "wf_psf.psf_models.psf_models.simPSF", return_value=mock_simPSF_instance ) - # Patch the load_dataset and process_sed_data methods inside DataHandler - mocker.patch.object(DataHandler, "load_dataset") + # Patch process_sed_data method mocker.patch.object(DataHandler, "process_sed_data") + # Patch validate_and_process_datasetmethod + mocker.patch.object(DataHandler, "validate_and_process_dataset") + + # Patch load_dataset to assign dataset + def mock_load_dataset(self): + self.dataset = { + "SEDs": ["dummy_sed_data"], + "positions": ["dummy_positions_data"], + } + + mocker.patch.object(DataHandler, "load_dataset", new=mock_load_dataset) + # Create DataConfigHandler instance data_config_handler = DataConfigHandler( "/path/to/data_config.yaml", @@ -144,7 +155,7 @@ def test_data_config_handler_init(mock_training_conf, mock_data_read_conf, mocke assert ( data_config_handler.batch_size == mock_training_conf.training.training_hparams.batch_size - ) # Default value + ) def test_training_config_handler_init(mocker, mock_training_conf, mock_file_handler): From 29e3abf94b711719a02299d94585b63e1a2ecfa4 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sun, 18 May 2025 12:24:21 +0200 Subject: [PATCH 028/129] Change exception handling in DataConfigHandler; modify args to DataHandler --- src/wf_psf/utils/configs_handler.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/wf_psf/utils/configs_handler.py b/src/wf_psf/utils/configs_handler.py index b7ca6444..a7e42ceb 100644 --- a/src/wf_psf/utils/configs_handler.py +++ b/src/wf_psf/utils/configs_handler.py @@ -129,28 +129,31 @@ class DataConfigHandler: def __init__(self, data_conf, training_model_params, batch_size=16, load_data=True): try: self.data_conf = read_conf(data_conf) - except FileNotFoundError as e: - logger.exception(e) - exit() - except TypeError as e: + except (FileNotFoundError, TypeError) as e: logger.exception(e) exit() self.simPSF = psf_models.simPSF(training_model_params) + + # Extract sub-configs early + train_params = self.data_conf.data.training + test_params = self.data_conf.data.test + self.training_data = DataHandler( dataset_type="training", - data_params=self.data_conf.data, + data_params=train_params, simPSF=self.simPSF, n_bins_lambda=training_model_params.n_bins_lda, load_data=load_data, ) self.test_data = DataHandler( dataset_type="test", - data_params=self.data_conf.data, + data_params=test_params, simPSF=self.simPSF, n_bins_lambda=training_model_params.n_bins_lda, load_data=load_data, ) + self.batch_size = batch_size From d1c86731cb3ff669f7ed55f747a5ce5c88f68851 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Mon, 19 May 2025 10:49:14 +0200 Subject: [PATCH 029/129] Add data and psf_model_imports into inference and sketch out methods --- src/wf_psf/inference/psf_inference.py | 59 +++++++++++++++++++++++++-- 1 file changed, 55 insertions(+), 4 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 4f5b39e5..5437ee61 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -12,10 +12,61 @@ import glob import logging import numpy as np -from wf_psf.psf_models import psf_models, psf_model_loader +from wf_psf.data.data_handler import DataHandler +from wf_psf.psf_models import psf_models +from wf_psf.psf_models.psf_model_loader import load_trained_psf_model import tensorflow as tf +def prepare_inputs(dataset): + + # Convert dataset to tensorflow Dataset + dataset["positions"] = tf.convert_to_tensor(dataset["positions"], dtype=tf.float32) + + + +def get_trained_psf_model(model_path, model_dir_name, cycle, training_conf, data_conf): + + trained_model_path = model_path + model_subdir = model_dir_name + cycle = cycle + + model_name = training_conf.training.model_params.model_name + id_name = training_conf.training.id_name + + weights_path_pattern = os.path.join( + trained_model_path, + model_subdir, + ( + f"{model_subdir}*_{model_name}" + f"*{id_name}_cycle{cycle}*" + ), + ) + return load_trained_psf_model( + training_conf, + data_conf, + weights_path_pattern, + ) + + +def generate_psfs(psf_model, inputs): + pass + + +def run_pipeline(): + psf_model = get_trained_psf_model( + model_path, + model_dir, + cycle, + training_conf, + data_conf + ) + inputs = prepare_inputs( + + ) + psfs = generate_psfs( + psf_model, + inputs, + batch_size=1, + ) + return psfs -#def prepare_inputs(...): ... -#def generate_psfs(...): ... -#def run_pipeline(...): ... From 0672d33252cfd1ca787516f3d481df6ce7fdc8ea Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 11:56:50 +0200 Subject: [PATCH 030/129] add base psf inference --- src/wf_psf/inference/psf_inference.py | 182 ++++++++++++++++++-------- 1 file changed, 127 insertions(+), 55 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 5437ee61..6787be61 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -1,8 +1,8 @@ """Inference. -A module which provides a set of functions to perform inference -on PSF models. It includes functions to load a trained model, -perform inference on a dataset of SEDs and positions, and generate a polychromatic PSF. +A module which provides a PSFInference class to perform inference +with trained PSF models. It is able to load a trained model, +perform inference on a dataset of SEDs and positions, and generate polychromatic PSFs. :Authors: Jennifer Pollack @@ -13,60 +13,132 @@ import logging import numpy as np from wf_psf.data.data_handler import DataHandler +from wf_psf.utils.read_config import read_conf from wf_psf.psf_models import psf_models from wf_psf.psf_models.psf_model_loader import load_trained_psf_model import tensorflow as tf +from typing import Optional -def prepare_inputs(dataset): - - # Convert dataset to tensorflow Dataset - dataset["positions"] = tf.convert_to_tensor(dataset["positions"], dtype=tf.float32) - - - -def get_trained_psf_model(model_path, model_dir_name, cycle, training_conf, data_conf): - - trained_model_path = model_path - model_subdir = model_dir_name - cycle = cycle - - model_name = training_conf.training.model_params.model_name - id_name = training_conf.training.id_name - - weights_path_pattern = os.path.join( - trained_model_path, - model_subdir, - ( - f"{model_subdir}*_{model_name}" - f"*{id_name}_cycle{cycle}*" - ), - ) - return load_trained_psf_model( - training_conf, - data_conf, - weights_path_pattern, - ) - - -def generate_psfs(psf_model, inputs): - pass - - -def run_pipeline(): - psf_model = get_trained_psf_model( - model_path, - model_dir, - cycle, - training_conf, - data_conf - ) - inputs = prepare_inputs( - - ) - psfs = generate_psfs( - psf_model, - inputs, - batch_size=1, - ) - return psfs +class PSFInference: + """Class to perform inference on PSF models.""" + + def __init__( + self, + trained_model_path: str, + model_subdir: str, + cycle: int, + training_conf_path: str, + data_conf_path: str, + batch_size: Optional[int] = None, + ): + self.trained_model_path = trained_model_path + self.model_subdir = model_subdir + self.cycle = cycle + self.training_conf_path = training_conf_path + self.data_conf_path = data_conf_path + + # Set source parameters + self.x_field = None + self.y_field = None + self.seds = None + self.trained_psf_model = None + + # Load the training and data configurations + self.training_conf = read_conf(training_conf_path) + self.data_conf = read_conf(data_conf_path) + + # Set the number of labmda bins + self.n_bins_lambda = self.training_conf.training.model_params.n_bins_lambda + + # Set the batch size + self.batch_size = ( + batch_size + if batch_size is not None + else self.training_conf.training.model_params.batch_size + ) + + # Instantiate the PSF simulator object + self.simPSF = psf_models.simPSF(self.training_conf.training.model_params) + + # Instantiate the data handler + self.data_handler = DataHandler( + dataset_type="inference", + data_params=self.data_conf, + simPSF=self.simPSF, + n_bins_lambda=self.n_bins_lambda, + load_data=False, + ) + + # Load the trained PSF model + self.trained_psf_model = self.get_trained_psf_model() + + def get_trained_psf_model(self): + """Get the trained PSF model.""" + + model_name = self.training_conf.training.model_params.model_name + id_name = self.training_conf.training.id_name + + weights_path_pattern = os.path.join( + self.trained_model_path, + self.model_subdir, + (f"{self.model_subdir}*_{model_name}" f"*{id_name}_cycle{self.cycle}*"), + ) + return load_trained_psf_model( + self.training_conf, + self.data_conf, + weights_path_pattern, + ) + + def set_source_parameters(self, x_field, y_field, seds): + """Set the input source parameters for inferring the PSF. + + Parameters + ---------- + x_field : array-like + X coordinates of the sources in WaveDiff format. + y_field : array-like + Y coordinates of the sources in WaveDiff format. + seds : list or array-like + A list or array of raw SEDs, where each SED is typically a vector of flux values + or coefficients. These will be processed using the PSF simulator. + It assumes the standard WaveDiff SED format. + + """ + # Positions array is of shape (n_sources, 2) + self.positions = tf.convert_to_tensor( + np.array([x_field, y_field]).T, dtype=tf.float32 + ) + # Process SED data + self.sed_data = self.data_handler.process_sed_data(seds) + + def get_psfs(self): + """Generate PSFs on the input source parameters.""" + + while counter < n_samples: + # Calculate the batch end element + if counter + batch_size <= n_samples: + end_sample = counter + batch_size + else: + end_sample = n_samples + + # Define the batch positions + batch_pos = pos[counter:end_sample, :] + + inputs = [self.positions, self.sed_data] + poly_psfs = self.trained_psf_model(inputs, training=False) + + return poly_psfs + + +# def run_pipeline(): +# psf_model = get_trained_psf_model( +# model_path, model_dir, cycle, training_conf, data_conf +# ) +# inputs = prepare_inputs() +# psfs = generate_psfs( +# psf_model, +# inputs, +# batch_size=1, +# ) +# return psfs From 9fbf8b154279a57d3df7d9eb41b069140689fd53 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 11:57:05 +0200 Subject: [PATCH 031/129] add common call interface through PSF models --- src/wf_psf/psf_models/models/psf_model_parametric.py | 2 +- src/wf_psf/psf_models/models/psf_model_semiparametric.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/wf_psf/psf_models/models/psf_model_parametric.py b/src/wf_psf/psf_models/models/psf_model_parametric.py index 3d85d2bc..4a28417d 100644 --- a/src/wf_psf/psf_models/models/psf_model_parametric.py +++ b/src/wf_psf/psf_models/models/psf_model_parametric.py @@ -215,7 +215,7 @@ def predict_opd(self, input_positions): return opd_maps - def call(self, inputs): + def call(self, inputs, **kwargs): """Define the PSF field forward model. [1] From positions to Zernike coefficients diff --git a/src/wf_psf/psf_models/models/psf_model_semiparametric.py b/src/wf_psf/psf_models/models/psf_model_semiparametric.py index c370956c..7b2ff04d 100644 --- a/src/wf_psf/psf_models/models/psf_model_semiparametric.py +++ b/src/wf_psf/psf_models/models/psf_model_semiparametric.py @@ -421,7 +421,7 @@ def project_DD_features(self, tf_zernike_cube=None): s_new = self.tf_np_poly_opd.S_mat - s_mat_projected self.assign_S_mat(s_new) - def call(self, inputs): + def call(self, inputs, **kwargs): """Define the PSF field forward model. [1] From positions to Zernike coefficients From b88d5625a7e779b680771a014fde2e0a30f009ca Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 14:35:29 +0200 Subject: [PATCH 032/129] add handling of inference params --- src/wf_psf/inference/psf_inference.py | 101 +++++++++++++++++++++----- 1 file changed, 81 insertions(+), 20 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 6787be61..a5ca12a6 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -27,16 +27,15 @@ def __init__( self, trained_model_path: str, model_subdir: str, - cycle: int, training_conf_path: str, data_conf_path: str, - batch_size: Optional[int] = None, + inference_conf_path: str, ): self.trained_model_path = trained_model_path self.model_subdir = model_subdir - self.cycle = cycle self.training_conf_path = training_conf_path self.data_conf_path = data_conf_path + self.inference_conf_path = inference_conf_path # Set source parameters self.x_field = None @@ -44,18 +43,24 @@ def __init__( self.seds = None self.trained_psf_model = None + # Set compute PSF placeholder + self.inferred_psfs = None + # Load the training and data configurations self.training_conf = read_conf(training_conf_path) self.data_conf = read_conf(data_conf_path) + self.inference_conf = read_conf(inference_conf_path) # Set the number of labmda bins - self.n_bins_lambda = self.training_conf.training.model_params.n_bins_lambda - + self.n_bins_lambda = self.inference_conf.inference.model_params.n_bins_lda # Set the batch size - self.batch_size = ( - batch_size - if batch_size is not None - else self.training_conf.training.model_params.batch_size + self.batch_size = self.inference_conf.inference.batch_size + # Set the cycle to use for inference + self.cycle = self.inference_conf.inference.cycle + + # Overwrite the model parameters with the inference configuration + self.training_conf.training.model_params = self.overwrite_model_params( + self.training_conf, self.inference_conf ) # Instantiate the PSF simulator object @@ -73,6 +78,29 @@ def __init__( # Load the trained PSF model self.trained_psf_model = self.get_trained_psf_model() + @staticmethod + def overwrite_model_params(training_conf=None, inference_conf=None): + """Overwrite model_params of the training_conf with the inference_conf. + + Parameters + ---------- + training_conf : RecursiveNamespace + Configuration object containing model parameters and training hyperparameters. + inference_conf : RecursiveNamespace + Configuration object containing inference-related parameters. + + """ + model_params = training_conf.training.model_params + inf_model_params = inference_conf.inference.model_params + if model_params is not None and inf_model_params is not None: + for key, value in inf_model_params.__dict__.items(): + # Check if model_params has the attribute + if hasattr(model_params, key): + # Set the attribute of model_params to the new value + setattr(model_params, key, value) + + return model_params + def get_trained_psf_model(self): """Get the trained PSF model.""" @@ -110,25 +138,58 @@ def set_source_parameters(self, x_field, y_field, seds): np.array([x_field, y_field]).T, dtype=tf.float32 ) # Process SED data - self.sed_data = self.data_handler.process_sed_data(seds) - - def get_psfs(self): - """Generate PSFs on the input source parameters.""" + self.data_handler.process_sed_data(seds) + self.sed_data = self.data_handler.sed_data + + def compute_psfs(self): + """Compute the PSFs for the input source parameters.""" + + # Check if source parameters are set + if self.positions is None or self.sed_data is None: + raise ValueError( + "Source parameters not set. Call set_source_parameters first." + ) + + # Get the number of samples + n_samples = self.positions.shape[0] + # Initialize counter + counter = 0 + # Initialize PSF array + self.inferred_psfs = np.zeros((n_samples,)) + psf_array = [] while counter < n_samples: # Calculate the batch end element - if counter + batch_size <= n_samples: - end_sample = counter + batch_size + if counter + self.batch_size <= n_samples: + end_sample = counter + self.batch_size else: end_sample = n_samples - # Define the batch positions - batch_pos = pos[counter:end_sample, :] + # Define the batch positions + batch_pos = self.positions[counter:end_sample, :] + batch_seds = self.sed_data[counter:end_sample, :, :] + + # Generate PSFs for the current batch + batch_inputs = [batch_pos, batch_seds] + batch_poly_psfs = self.trained_psf_model(batch_inputs, training=False) + + # Append to the PSF array + psf_array.append(poly_psfs) + + # Update the counter + counter += self.batch_size + + return tf.concat(psf_array, axis=0) + + def get_psfs(self) -> np.ndarray: + """Get all the generated PSFs.""" + + pass - inputs = [self.positions, self.sed_data] - poly_psfs = self.trained_psf_model(inputs, training=False) + def get_psf(self, index) -> np.ndarray: + """Generate the generated PSF at a specific index.""" - return poly_psfs + pass # def run_pipeline(): From aa9bad361ac741edc8da9042e4b94e555f3c5617 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 14:45:24 +0200 Subject: [PATCH 033/129] automatic formatting --- src/wf_psf/data/data_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index 63f4650e..84af1eba 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -28,7 +28,7 @@ class DataHandler: DataHandler for WaveDiff PSF modeling. This class manages loading, preprocessing, and TensorFlow conversion of datasets used - for PSF model training, testing, inference, and validation in the WaveDiff framework. + for PSF model training, testing, and inference in the WaveDiff framework. Parameters ---------- From 98e4806c540a3ff95216851477ba5a66d07508bc Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 14:45:43 +0200 Subject: [PATCH 034/129] add first completed class draft --- src/wf_psf/inference/psf_inference.py | 49 +++++++++++++++++++++------ 1 file changed, 39 insertions(+), 10 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index a5ca12a6..80a2c1c1 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -21,7 +21,23 @@ class PSFInference: - """Class to perform inference on PSF models.""" + """Class to perform inference on PSF models. + + + Parameters + ---------- + trained_model_path : str + Path to the directory containing the trained model. + model_subdir : str + Subdirectory name of the trained model. + training_conf_path : str + Path to the training configuration file used to train the model. + data_conf_path : str + Path to the data configuration file. + inference_conf_path : str + Path to the inference configuration file. + + """ def __init__( self, @@ -55,8 +71,11 @@ def __init__( self.n_bins_lambda = self.inference_conf.inference.model_params.n_bins_lda # Set the batch size self.batch_size = self.inference_conf.inference.batch_size + assert self.batch_size > 0, "Batch size must be greater than 0." # Set the cycle to use for inference self.cycle = self.inference_conf.inference.cycle + # Get output psf dimensions + self.output_dim = self.inference_conf.inference.model_params.output_dim # Overwrite the model parameters with the inference configuration self.training_conf.training.model_params = self.overwrite_model_params( @@ -73,6 +92,7 @@ def __init__( simPSF=self.simPSF, n_bins_lambda=self.n_bins_lambda, load_data=False, + dataset=None, ) # Load the trained PSF model @@ -155,8 +175,7 @@ def compute_psfs(self): # Initialize counter counter = 0 # Initialize PSF array - self.inferred_psfs = np.zeros((n_samples,)) - psf_array = [] + self.inferred_psfs = np.zeros((n_samples, self.output_dim, self.output_dim)) while counter < n_samples: # Calculate the batch end element @@ -174,22 +193,32 @@ def compute_psfs(self): batch_poly_psfs = self.trained_psf_model(batch_inputs, training=False) # Append to the PSF array - psf_array.append(poly_psfs) + self.inferred_psfs[counter:end_sample, :, :] = batch_poly_psfs.numpy() # Update the counter counter += self.batch_size - return tf.concat(psf_array, axis=0) - def get_psfs(self) -> np.ndarray: - """Get all the generated PSFs.""" + """Get all the generated PSFs. - pass + Returns + ------- + np.ndarray + The generated PSFs for the input source parameters. + Shape is (n_samples, output_dim, output_dim). + """ + return self.inferred_psfs def get_psf(self, index) -> np.ndarray: - """Generate the generated PSF at a specific index.""" + """Generate the generated PSF at a specific index. - pass + Returns + ------- + np.ndarray + The generated PSFs for the input source parameters. + Shape is (output_dim, output_dim). + """ + return self.inferred_psfs[index] # def run_pipeline(): From e3ec2d2fc18c03d69fa949fad3eb8d03afd65b00 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 14:45:55 +0200 Subject: [PATCH 035/129] add inference config file --- config/inference_conf.yaml | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 config/inference_conf.yaml diff --git a/config/inference_conf.yaml b/config/inference_conf.yaml new file mode 100644 index 00000000..af67cde6 --- /dev/null +++ b/config/inference_conf.yaml @@ -0,0 +1,28 @@ + +inference: + # Inference batch size + batch_size: 16 + + # Cycle to use for inference. Can be: 1, 2, ... + cycle: 2 + + # The following parameters will overwrite the `model_params` in the training config file. + model_params: + # Num of wavelength bins to reconstruct polychromatic objects. + n_bins_lda: 20 + + # Downsampling rate to match the oversampled model to the specified telescope's sampling. + output_Q: 3 + + # Oversampling rate used for the OPD/WFE PSF model. + oversampling_rate: 3 + + # Dimension of the pixel PSF postage stamp + output_dim: 32 + + # Dimension of the OPD/Wavefront space. + pupil_diameter: 256 + + + + From 634c81c05bcc67b1f98af241c40318900d221667 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 14:46:16 +0200 Subject: [PATCH 036/129] remove unused code --- src/wf_psf/inference/psf_inference.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 80a2c1c1..97db58da 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -219,16 +219,3 @@ def get_psf(self, index) -> np.ndarray: Shape is (output_dim, output_dim). """ return self.inferred_psfs[index] - - -# def run_pipeline(): -# psf_model = get_trained_psf_model( -# model_path, model_dir, cycle, training_conf, data_conf -# ) -# inputs = prepare_inputs() -# psfs = generate_psfs( -# psf_model, -# inputs, -# batch_size=1, -# ) -# return psfs From 9075fa6b393958e52b2f36452c3552cba7c3fa27 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 14:56:51 +0200 Subject: [PATCH 037/129] update params --- config/inference_conf.yaml | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/config/inference_conf.yaml b/config/inference_conf.yaml index af67cde6..cf4de4ff 100644 --- a/config/inference_conf.yaml +++ b/config/inference_conf.yaml @@ -12,17 +12,8 @@ inference: n_bins_lda: 20 # Downsampling rate to match the oversampled model to the specified telescope's sampling. - output_Q: 3 - - # Oversampling rate used for the OPD/WFE PSF model. - oversampling_rate: 3 + output_Q: 1 # Dimension of the pixel PSF postage stamp output_dim: 32 - - # Dimension of the OPD/Wavefront space. - pupil_diameter: 256 - - - From 1cc14cc83deb98919a27f8962c1d2ef4a6b32cf9 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 14:57:17 +0200 Subject: [PATCH 038/129] update params --- config/inference_conf.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config/inference_conf.yaml b/config/inference_conf.yaml index cf4de4ff..afb0174c 100644 --- a/config/inference_conf.yaml +++ b/config/inference_conf.yaml @@ -9,11 +9,11 @@ inference: # The following parameters will overwrite the `model_params` in the training config file. model_params: # Num of wavelength bins to reconstruct polychromatic objects. - n_bins_lda: 20 + n_bins_lda: 8 # Downsampling rate to match the oversampled model to the specified telescope's sampling. output_Q: 1 # Dimension of the pixel PSF postage stamp - output_dim: 32 + output_dim: 64 From 72c80b61e5e4206ae97fcfee4fc279a2a0abcc93 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 17:04:40 +0200 Subject: [PATCH 039/129] update inference --- config/inference_conf.yaml | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/config/inference_conf.yaml b/config/inference_conf.yaml index afb0174c..7e971957 100644 --- a/config/inference_conf.yaml +++ b/config/inference_conf.yaml @@ -2,10 +2,22 @@ inference: # Inference batch size batch_size: 16 - # Cycle to use for inference. Can be: 1, 2, ... cycle: 2 + configs: + # Path to the directory containing the trained model + training_config_path: models/ + + # Subdirectory name of the trained model + model_subdir: models + + # Path to the training configuration file used to train the model + trained_model_path: config/training_config.yaml + + # Path to the data config file (this could contain prior information) + data_conf_path: + # The following parameters will overwrite the `model_params` in the training config file. model_params: # Num of wavelength bins to reconstruct polychromatic objects. From 566668ba4ecc9911e2007f7395383b195833463b Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 17:05:00 +0200 Subject: [PATCH 040/129] reduce arguments and add compute psfs when appropiate --- src/wf_psf/inference/psf_inference.py | 48 +++++++++++++-------------- 1 file changed, 23 insertions(+), 25 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 97db58da..e73231cb 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -26,32 +26,31 @@ class PSFInference: Parameters ---------- - trained_model_path : str - Path to the directory containing the trained model. - model_subdir : str - Subdirectory name of the trained model. - training_conf_path : str - Path to the training configuration file used to train the model. - data_conf_path : str - Path to the data configuration file. inference_conf_path : str Path to the inference configuration file. """ - def __init__( - self, - trained_model_path: str, - model_subdir: str, - training_conf_path: str, - data_conf_path: str, - inference_conf_path: str, - ): - self.trained_model_path = trained_model_path - self.model_subdir = model_subdir - self.training_conf_path = training_conf_path - self.data_conf_path = data_conf_path + def __init__(self, inference_conf_path: str): + self.inference_conf_path = inference_conf_path + # Load the training and data configurations + self.inference_conf = read_conf(inference_conf_path) + + # Set config paths + self.config_paths = self.inference_conf.inference.configs.config_paths + self.trained_model_path = self.config_paths.trained_model_path + self.model_subdir = self.config_paths.model_subdir + self.training_config_path = self.config_paths.training_config_path + self.data_conf_path = self.config_paths.data_conf_path + + # Load the training and data configurations + self.training_conf = read_conf(self.training_conf_path) + if self.data_conf_path is not None: + # Load the data configuration + self.data_conf = read_conf(self.data_conf_path) + else: + self.data_conf = None # Set source parameters self.x_field = None @@ -62,11 +61,6 @@ def __init__( # Set compute PSF placeholder self.inferred_psfs = None - # Load the training and data configurations - self.training_conf = read_conf(training_conf_path) - self.data_conf = read_conf(data_conf_path) - self.inference_conf = read_conf(inference_conf_path) - # Set the number of labmda bins self.n_bins_lambda = self.inference_conf.inference.model_params.n_bins_lda # Set the batch size @@ -207,6 +201,8 @@ def get_psfs(self) -> np.ndarray: The generated PSFs for the input source parameters. Shape is (n_samples, output_dim, output_dim). """ + if self.inferred_psfs is None: + self.compute_psfs() return self.inferred_psfs def get_psf(self, index) -> np.ndarray: @@ -218,4 +214,6 @@ def get_psf(self, index) -> np.ndarray: The generated PSFs for the input source parameters. Shape is (output_dim, output_dim). """ + if self.inferred_psfs is None: + self.compute_psfs() return self.inferred_psfs[index] From 068d78eaae8b9ead2a2556400affd002db3cdd8b Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 17:08:49 +0200 Subject: [PATCH 041/129] add config handler class --- src/wf_psf/inference/psf_inference.py | 53 +++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index e73231cb..d0f2b930 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -20,6 +20,59 @@ from typing import Optional +class InferenceConfigHandler: + ids = ("inference_conf",) + + def __init__( + self, + trained_model_path: str, + model_subdir: str, + training_conf_path: str, + data_conf_path: str, + inference_conf_path: str, + ): + self.trained_model_path = trained_model_path + self.model_subdir = model_subdir + self.training_conf_path = training_conf_path + self.data_conf_path = data_conf_path + self.inference_conf_path = inference_conf_path + + # Overwrite the model parameters with the inference configuration + self.model_params = self.overwrite_model_params( + self.training_conf, self.inference_conf + ) + + def read_configurations(self): + # Load the training and data configurations + self.training_conf = read_conf(training_conf_path) + self.data_conf = read_conf(data_conf_path) + self.inference_conf = read_conf(inference_conf_path) + + @staticmethod + def overwrite_model_params(training_conf=None, inference_conf=None): + """Overwrite model_params of the training_conf with the inference_conf. + + Parameters + ---------- + training_conf : RecursiveNamespace + Configuration object containing model parameters and training hyperparameters. + inference_conf : RecursiveNamespace + Configuration object containing inference-related parameters. + + """ + model_params = training_conf.training.model_params + inf_model_params = inference_conf.inference.model_params + + if model_params is not None and inf_model_params is not None: + for key, value in inf_model_params.__dict__.items(): + # Check if model_params has the attribute + if hasattr(model_params, key): + # Set the attribute of model_params to the new value + setattr(model_params, key, value) + + return model_params + + class PSFInference: """Class to perform inference on PSF models. From ec445304cc9bc872d780d311f0679c9a712397a8 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 17:48:47 +0200 Subject: [PATCH 042/129] set up inference config handler and simplify PSFInferenc init --- src/wf_psf/inference/psf_inference.py | 103 ++++++++++++++------------ 1 file changed, 55 insertions(+), 48 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index d0f2b930..57f84080 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -23,30 +23,43 @@ class InferenceConfigHandler: ids = ("inference_conf",) - def __init__( - self, - trained_model_path: str, - model_subdir: str, - training_conf_path: str, - data_conf_path: str, - inference_conf_path: str, - ): - self.trained_model_path = trained_model_path - self.model_subdir = model_subdir - self.training_conf_path = training_conf_path - self.data_conf_path = data_conf_path + def __init__(self, inference_conf_path: str): self.inference_conf_path = inference_conf_path + # Load the inference configuration + self.read_configurations() + # Overwrite the model parameters with the inference configuration self.model_params = self.overwrite_model_params( self.training_conf, self.inference_conf ) def read_configurations(self): + """Read the configuration files.""" + # Load the inference configuration + self.inference_conf = read_conf(self.inference_conf_path) + # Set config paths + self.set_config_paths() # Load the training and data configurations - self.training_conf = read_conf(training_conf_path) - self.data_conf = read_conf(data_conf_path) - self.inference_conf = read_conf(inference_conf_path) + self.training_conf = read_conf(self.training_conf_path) + if self.data_conf_path is not None: + # Load the data configuration + self.data_conf = read_conf(self.data_conf_path) + else: + self.data_conf = None + + def set_config_paths(self): + """Extract and set the configuration paths.""" + # Set config paths + self.config_paths = self.inference_conf.inference.configs.config_paths + self.trained_model_path = self.config_paths.trained_model_path + self.model_subdir = self.config_paths.model_subdir + self.training_config_path = self.config_paths.training_config_path + self.data_conf_path = self.config_paths.data_conf_path + + def get_configs(self): + """Get the configurations.""" + return (self.inference_conf, self.training_conf, self.data_conf) @staticmethod def overwrite_model_params(training_conf=None, inference_conf=None): @@ -86,48 +99,25 @@ class PSFInference: def __init__(self, inference_conf_path: str): - self.inference_conf_path = inference_conf_path - # Load the training and data configurations - self.inference_conf = read_conf(inference_conf_path) - - # Set config paths - self.config_paths = self.inference_conf.inference.configs.config_paths - self.trained_model_path = self.config_paths.trained_model_path - self.model_subdir = self.config_paths.model_subdir - self.training_config_path = self.config_paths.training_config_path - self.data_conf_path = self.config_paths.data_conf_path + self.inference_config_handler = InferenceConfigHandler( + inference_conf_path=inference_conf_path + ) - # Load the training and data configurations - self.training_conf = read_conf(self.training_conf_path) - if self.data_conf_path is not None: - # Load the data configuration - self.data_conf = read_conf(self.data_conf_path) - else: - self.data_conf = None + self.inference_conf, self.training_conf, self.data_conf = ( + self.inference_config_handler.get_configs() + ) - # Set source parameters + # Init source parameters self.x_field = None self.y_field = None self.seds = None self.trained_psf_model = None - # Set compute PSF placeholder + # Init compute PSF placeholder self.inferred_psfs = None - # Set the number of labmda bins - self.n_bins_lambda = self.inference_conf.inference.model_params.n_bins_lda - # Set the batch size - self.batch_size = self.inference_conf.inference.batch_size - assert self.batch_size > 0, "Batch size must be greater than 0." - # Set the cycle to use for inference - self.cycle = self.inference_conf.inference.cycle - # Get output psf dimensions - self.output_dim = self.inference_conf.inference.model_params.output_dim - - # Overwrite the model parameters with the inference configuration - self.training_conf.training.model_params = self.overwrite_model_params( - self.training_conf, self.inference_conf - ) + # Load inference parameters + self.load_inference_params() # Instantiate the PSF simulator object self.simPSF = psf_models.simPSF(self.training_conf.training.model_params) @@ -145,6 +135,18 @@ def __init__(self, inference_conf_path: str): # Load the trained PSF model self.trained_psf_model = self.get_trained_psf_model() + def load_inference_params(self): + """Load the inference parameters from the configuration file.""" + # Set the number of labmda bins + self.n_bins_lambda = self.inference_conf.inference.model_params.n_bins_lda + # Set the batch size + self.batch_size = self.inference_conf.inference.batch_size + assert self.batch_size > 0, "Batch size must be greater than 0." + # Set the cycle to use for inference + self.cycle = self.inference_conf.inference.cycle + # Get output psf dimensions + self.output_dim = self.inference_conf.inference.model_params.output_dim + @staticmethod def overwrite_model_params(training_conf=None, inference_conf=None): """Overwrite model_params of the training_conf with the inference_conf. @@ -188,6 +190,11 @@ def get_trained_psf_model(self): def set_source_parameters(self, x_field, y_field, seds): """Set the input source parameters for inferring the PSF. + Note + ---- + The input source parameters are expected to be in the WaveDiff format. See the simulated data + format for more details. + Parameters ---------- x_field : array-like From 6b073bde66a051502093e4e248fc265da1a3c6f3 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 17:49:47 +0200 Subject: [PATCH 043/129] remove unused imports --- src/wf_psf/inference/psf_inference.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 57f84080..797a410f 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -4,20 +4,17 @@ with trained PSF models. It is able to load a trained model, perform inference on a dataset of SEDs and positions, and generate polychromatic PSFs. -:Authors: Jennifer Pollack +:Authors: Jennifer Pollack , Tobias Liaudat """ import os -import glob -import logging import numpy as np from wf_psf.data.data_handler import DataHandler from wf_psf.utils.read_config import read_conf from wf_psf.psf_models import psf_models from wf_psf.psf_models.psf_model_loader import load_trained_psf_model import tensorflow as tf -from typing import Optional class InferenceConfigHandler: From be08c3bc68cafaad19ef8928ad3b0b4ef02adc33 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 17:50:52 +0200 Subject: [PATCH 044/129] update inference --- config/inference_conf.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/config/inference_conf.yaml b/config/inference_conf.yaml index 7e971957..0c846fca 100644 --- a/config/inference_conf.yaml +++ b/config/inference_conf.yaml @@ -5,6 +5,7 @@ inference: # Cycle to use for inference. Can be: 1, 2, ... cycle: 2 + # Paths to the configuration files and trained model directory configs: # Path to the directory containing the trained model training_config_path: models/ From ec87d79b5933ad8c2c75641e645be2e25d927e04 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 23 May 2025 11:43:09 +0200 Subject: [PATCH 045/129] Add single-space lines to improve readability; Remove duplicated static method --- src/wf_psf/inference/psf_inference.py | 34 ++++++++------------------- 1 file changed, 10 insertions(+), 24 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 797a410f..f8496453 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -35,10 +35,13 @@ def read_configurations(self): """Read the configuration files.""" # Load the inference configuration self.inference_conf = read_conf(self.inference_conf_path) + # Set config paths self.set_config_paths() + # Load the training and data configurations self.training_conf = read_conf(self.training_conf_path) + if self.data_conf_path is not None: # Load the data configuration self.data_conf = read_conf(self.data_conf_path) @@ -136,37 +139,18 @@ def load_inference_params(self): """Load the inference parameters from the configuration file.""" # Set the number of labmda bins self.n_bins_lambda = self.inference_conf.inference.model_params.n_bins_lda + # Set the batch size self.batch_size = self.inference_conf.inference.batch_size assert self.batch_size > 0, "Batch size must be greater than 0." + # Set the cycle to use for inference self.cycle = self.inference_conf.inference.cycle + # Get output psf dimensions self.output_dim = self.inference_conf.inference.model_params.output_dim - - @staticmethod - def overwrite_model_params(training_conf=None, inference_conf=None): - """Overwrite model_params of the training_conf with the inference_conf. - - Parameters - ---------- - training_conf : RecursiveNamespace - Configuration object containing model parameters and training hyperparameters. - inference_conf : RecursiveNamespace - Configuration object containing inference-related parameters. - - """ - model_params = training_conf.training.model_params - inf_model_params = inference_conf.inference.model_params - if model_params is not None and inf_model_params is not None: - for key, value in inf_model_params.__dict__.items(): - # Check if model_params has the attribute - if hasattr(model_params, key): - # Set the attribute of model_params to the new value - setattr(model_params, key, value) - - return model_params - + + def get_trained_psf_model(self): """Get the trained PSF model.""" @@ -223,8 +207,10 @@ def compute_psfs(self): # Get the number of samples n_samples = self.positions.shape[0] + # Initialize counter counter = 0 + # Initialize PSF array self.inferred_psfs = np.zeros((n_samples, self.output_dim, self.output_dim)) From b5b31f05710bba0eaee0756dc358865f9a3544c4 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Mon, 26 May 2025 16:19:40 +0200 Subject: [PATCH 046/129] Add additional PSFInference class attributes; update set_source_parameters to use class attributes; assign variable names in get_trained_psf_model; Update class and __init__ docstrings --- src/wf_psf/inference/psf_inference.py | 83 ++++++++++++++++++++++----- 1 file changed, 68 insertions(+), 15 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index f8496453..552abd02 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -87,17 +87,67 @@ def overwrite_model_params(training_conf=None, inference_conf=None): class PSFInference: - """Class to perform inference on PSF models. + """ + Perform PSF inference using a pre-trained WaveDiff model. + This class handles the setup for PSF inference, including loading configuration + files, instantiating the PSF simulator and data handler, and preparing the + input data required for inference. Parameters ---------- - inference_conf_path : str - Path to the inference configuration file. - + inference_conf_path : str, optional + Path to the inference configuration YAML file. This file should define + paths and parameters for the inference, training, and data configurations. + x_field : array-like, optional + Array of x field-of-view coordinates in the SHE convention to be transformed + and passed to the WaveDiff model. + y_field : array-like, optional + Array of y field-of-view coordinates in the SHE convention to be transformed + and passed to the WaveDiff model. + seds : array-like, optional + Spectral energy distributions (SEDs) for the sources being modeled. These + will be used as part of the input to the PSF simulator. + + Attributes + ---------- + inference_config_handler : InferenceConfigHandler + Handler object to load and parse inference, training, and data configs. + inference_conf : dict + Dictionary containing inference configuration settings. + training_conf : dict + Dictionary containing training configuration settings. + data_conf : dict + Dictionary containing data configuration settings. + x_field : array-like + Input x coordinates after transformation (if applicable). + y_field : array-like + Input y coordinates after transformation (if applicable). + seds : array-like + Input spectral energy distributions. + trained_psf_model : keras.Model + Loaded PSF model used for prediction. + inferred_psfs : array-like or None + Array of inferred PSF images, populated after inference is performed. + simPSF : psf_models.simPSF + PSF simulator instance initialized with training model parameters. + data_handler : DataHandler + Data handler configured for inference, used to prepare inputs to the model. + n_bins_lambda : int + Number of spectral bins used for PSF simulation (loaded from config). + + Methods + ------- + load_inference_params() + Load parameters required for inference, including spectral binning. + get_trained_psf_model() + Load and return the trained Keras model for PSF inference. + run_inference() + Run the model on the input data and generate predicted PSFs. """ - def __init__(self, inference_conf_path: str): + + def __init__(self, inference_conf_path: str, x_field=None, y_field=None, seds=None): self.inference_config_handler = InferenceConfigHandler( inference_conf_path=inference_conf_path @@ -108,9 +158,9 @@ def __init__(self, inference_conf_path: str): ) # Init source parameters - self.x_field = None - self.y_field = None - self.seds = None + self.x_field = x_field + self.y_field = y_field + self.seds = seds self.trained_psf_model = None # Init compute PSF placeholder @@ -149,18 +199,21 @@ def load_inference_params(self): # Get output psf dimensions self.output_dim = self.inference_conf.inference.model_params.output_dim - + def get_trained_psf_model(self): """Get the trained PSF model.""" + # Load the trained PSF model + model_path = self.inference_config_handler.trained_model_path + model_dir_name = self.inference_config_handler model_name = self.training_conf.training.model_params.model_name id_name = self.training_conf.training.id_name weights_path_pattern = os.path.join( - self.trained_model_path, - self.model_subdir, - (f"{self.model_subdir}*_{model_name}" f"*{id_name}_cycle{self.cycle}*"), + model_path, + model_dir_name, + (f"{model_dir_name}*_{model_name}" f"*{id_name}_cycle{self.cycle}*"), ) return load_trained_psf_model( self.training_conf, @@ -168,7 +221,7 @@ def get_trained_psf_model(self): weights_path_pattern, ) - def set_source_parameters(self, x_field, y_field, seds): + def set_source_parameters(self): """Set the input source parameters for inferring the PSF. Note @@ -190,10 +243,10 @@ def set_source_parameters(self, x_field, y_field, seds): """ # Positions array is of shape (n_sources, 2) self.positions = tf.convert_to_tensor( - np.array([x_field, y_field]).T, dtype=tf.float32 + np.array([self.x_field, self.y_field]).T, dtype=tf.float32 ) # Process SED data - self.data_handler.process_sed_data(seds) + self.data_handler.process_sed_data(self.seds) self.sed_data = self.data_handler.sed_data def compute_psfs(self): From 4c8db2bca53b208845b1d9ed7fac2d50df015c7f Mon Sep 17 00:00:00 2001 From: jeipollack Date: Tue, 27 May 2025 18:09:04 +0100 Subject: [PATCH 047/129] Add checks to convert to np.ndarray and expand dimensions if needed --- src/wf_psf/data/centroids.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/wf_psf/data/centroids.py b/src/wf_psf/data/centroids.py index 27f7894b..83ccf8c7 100644 --- a/src/wf_psf/data/centroids.py +++ b/src/wf_psf/data/centroids.py @@ -247,6 +247,19 @@ def __init__( self, im, mask=None, sigma_init=7.5, n_iter=5, auto_run=True, xc=None, yc=None ): """Initialize class attributes.""" + + # Convert to np.ndarray if not already + im = np.asarray(im) + if mask is not None: + mask = np.asarray(mask) + + # Check im dimensions convert to batch, if 2D + if im.ndim == 2: + # Single stamp → convert to batch of one + im = np.expand_dims(im, axis=0) + elif im.ndim != 3: + raise ValueError(f"Expected 2D or 3D input, got shape {im.shape}") + self.im = im self.mask = mask if self.mask is not None: From fbdd32a78956f2da19b871bbec344c846d8a19ca Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 5 Jun 2025 13:53:34 +0100 Subject: [PATCH 048/129] Update pyproject.toml with numpy dependency limits - sdc-uk --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 187ead5e..3107f12c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ maintainers = [ description = 'A software framework to perform Differentiable wavefront-based PSF modelling.' dependencies = [ - "numpy>=1.26.4,<2.0", + "numpy>=1.18,<1.24", "scipy", "tensorflow==2.11.0", "tensorflow-estimator", From 94a424e1c1b73c3ad95a661e66976b66166919cb Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 6 Jun 2025 12:53:26 +0200 Subject: [PATCH 049/129] Correct name of psf_inference_test module to follow repo naming convention --- .../{test_psf_inference.py => psf_inference_test.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/wf_psf/tests/test_inference/{test_psf_inference.py => psf_inference_test.py} (100%) diff --git a/src/wf_psf/tests/test_inference/test_psf_inference.py b/src/wf_psf/tests/test_inference/psf_inference_test.py similarity index 100% rename from src/wf_psf/tests/test_inference/test_psf_inference.py rename to src/wf_psf/tests/test_inference/psf_inference_test.py From 44cf3757cf4012887afcdae57b7776a52aa70c88 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sun, 8 Jun 2025 18:22:14 +0200 Subject: [PATCH 050/129] Correct config subkey names for defining trained_model_path and trained_model_config_path --- config/inference_conf.yaml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/config/inference_conf.yaml b/config/inference_conf.yaml index 0c846fca..c9d29cb8 100644 --- a/config/inference_conf.yaml +++ b/config/inference_conf.yaml @@ -8,16 +8,16 @@ inference: # Paths to the configuration files and trained model directory configs: # Path to the directory containing the trained model - training_config_path: models/ + trained_model_path: /path/to/trained/model/ - # Subdirectory name of the trained model - model_subdir: models + # Subdirectory name of the trained model, e.g. psf_model + model_subdir: model - # Path to the training configuration file used to train the model - trained_model_path: config/training_config.yaml + # Relative Path to the training configuration file used to train the model + trained_model_config_path: config/training_config.yaml # Path to the data config file (this could contain prior information) - data_conf_path: + data_config_path: # The following parameters will overwrite the `model_params` in the training config file. model_params: From cfd3bafbfc3854e2ec1d06fbdfa6dbc1c9a3528b Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sun, 8 Jun 2025 18:24:26 +0200 Subject: [PATCH 051/129] Refactor psf_inference adding PSFInferenceEngine to separate concerns, enabling isolated testing; implement lazy loaders for config handling, model loading, and inference --- src/wf_psf/inference/psf_inference.py | 426 +++++++++++++------------- 1 file changed, 206 insertions(+), 220 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 552abd02..9012e3fc 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -9,6 +9,7 @@ """ import os +from pathlib import Path import numpy as np from wf_psf.data.data_handler import DataHandler from wf_psf.utils.read_config import read_conf @@ -20,71 +21,60 @@ class InferenceConfigHandler: ids = ("inference_conf",) - def __init__(self, inference_conf_path: str): - self.inference_conf_path = inference_conf_path + def __init__(self, inference_config_path: str): + self.inference_config_path = inference_config_path + self.inference_config = None + self.training_config = None + self.data_config = None - # Load the inference configuration - self.read_configurations() - # Overwrite the model parameters with the inference configuration - self.model_params = self.overwrite_model_params( - self.training_conf, self.inference_conf - ) - - def read_configurations(self): - """Read the configuration files.""" - # Load the inference configuration - self.inference_conf = read_conf(self.inference_conf_path) - - # Set config paths + def load_configs(self): + """Load configuration files based on the inference config.""" + self.inference_config = read_conf(self.inference_config_path) self.set_config_paths() - - # Load the training and data configurations - self.training_conf = read_conf(self.training_conf_path) - - if self.data_conf_path is not None: + self.training_config = read_conf(self.trained_model_config_path) + + if self.data_config_path is not None: # Load the data configuration - self.data_conf = read_conf(self.data_conf_path) - else: - self.data_conf = None + self.data_conf = read_conf(self.data_config_path) + def set_config_paths(self): """Extract and set the configuration paths.""" # Set config paths - self.config_paths = self.inference_conf.inference.configs.config_paths - self.trained_model_path = self.config_paths.trained_model_path - self.model_subdir = self.config_paths.model_subdir - self.training_config_path = self.config_paths.training_config_path - self.data_conf_path = self.config_paths.data_conf_path + config_paths = self.inference_config.inference.configs + + self.trained_model_path = Path(config_paths.trained_model_path) + self.model_subdir = config_paths.model_subdir + self.trained_model_config_path = self.trained_model_path / config_paths.trained_model_config_path + self.data_config_path = config_paths.data_config_path - def get_configs(self): - """Get the configurations.""" - return (self.inference_conf, self.training_conf, self.data_conf) @staticmethod - def overwrite_model_params(training_conf=None, inference_conf=None): - """Overwrite model_params of the training_conf with the inference_conf. + def overwrite_model_params(training_config=None, inference_config=None): + """ + Overwrite training model_params with values from inference_config if available. Parameters ---------- - training_conf : RecursiveNamespace - Configuration object containing model parameters and training hyperparameters. - inference_conf : RecursiveNamespace - Configuration object containing inference-related parameters. - + training_config : RecursiveNamespace + Configuration object from training phase. + inference_config : RecursiveNamespace + Configuration object from inference phase. + + Notes + ----- + Updates are applied in-place to training_config.training.model_params. """ - model_params = training_conf.training.model_params - inf_model_params = inference_conf.inference.model_params + model_params = training_config.training.model_params + inf_model_params = inference_config.inference.model_params - if model_params is not None and inf_model_params is not None: + if model_params and inf_model_params: for key, value in inf_model_params.__dict__.items(): - # Check if model_params has the attribute if hasattr(model_params, key): - # Set the attribute of model_params to the new value setattr(model_params, key, value) - return model_params - + class PSFInference: """ @@ -96,220 +86,216 @@ class PSFInference: Parameters ---------- - inference_conf_path : str, optional - Path to the inference configuration YAML file. This file should define - paths and parameters for the inference, training, and data configurations. + inference_config_path : str + Path to the inference configuration YAML file. x_field : array-like, optional - Array of x field-of-view coordinates in the SHE convention to be transformed - and passed to the WaveDiff model. + x coordinates in SHE convention. y_field : array-like, optional - Array of y field-of-view coordinates in the SHE convention to be transformed - and passed to the WaveDiff model. + y coordinates in SHE convention. seds : array-like, optional - Spectral energy distributions (SEDs) for the sources being modeled. These - will be used as part of the input to the PSF simulator. - - Attributes - ---------- - inference_config_handler : InferenceConfigHandler - Handler object to load and parse inference, training, and data configs. - inference_conf : dict - Dictionary containing inference configuration settings. - training_conf : dict - Dictionary containing training configuration settings. - data_conf : dict - Dictionary containing data configuration settings. - x_field : array-like - Input x coordinates after transformation (if applicable). - y_field : array-like - Input y coordinates after transformation (if applicable). - seds : array-like - Input spectral energy distributions. - trained_psf_model : keras.Model - Loaded PSF model used for prediction. - inferred_psfs : array-like or None - Array of inferred PSF images, populated after inference is performed. - simPSF : psf_models.simPSF - PSF simulator instance initialized with training model parameters. - data_handler : DataHandler - Data handler configured for inference, used to prepare inputs to the model. - n_bins_lambda : int - Number of spectral bins used for PSF simulation (loaded from config). - - Methods - ------- - load_inference_params() - Load parameters required for inference, including spectral binning. - get_trained_psf_model() - Load and return the trained Keras model for PSF inference. - run_inference() - Run the model on the input data and generate predicted PSFs. + Spectral energy distributions (SEDs). """ + def __init__(self, inference_config_path: str, x_field=None, y_field=None, seds=None): - def __init__(self, inference_conf_path: str, x_field=None, y_field=None, seds=None): + self.inference_config_path = inference_config_path - self.inference_config_handler = InferenceConfigHandler( - inference_conf_path=inference_conf_path - ) - - self.inference_conf, self.training_conf, self.data_conf = ( - self.inference_config_handler.get_configs() - ) - - # Init source parameters + # Inputs for the model self.x_field = x_field self.y_field = y_field self.seds = seds - self.trained_psf_model = None - - # Init compute PSF placeholder - self.inferred_psfs = None - - # Load inference parameters - self.load_inference_params() - - # Instantiate the PSF simulator object - self.simPSF = psf_models.simPSF(self.training_conf.training.model_params) - - # Instantiate the data handler - self.data_handler = DataHandler( - dataset_type="inference", - data_params=self.data_conf, - simPSF=self.simPSF, - n_bins_lambda=self.n_bins_lambda, - load_data=False, - dataset=None, + + # Internal caches for lazy-loading + self._config_handler = None + self._simPSF = None + self._data_handler = None + self._trained_psf_model = None + self._n_bins_lambda = None + self._batch_size = None + self._cycle = None + self._output_dim = None + + # Initialise PSF Inference engine + self.engine = None + + @property + def config_handler(self): + if self._config_handler is None: + self._config_handler = InferenceConfigHandler(self.inference_config_path) + self._config_handler.load_configs() + return self._config_handler + + def prepare_configs(self): + """Prepare the configuration for inference.""" + # Overwrite model parameters with inference config + self.config_handler.overwrite_model_params( + self.training_config, self.inference_config ) - # Load the trained PSF model - self.trained_psf_model = self.get_trained_psf_model() + @property + def inference_config(self): + return self.config_handler.inference_config + + @property + def training_config(self): + return self.config_handler.training_config + + @property + def data_config(self): + return self.config_handler.data_config + + @property + def simPSF(self): + if self._simPSF is None: + self._simPSF = psf_models.simPSF(self.model_params) + return self._simPSF + + @property + def data_handler(self): + if self._data_handler is None: + # Instantiate the data handler + self._data_handler = DataHandler( + dataset_type="inference", + data_params=self.data_config, + simPSF=self.simPSF, + n_bins_lambda=self.n_bins_lambda, + load_data=False, + dataset=None, + ) + return self._data_handler - def load_inference_params(self): - """Load the inference parameters from the configuration file.""" - # Set the number of labmda bins - self.n_bins_lambda = self.inference_conf.inference.model_params.n_bins_lda + @property + def trained_psf_model(self): + if self._trained_psf_model is None: + self._trained_psf_model = self.load_inference_model() + return self._trained_psf_model - # Set the batch size - self.batch_size = self.inference_conf.inference.batch_size - assert self.batch_size > 0, "Batch size must be greater than 0." + def load_inference_model(self): + # Prepare the configuration for inference + self.prepare_configs() - # Set the cycle to use for inference - self.cycle = self.inference_conf.inference.cycle - - # Get output psf dimensions - self.output_dim = self.inference_conf.inference.model_params.output_dim - - - def get_trained_psf_model(self): - """Get the trained PSF model.""" - - # Load the trained PSF model - model_path = self.inference_config_handler.trained_model_path - model_dir_name = self.inference_config_handler - model_name = self.training_conf.training.model_params.model_name - id_name = self.training_conf.training.id_name + model_path = self.config_handler.trained_model_path + model_dir = self.config_handler.model_subdir + model_name = self.training_config.training.model_params.model_name + id_name = self.training_config.training.id_name weights_path_pattern = os.path.join( model_path, - model_dir_name, - (f"{model_dir_name}*_{model_name}" f"*{id_name}_cycle{self.cycle}*"), + model_dir, + f"{model_dir}*_{model_name}*{id_name}_cycle{self.cycle}*" ) + + # Load the trained PSF model return load_trained_psf_model( - self.training_conf, - self.data_conf, + self.training_config, + self.data_config, weights_path_pattern, ) - def set_source_parameters(self): - """Set the input source parameters for inferring the PSF. - - Note - ---- - The input source parameters are expected to be in the WaveDiff format. See the simulated data - format for more details. - - Parameters - ---------- - x_field : array-like - X coordinates of the sources in WaveDiff format. - y_field : array-like - Y coordinates of the sources in WaveDiff format. - seds : list or array-like - A list or array of raw SEDs, where each SED is typically a vector of flux values - or coefficients. These will be processed using the PSF simulator. - It assumes the standard WaveDiff SED format. - - """ - # Positions array is of shape (n_sources, 2) - self.positions = tf.convert_to_tensor( + @property + def n_bins_lambda(self): + if self._n_bins_lambda is None: + self._n_bins_lambda = self.inference_config.inference.model_params.n_bins_lda + return self._n_bins_lambda + + @property + def batch_size(self): + if self._batch_size is None: + self._batch_size = self.inference_config.inference.batch_size + assert self._batch_size > 0, "Batch size must be greater than 0." + return self._batch_size + + @property + def cycle(self): + if self._cycle is None: + self._cycle = self.inference_config.inference.cycle + return self._cycle + + @property + def output_dim(self): + if self._output_dim is None: + self._output_dim = self.inference_config.inference.model_params.output_dim + return self._output_dim + + def _prepare_positions_and_seds(self): + """Preprocess and return tensors for positions and SEDs.""" + positions = tf.convert_to_tensor( np.array([self.x_field, self.y_field]).T, dtype=tf.float32 ) - # Process SED data self.data_handler.process_sed_data(self.seds) - self.sed_data = self.data_handler.sed_data + sed_data = self.data_handler.sed_data + return positions, sed_data - def compute_psfs(self): - """Compute the PSFs for the input source parameters.""" + def run_inference(self): + """Run PSF inference and return the full PSF array.""" + positions, sed_data = self._prepare_positions_and_seds() - # Check if source parameters are set - if self.positions is None or self.sed_data is None: - raise ValueError( - "Source parameters not set. Call set_source_parameters first." - ) + self.engine = PSFInferenceEngine( + trained_model=self.trained_psf_model, + batch_size=self.batch_size, + output_dim=self.output_dim, + ) + return self.engine.compute_psfs(positions, sed_data) + + def _ensure_psf_inference_completed(self): + if self.engine is None or self.engine.inferred_psfs is None: + self.run_inference() + + def get_psfs(self): + self._ensure_psf_inference_completed() + return self.engine.get_psfs() + + def get_psf(self, index): + self._ensure_psf_inference_completed() + return self.engine.get_psf(index) - # Get the number of samples - n_samples = self.positions.shape[0] +class PSFInferenceEngine: + def __init__(self, trained_model, batch_size: int, output_dim: int): + self.trained_model = trained_model + self.batch_size = batch_size + self.output_dim = output_dim + self._inferred_psfs = None + + @property + def inferred_psfs(self) -> np.ndarray: + """Access the cached inferred PSFs, if available.""" + return self._inferred_psfs + + def compute_psfs(self, positions: tf.Tensor, sed_data: tf.Tensor) -> np.ndarray: + """Compute and cache PSFs for the input source parameters.""" + n_samples = positions.shape[0] + self._inferred_psfs = np.zeros((n_samples, self.output_dim, self.output_dim), dtype=np.float32) # Initialize counter counter = 0 - - # Initialize PSF array - self.inferred_psfs = np.zeros((n_samples, self.output_dim, self.output_dim)) - while counter < n_samples: # Calculate the batch end element - if counter + self.batch_size <= n_samples: - end_sample = counter + self.batch_size - else: - end_sample = n_samples + end = min(counter + self.batch_size, n_samples) # Define the batch positions - batch_pos = self.positions[counter:end_sample, :] - batch_seds = self.sed_data[counter:end_sample, :, :] - - # Generate PSFs for the current batch + batch_pos = positions[counter:end_sample, :] + batch_seds = sed_data[counter:end_sample, :, :] batch_inputs = [batch_pos, batch_seds] - batch_poly_psfs = self.trained_psf_model(batch_inputs, training=False) - - # Append to the PSF array - self.inferred_psfs[counter:end_sample, :, :] = batch_poly_psfs.numpy() + + # Generate PSFs for the current batch + batch_psfs = self.trained_model(batch_inputs, training=False) + self.inferred_psfs[counter:end, :, :] = batch_psfs.numpy() # Update the counter - counter += self.batch_size + counter = end + + return self._inferred_psfs def get_psfs(self) -> np.ndarray: - """Get all the generated PSFs. + """Get all the generated PSFs.""" + if self._inferred_psfs is None: + raise ValueError("PSFs not yet computed. Call compute_psfs() first.") + return self._inferred_psfs + + def get_psf(self, index: int) -> np.ndarray: + """Get the PSF at a specific index.""" + if self._inferred_psfs is None: + raise ValueError("PSFs not yet computed. Call compute_psfs() first.") + return self._inferred_psfs[index] + - Returns - ------- - np.ndarray - The generated PSFs for the input source parameters. - Shape is (n_samples, output_dim, output_dim). - """ - if self.inferred_psfs is None: - self.compute_psfs() - return self.inferred_psfs - - def get_psf(self, index) -> np.ndarray: - """Generate the generated PSF at a specific index. - - Returns - ------- - np.ndarray - The generated PSFs for the input source parameters. - Shape is (output_dim, output_dim). - """ - if self.inferred_psfs is None: - self.compute_psfs() - return self.inferred_psfs[index] From d1a4fb5878e1942beeb715d6892b07194cfc85ee Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sun, 8 Jun 2025 18:26:26 +0200 Subject: [PATCH 052/129] Add unit tests for psf_inference --- .../test_inference/psf_inference_test.py | 255 ++++++++++++++++++ 1 file changed, 255 insertions(+) diff --git a/src/wf_psf/tests/test_inference/psf_inference_test.py b/src/wf_psf/tests/test_inference/psf_inference_test.py index e69de29b..5709ed6d 100644 --- a/src/wf_psf/tests/test_inference/psf_inference_test.py +++ b/src/wf_psf/tests/test_inference/psf_inference_test.py @@ -0,0 +1,255 @@ +"""UNIT TESTS FOR PACKAGE MODULE: PSF Inference. + +This module contains unit tests for the wf_psf.inference.psf_inference module. + +:Author: Jennifer Pollack + +""" + +import numpy as np +import os +from pathlib import Path +import pytest +import tensorflow as tf +from types import SimpleNamespace +from unittest.mock import MagicMock, patch, PropertyMock +from wf_psf.inference.psf_inference import ( + InferenceConfigHandler, + PSFInference, + PSFInferenceEngine +) + +from wf_psf.utils.read_config import RecursiveNamespace + +@pytest.fixture +def mock_training_config(): + training_config = RecursiveNamespace( + training=RecursiveNamespace( + id_name="mock_id", + model_params=RecursiveNamespace( + model_name="mock_model", + output_Q=2, + output_dim=32 + ) + ) + ) + return training_config + +@pytest.fixture +def mock_inference_config(): + inference_config = RecursiveNamespace( + inference=RecursiveNamespace( + batch_size=16, + cycle=2, + configs=RecursiveNamespace( + trained_model_path='/path/to/trained/model', + model_subdir='psf_model', + trained_model_config_path='config/training_config.yaml', + data_config_path=None + ), + model_params=RecursiveNamespace( + output_Q=1, + output_dim=64 + ) + ) + ) + return inference_config + + +@pytest.fixture +def psf_test_setup(mock_inference_config): + num_sources = 2 + num_bins = 10 + output_dim = 32 + + mock_positions = tf.convert_to_tensor([[0.1, 0.1], [0.2, 0.2]], dtype=tf.float32) + mock_seds = tf.convert_to_tensor(np.random.rand(num_sources, 2, num_bins), dtype=tf.float32) + expected_psfs = np.random.rand(num_sources, output_dim, output_dim).astype(np.float32) + + inference = PSFInference( + "dummy_path.yaml", + x_field=[0.1, 0.2], + y_field=[0.1, 0.2], + seds=np.random.rand(num_sources, num_bins) + ) + inference._config_handler = MagicMock() + inference._config_handler.inference_config = mock_inference_config + inference._trained_psf_model = MagicMock() + + return { + "inference": inference, + "mock_positions": mock_positions, + "mock_seds": mock_seds, + "expected_psfs": expected_psfs, + "num_sources": num_sources, + "num_bins": num_bins, + "output_dim": output_dim + } + + +def test_set_config_paths(mock_inference_config): + """Test setting configuration paths.""" + + # Initialize handler and inject mock config + config_handler = InferenceConfigHandler("fake/path") + config_handler.inference_config = mock_inference_config + + # Call the method under test + config_handler.set_config_paths() + + # Assertions + assert config_handler.trained_model_path == Path("/path/to/trained/model") + assert config_handler.model_subdir == "psf_model" + assert config_handler.trained_model_config_path == Path("/path/to/trained/model/config/training_config.yaml") + assert config_handler.data_config_path == None + + +def test_overwrite_model_params(mock_training_config, mock_inference_config): + """Test that model_params can be overwritten.""" + # Mock the model_params object with some initial values + training_config = mock_training_config + inference_config = mock_inference_config + + InferenceConfigHandler.overwrite_model_params( + training_config, inference_config + ) + + # Assert that the model_params were overwritten correctly + assert training_config.training.model_params.output_Q == 1, "output_Q should be overwritten" + assert training_config.training.model_params.output_dim == 64, "output_dim should be overwritten" + + assert training_config.training.id_name == "mock_id", "id_name should not be overwritten" + + +def test_prepare_configs(mock_training_config, mock_inference_config): + """Test preparing configurations for inference.""" + # Mock the model_params object with some initial values + training_config = mock_training_config + inference_config = mock_inference_config + + # Make copy of the original training config model_params + original_model_params = mock_training_config.training.model_params + + # Instantiate PSFInference + psf_inf = PSFInference('/dummy/path.yaml') + + # Mock the config handler attribute with a mock InferenceConfigHandler + mock_config_handler = MagicMock(spec=InferenceConfigHandler) + mock_config_handler.training_config = training_config + mock_config_handler.inference_config = inference_config + + # Patch the overwrite_model_params to use the real static method + mock_config_handler.overwrite_model_params.side_effect = InferenceConfigHandler.overwrite_model_params + + psf_inf._config_handler = mock_config_handler + + # Run prepare_configs + psf_inf.prepare_configs() + + # Assert that the training model_params were updated + assert original_model_params.output_Q == 1 + assert original_model_params.output_dim == 64 + + +def test_config_handler_lazy_load(monkeypatch): + inference = PSFInference("dummy_path.yaml") + + called = {} + + class DummyHandler: + def load_configs(self): + called['load'] = True + self.inference_config = {} + self.training_config = {} + self.data_config = {} + def overwrite_model_params(self, *args): pass + + monkeypatch.setattr("wf_psf.inference.psf_inference.InferenceConfigHandler", lambda path: DummyHandler()) + + inference.prepare_configs() + + assert 'load' in called # Confirm lazy load happened + +def test_batch_size_positive(): + inference = PSFInference("dummy_path.yaml") + inference._config_handler = MagicMock() + inference._config_handler.inference_config = SimpleNamespace( + inference=SimpleNamespace(batch_size=4, model_params=SimpleNamespace(output_dim=32)) + ) + assert inference.batch_size == 4 + + +@patch.object(PSFInference, 'prepare_configs') +@patch('wf_psf.inference.psf_inference.load_trained_psf_model') +def test_load_inference_model(mock_load_trained_psf_model, mock_prepare_configs, mock_training_config, mock_inference_config): + + data_config = MagicMock() + mock_config_handler = MagicMock(spec=InferenceConfigHandler) + mock_config_handler.trained_model_path = "mock/path/to/model" + mock_config_handler.training_config = mock_training_config + mock_config_handler.inference_config = mock_inference_config + mock_config_handler.model_subdir = "psf_model" + mock_config_handler.data_config = data_config + + psf_inf = PSFInference("dummy_path.yaml") + psf_inf._config_handler = mock_config_handler + + psf_inf.load_inference_model() + + weights_path_pattern = os.path.join( + mock_config_handler.trained_model_path, + mock_config_handler.model_subdir, + f"{mock_config_handler.model_subdir}*_{mock_config_handler.training_config.training.model_params.model_name}*{mock_config_handler.training_config.training.id_name}_cycle{mock_config_handler.inference_config.inference.cycle}*" + ) + + # Assert calls to the mocked methods + mock_prepare_configs.assert_called_once() + mock_load_trained_psf_model.assert_called_once_with( + mock_config_handler.training_config, + mock_config_handler.data_config, + weights_path_pattern + ) + + +@patch.object(PSFInference, '_prepare_positions_and_seds') +@patch.object(PSFInferenceEngine, 'compute_psfs') +def test_run_inference(mock_compute_psfs, mock_prepare_positions_and_seds, psf_test_setup): + inference = psf_test_setup["inference"] + mock_positions = psf_test_setup["mock_positions"] + mock_seds = psf_test_setup["mock_seds"] + expected_psfs = psf_test_setup["expected_psfs"] + + mock_prepare_positions_and_seds.return_value = (mock_positions, mock_seds) + mock_compute_psfs.return_value = expected_psfs + + psfs = inference.run_inference() + + assert isinstance(psfs, np.ndarray) + assert psfs.shape == expected_psfs.shape + mock_prepare_positions_and_seds.assert_called_once() + mock_compute_psfs.assert_called_once_with(mock_positions, mock_seds) + + +@patch.object(PSFInference, '_prepare_positions_and_seds') +@patch.object(PSFInferenceEngine, 'compute_psfs') +def test_get_psfs_runs_inference(mock_compute_psfs, mock_prepare_positions_and_seds, psf_test_setup): + inference = psf_test_setup["inference"] + mock_positions = psf_test_setup["mock_positions"] + mock_seds = psf_test_setup["mock_seds"] + expected_psfs = psf_test_setup["expected_psfs"] + + mock_prepare_positions_and_seds.return_value = (mock_positions, mock_seds) + + def fake_compute_psfs(positions, seds): + inference.engine._inferred_psfs = expected_psfs + return expected_psfs + + mock_compute_psfs.side_effect = fake_compute_psfs + + psfs_1 = inference.get_psfs() + assert np.all(psfs_1 == expected_psfs) + + psfs_2 = inference.get_psfs() + assert np.all(psfs_2 == expected_psfs) + + assert mock_compute_psfs.call_count == 1 From 5ea8bcb5efaa49f91bb6564a6c2e19f927f94ff9 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Thu, 12 Jun 2025 13:22:51 +0200 Subject: [PATCH 053/129] Bugfix: Ensure updated training_config.training.model_params are passed to simPSF Move call to prepare_configs() into run_inference() to ensure model_params are overwritten before preparing SEDs. --- src/wf_psf/inference/psf_inference.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 9012e3fc..01e7cf1a 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -147,7 +147,7 @@ def data_config(self): @property def simPSF(self): if self._simPSF is None: - self._simPSF = psf_models.simPSF(self.model_params) + self._simPSF = psf_models.simPSF(self.training_config.training.model_params) return self._simPSF @property @@ -171,9 +171,7 @@ def trained_psf_model(self): return self._trained_psf_model def load_inference_model(self): - # Prepare the configuration for inference - self.prepare_configs() - + """Load the trained PSF model based on the inference configuration.""" model_path = self.config_handler.trained_model_path model_dir = self.config_handler.model_subdir model_name = self.training_config.training.model_params.model_name @@ -228,6 +226,10 @@ def _prepare_positions_and_seds(self): def run_inference(self): """Run PSF inference and return the full PSF array.""" + # Prepare the configuration for inference + self.prepare_configs() + + # Prepare positions and SEDs for inference positions, sed_data = self._prepare_positions_and_seds() self.engine = PSFInferenceEngine( From c82836b7af8fbbf36230148bbb537436226717d7 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Thu, 12 Jun 2025 13:25:28 +0200 Subject: [PATCH 054/129] test(simPSF): add unit test to verify updated model_params are passed Also moves prepare_configs() assertion from test load_inference() to unit test for run_inference(). --- .../test_inference/psf_inference_test.py | 47 ++++++++++++++++--- 1 file changed, 40 insertions(+), 7 deletions(-) diff --git a/src/wf_psf/tests/test_inference/psf_inference_test.py b/src/wf_psf/tests/test_inference/psf_inference_test.py index 5709ed6d..4add460b 100644 --- a/src/wf_psf/tests/test_inference/psf_inference_test.py +++ b/src/wf_psf/tests/test_inference/psf_inference_test.py @@ -179,17 +179,15 @@ def test_batch_size_positive(): assert inference.batch_size == 4 -@patch.object(PSFInference, 'prepare_configs') @patch('wf_psf.inference.psf_inference.load_trained_psf_model') -def test_load_inference_model(mock_load_trained_psf_model, mock_prepare_configs, mock_training_config, mock_inference_config): +def test_load_inference_model(mock_load_trained_psf_model, mock_training_config, mock_inference_config): - data_config = MagicMock() mock_config_handler = MagicMock(spec=InferenceConfigHandler) mock_config_handler.trained_model_path = "mock/path/to/model" mock_config_handler.training_config = mock_training_config mock_config_handler.inference_config = mock_inference_config mock_config_handler.model_subdir = "psf_model" - mock_config_handler.data_config = data_config + mock_config_handler.data_config = MagicMock() psf_inf = PSFInference("dummy_path.yaml") psf_inf._config_handler = mock_config_handler @@ -203,17 +201,16 @@ def test_load_inference_model(mock_load_trained_psf_model, mock_prepare_configs, ) # Assert calls to the mocked methods - mock_prepare_configs.assert_called_once() mock_load_trained_psf_model.assert_called_once_with( mock_config_handler.training_config, mock_config_handler.data_config, weights_path_pattern ) - +@patch.object(PSFInference, 'prepare_configs') @patch.object(PSFInference, '_prepare_positions_and_seds') @patch.object(PSFInferenceEngine, 'compute_psfs') -def test_run_inference(mock_compute_psfs, mock_prepare_positions_and_seds, psf_test_setup): +def test_run_inference(mock_compute_psfs, mock_prepare_positions_and_seds, mock_prepare_configs, psf_test_setup): inference = psf_test_setup["inference"] mock_positions = psf_test_setup["mock_positions"] mock_seds = psf_test_setup["mock_seds"] @@ -228,6 +225,42 @@ def test_run_inference(mock_compute_psfs, mock_prepare_positions_and_seds, psf_t assert psfs.shape == expected_psfs.shape mock_prepare_positions_and_seds.assert_called_once() mock_compute_psfs.assert_called_once_with(mock_positions, mock_seds) + mock_prepare_configs.assert_called_once() + +@patch("wf_psf.inference.psf_inference.psf_models.simPSF") +def test_simpsf_uses_updated_model_params(mock_simpsf, mock_training_config, mock_inference_config): + """Test that simPSF uses the updated model parameters.""" + training_config = mock_training_config + inference_config = mock_inference_config + + # Set the expected output_Q + expected_output_Q = inference_config.inference.model_params.output_Q + training_config.training.model_params.output_Q = expected_output_Q + + # Create fake psf instance + fake_psf_instance = MagicMock() + fake_psf_instance.output_Q = expected_output_Q + mock_simpsf.return_value = fake_psf_instance + + mock_config_handler = MagicMock(spec=InferenceConfigHandler) + mock_config_handler.trained_model_path = "mock/path/to/model" + mock_config_handler.training_config = training_config + mock_config_handler.inference_config = inference_config + mock_config_handler.model_subdir = "psf_model" + mock_config_handler.data_config = MagicMock() + + modeller = PSFInference("dummy_path.yaml") + modeller._config_handler = mock_config_handler + + modeller.prepare_configs() + result = modeller.simPSF + + # Confirm simPSF was called once with the updated model_params + mock_simpsf.assert_called_once() + called_args, _ = mock_simpsf.call_args + model_params_passed = called_args[0] + assert model_params_passed.output_Q == expected_output_Q + assert result.output_Q == expected_output_Q @patch.object(PSFInference, '_prepare_positions_and_seds') From c99b1e8e0a45e2d4d76f7edbac484bdbf518be07 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 12 Jun 2025 16:34:20 +0100 Subject: [PATCH 055/129] Bug: replace self.data_conf with self.data_config --- src/wf_psf/inference/psf_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 01e7cf1a..6f798245 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -36,7 +36,7 @@ def load_configs(self): if self.data_config_path is not None: # Load the data configuration - self.data_conf = read_conf(self.data_config_path) + self.data_config = read_conf(self.data_config_path) def set_config_paths(self): From cf0e8e1050781f0e2afe3bd953290ef41a057933 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Tue, 5 Aug 2025 19:59:51 +0200 Subject: [PATCH 056/129] Change logger.warnings to ValueErrors for missing fields in datasets & remove redundant check --- src/wf_psf/data/data_handler.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index 84af1eba..f341c28d 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -161,16 +161,16 @@ def _validate_dataset_structure(self): if "positions" not in self.dataset: raise ValueError("Dataset missing required field: 'positions'") - if self.dataset_type == "train": + if self.dataset_type == "training": if "noisy_stars" not in self.dataset: - logger.warning("Missing 'noisy_stars' in 'train' dataset.") + raise ValueError(f"Missing required field 'noisy_stars' in {self.dataset_type} dataset.") elif self.dataset_type == "test": if "stars" not in self.dataset: - logger.warning("Missing 'stars' in 'test' dataset.") + raise ValueError(f"Missing required field 'stars' in {self.dataset_type} dataset.") elif self.dataset_type == "inference": pass else: - logger.warning(f"Unrecognized dataset_type: {self.dataset_type}") + raise ValueError(f"Unrecognized dataset_type: {self.dataset_type}") def _convert_dataset_to_tensorflow(self): """Convert dataset to TensorFlow tensors.""" @@ -179,12 +179,10 @@ def _convert_dataset_to_tensorflow(self): self.dataset["positions"], dtype=tf.float32 ) if self.dataset_type == "training": - if "noisy_stars" in self.dataset: self.dataset["noisy_stars"] = tf.convert_to_tensor( self.dataset["noisy_stars"], dtype=tf.float32 ) - else: - logger.warning(f"Missing 'noisy_stars' in {self.dataset_type} dataset.") + elif self.dataset_type == "test": if "stars" in self.dataset: self.dataset["stars"] = tf.convert_to_tensor( From 814d047658463909019c588dc49e28a97fd38a20 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Tue, 5 Aug 2025 20:00:21 +0200 Subject: [PATCH 057/129] Update unit tests with changes to data_handler.py --- .../tests/test_data/data_handler_test.py | 78 +++++-------------- 1 file changed, 19 insertions(+), 59 deletions(-) diff --git a/src/wf_psf/tests/test_data/data_handler_test.py b/src/wf_psf/tests/test_data/data_handler_test.py index 42e19a71..8d1b5421 100644 --- a/src/wf_psf/tests/test_data/data_handler_test.py +++ b/src/wf_psf/tests/test_data/data_handler_test.py @@ -16,22 +16,16 @@ def mock_sed(): return np.linspace(0.1, 1.0, 50) -def test_process_sed_data(data_params, simPSF): - # Test processing SED data without initialization - data_handler = DataHandler( - "training", data_params, simPSF, n_bins_lambda=10, load_data=False - ) - assert data_handler.sed_data is None # SED data should not be processed - - def test_process_sed_data_auto_load(data_params, simPSF): # load_data=True → dataset is used and SEDs processed automatically data_handler = DataHandler( - "training", data_params, simPSF, n_bins_lambda=10, load_data=True + "training", data_params.training, simPSF, n_bins_lambda=10, load_data=True ) + assert data_handler.sed_data is not None + assert data_handler.sed_data.shape[1] == 10 # n_bins_lambda -def test_load_train_dataset(tmp_path, data_params, simPSF): +def test_load_train_dataset(tmp_path, simPSF): # Create a temporary directory and a temporary data file data_dir = tmp_path / "data" data_dir.mkdir() @@ -48,9 +42,7 @@ def test_load_train_dataset(tmp_path, data_params, simPSF): np.save(temp_data_dir, mock_dataset) # Initialize DataHandler instance - data_params = RecursiveNamespace( - training=RecursiveNamespace(data_dir=str(data_dir), file="train_data.npy") - ) + data_params = RecursiveNamespace(data_dir=str(data_dir), file="train_data.npy") n_bins_lambda = 10 data_handler = DataHandler( @@ -68,7 +60,7 @@ def test_load_train_dataset(tmp_path, data_params, simPSF): assert np.array_equal(data_handler.dataset["SEDs"], mock_dataset["SEDs"]) -def test_load_test_dataset(tmp_path, data_params, simPSF): +def test_load_test_dataset(tmp_path, simPSF): # Create a temporary directory and a temporary data file data_dir = tmp_path / "data" data_dir.mkdir() @@ -85,14 +77,12 @@ def test_load_test_dataset(tmp_path, data_params, simPSF): np.save(temp_data_dir, mock_dataset) # Initialize DataHandler instance - data_params = RecursiveNamespace( - test=RecursiveNamespace(data_dir=str(data_dir), file="test_data.npy") - ) + data_params = RecursiveNamespace(data_dir=str(data_dir), file="test_data.npy") n_bins_lambda = 10 data_handler = DataHandler( dataset_type="test", - data_params=data_params.test, + data_params=data_params, simPSF=simPSF, n_bins_lambda=n_bins_lambda, load_data=False, @@ -120,21 +110,21 @@ def test_validate_train_dataset_missing_noisy_stars_raises(tmp_path, simPSF): np.save(temp_data_file, mock_dataset) - data_params = RecursiveNamespace( - training=RecursiveNamespace(data_dir=str(data_dir), file="train_data.npy") - ) + data_params = RecursiveNamespace(data_dir=str(data_dir), file="train_data.npy") n_bins_lambda = 10 data_handler = DataHandler( "training", data_params, simPSF, n_bins_lambda, load_data=False ) - with patch("wf_psf.data.data_handler.logger.warning") as mock_warning: + with pytest.raises( + ValueError, match="Missing required field 'noisy_stars' in training dataset." + ): data_handler.load_dataset() - mock_warning.assert_called_with("Missing 'noisy_stars' in training dataset.") + data_handler.validate_and_process_dataset() -def test_load_test_dataset_missing_stars(tmp_path, data_params, simPSF): +def test_load_test_dataset_missing_stars(tmp_path, simPSF): """Test that a warning is raised if 'stars' is missing in test data.""" data_dir = tmp_path / "data" data_dir.mkdir() @@ -147,48 +137,18 @@ def test_load_test_dataset_missing_stars(tmp_path, data_params, simPSF): np.save(temp_data_file, mock_dataset) - data_params = RecursiveNamespace( - test=RecursiveNamespace(data_dir=str(data_dir), file="test_data.npy") - ) + data_params = RecursiveNamespace(data_dir=str(data_dir), file="test_data.npy") n_bins_lambda = 10 data_handler = DataHandler( "test", data_params, simPSF, n_bins_lambda, load_data=False ) - with patch("wf_psf.data.data_handler.logger.warning") as mock_warning: + with pytest.raises( + ValueError, match="Missing required field 'stars' in test dataset." + ): data_handler.load_dataset() - mock_warning.assert_called_with("Missing 'stars' in test dataset.") - - -def test_process_sed_data(data_params, simPSF): - mock_dataset = { - "positions": np.array([[1, 2], [3, 4]]), - "SEDs": np.array([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]]), - # Missing 'noisy_stars' - } - # Initialize DataHandler instance - n_bins_lambda = 4 - data_handler = DataHandler("training", data_params, simPSF, n_bins_lambda, False) - - np.save(temp_data_file, mock_dataset) - - data_params = RecursiveNamespace(data_dir=str(data_dir), file="train_data.npy") - - data_handler = DataHandler( - dataset_type="train", - data_params=data_params, - simPSF=simPSF, - n_bins_lambda=10, - load_data=False, - ) - - data_handler.load_dataset() - data_handler.process_sed_data(mock_dataset["SEDs"]) - - with patch("wf_psf.data.data_handler.logger.warning") as mock_warning: - data_handler._validate_dataset_structure() - mock_warning.assert_called_with("Missing 'noisy_stars' in 'train' dataset.") + data_handler.validate_and_process_dataset() def test_get_obs_positions(mock_data): From dfae6ced67740b2c88d17ea10011e0223a96c54a Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 16 May 2025 14:41:17 +0200 Subject: [PATCH 058/129] Refactor: reorganise modules, relocate utility functions, rename modules, update import statements and unit tests --- src/wf_psf/data/data_handler.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index f341c28d..6ec6ec56 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -163,10 +163,14 @@ def _validate_dataset_structure(self): if self.dataset_type == "training": if "noisy_stars" not in self.dataset: - raise ValueError(f"Missing required field 'noisy_stars' in {self.dataset_type} dataset.") + raise ValueError( + f"Missing required field 'noisy_stars' in {self.dataset_type} dataset." + ) elif self.dataset_type == "test": if "stars" not in self.dataset: - raise ValueError(f"Missing required field 'stars' in {self.dataset_type} dataset.") + raise ValueError( + f"Missing required field 'stars' in {self.dataset_type} dataset." + ) elif self.dataset_type == "inference": pass else: @@ -179,10 +183,10 @@ def _convert_dataset_to_tensorflow(self): self.dataset["positions"], dtype=tf.float32 ) if self.dataset_type == "training": - self.dataset["noisy_stars"] = tf.convert_to_tensor( - self.dataset["noisy_stars"], dtype=tf.float32 - ) - + self.dataset["noisy_stars"] = tf.convert_to_tensor( + self.dataset["noisy_stars"], dtype=tf.float32 + ) + elif self.dataset_type == "test": if "stars" in self.dataset: self.dataset["stars"] = tf.convert_to_tensor( From c49ebc7b3f6d80150617d800a999f035c6b880b2 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sun, 18 May 2025 12:22:27 +0200 Subject: [PATCH 059/129] Refactor data_handler with new utility functions to validate and process datasets and update docstrings --- src/wf_psf/data/data_handler.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index 6ec6ec56..7ec53b59 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -188,10 +188,9 @@ def _convert_dataset_to_tensorflow(self): ) elif self.dataset_type == "test": - if "stars" in self.dataset: - self.dataset["stars"] = tf.convert_to_tensor( - self.dataset["stars"], dtype=tf.float32 - ) + self.dataset["stars"] = tf.convert_to_tensor( + self.dataset["stars"], dtype=tf.float32 + ) def process_sed_data(self, sed_data): """ From 49039287defcdf4cfbeedc4b11498066e0a2325e Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sun, 18 May 2025 12:23:27 +0200 Subject: [PATCH 060/129] Update unit tests associated to changes in data_handler.py --- src/wf_psf/data/data_handler.py | 42 +++++-------------- .../tests/test_data/data_handler_test.py | 15 +++++++ 2 files changed, 25 insertions(+), 32 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index 7ec53b59..7e7b3881 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -188,43 +188,21 @@ def _convert_dataset_to_tensorflow(self): ) elif self.dataset_type == "test": - self.dataset["stars"] = tf.convert_to_tensor( - self.dataset["stars"], dtype=tf.float32 - ) + if "stars" in self.dataset: + self.dataset["stars"] = tf.convert_to_tensor( + self.dataset["stars"], dtype=tf.float32 + ) + else: + logger.warning(f"Missing 'stars' in {self.dataset_type} dataset.") + elif "inference" == self.dataset_type: + pass def process_sed_data(self, sed_data): - """ - Generate and process SED (Spectral Energy Distribution) data. - - This method transforms raw SED inputs into TensorFlow tensors suitable for model input. - It generates wavelength-binned SED elements using the PSF simulator, converts the result - into a tensor, and transposes it to match the expected shape for training or inference. - - Parameters - ---------- - sed_data : list or array-like - A list or array of raw SEDs, where each SED is typically a vector of flux values - or coefficients. These will be processed using the PSF simulator. + """Process SED Data. - Raises - ------ - ValueError - If `sed_data` is None. + A method to generate and process SED data. - Notes - ----- - The resulting tensor is stored in `self.sed_data` and has shape - `(num_samples, n_bins_lambda, n_components)`, where: - - `num_samples` is the number of SEDs, - - `n_bins_lambda` is the number of wavelength bins, - - `n_components` is the number of components per SED (e.g., filters or basis terms). - - The intermediate tensor is created with `tf.float64` for precision during generation, - but is converted to `tf.float32` after processing for use in training. """ - if sed_data is None: - raise ValueError("SED data must be provided explicitly or via dataset.") - self.sed_data = [ utils.generate_SED_elems_in_tensorflow( _sed, self.simPSF, n_bins=self.n_bins_lambda, tf_dtype=tf.float64 diff --git a/src/wf_psf/tests/test_data/data_handler_test.py b/src/wf_psf/tests/test_data/data_handler_test.py index 8d1b5421..6545964a 100644 --- a/src/wf_psf/tests/test_data/data_handler_test.py +++ b/src/wf_psf/tests/test_data/data_handler_test.py @@ -150,6 +150,21 @@ def test_load_test_dataset_missing_stars(tmp_path, simPSF): data_handler.load_dataset() data_handler.validate_and_process_dataset() + data_handler = DataHandler( + dataset_type="train", + data_params=data_params, + simPSF=simPSF, + n_bins_lambda=10, + load_data=False + ) + + data_handler.load_dataset() + data_handler.process_sed_data(mock_dataset["SEDs"]) + + with patch("wf_psf.data.data_handler.logger.warning") as mock_warning: + data_handler._validate_dataset_structure() + mock_warning.assert_called_with("Missing 'noisy_stars' in 'train' dataset.") + def test_get_obs_positions(mock_data): observed_positions = get_obs_positions(mock_data) From 152b0d83e700597a8b9b6d4d891063b212da59a8 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 14:45:24 +0200 Subject: [PATCH 061/129] automatic formatting --- src/wf_psf/data/data_handler.py | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index 7e7b3881..b25881a2 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -198,11 +198,38 @@ def _convert_dataset_to_tensorflow(self): pass def process_sed_data(self, sed_data): - """Process SED Data. + """ + Generate and process SED (Spectral Energy Distribution) data. + + This method transforms raw SED inputs into TensorFlow tensors suitable for model input. + It generates wavelength-binned SED elements using the PSF simulator, converts the result + into a tensor, and transposes it to match the expected shape for training or inference. - A method to generate and process SED data. + Parameters + ---------- + sed_data : list or array-like + A list or array of raw SEDs, where each SED is typically a vector of flux values + or coefficients. These will be processed using the PSF simulator. + Raises + ------ + ValueError + If `sed_data` is None. + + Notes + ----- + The resulting tensor is stored in `self.sed_data` and has shape + `(num_samples, n_bins_lambda, n_components)`, where: + - `num_samples` is the number of SEDs, + - `n_bins_lambda` is the number of wavelength bins, + - `n_components` is the number of components per SED (e.g., filters or basis terms). + + The intermediate tensor is created with `tf.float64` for precision during generation, + but is converted to `tf.float32` after processing for use in training. """ + if sed_data is None: + raise ValueError("SED data must be provided explicitly or via dataset.") + self.sed_data = [ utils.generate_SED_elems_in_tensorflow( _sed, self.simPSF, n_bins=self.n_bins_lambda, tf_dtype=tf.float64 From de0892cd25a1e412e93aa0aaaf160a357ce70e87 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Wed, 18 Jun 2025 23:16:15 +0200 Subject: [PATCH 062/129] Refactor: add ZernikeInputs dataclass, ZernikeInputsFactory, helper methods for assembling zernike contributions according to run_type mode: training, simulation, or inference --- src/wf_psf/data/data_zernike_utils.py | 207 +++++++++++++++++++------- 1 file changed, 152 insertions(+), 55 deletions(-) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index 760adb11..648b260c 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -11,14 +11,89 @@ """ +from dataclasses import dataclass +from typing import Optional, Union import numpy as np import tensorflow as tf -from typing import Optional +from wf_psf.utils.read_config import RecursiveNamespace import logging logger = logging.getLogger(__name__) +@dataclass +class ZernikeInputs: + zernike_prior: Optional[np.ndarray] # true prior, if provided (e.g. from PDC) + centroid_dataset: Optional[Union[dict, 'RecursiveNamespace']] # only used in training/simulation + misalignment_positions: Optional[np.ndarray] # needed for CCD corrections + batch_size: int + + +class ZernikeInputsFactory: + @staticmethod + def build(data, run_type: str, model_params, prior: Optional[np.ndarray] = None) -> ZernikeInputs: + """Builds a ZernikeInputs dataclass instance based on run type and data. + + Parameters + ---------- + data : Union[dict, DataConfigHandler] + Dataset object containing star positions, priors, and optionally pixel data. + run_type : str + One of 'training', 'simulation', or 'inference'. + model_params : RecursiveNamespace + Model parameters, including flags for prior/corrections. + prior : Optional[np.ndarray] + An explicitly passed prior (overrides any inferred one if provided). + + Returns + ------- + ZernikeInputs + """ + centroid_dataset = None + positions = None + + if run_type in {"training", "simulation"}: + centroid_dataset = data # Assuming data is a DataConfigHandler or similar object containing train and test datasets + positions = np.concatenate( + [ + data.training_dataset["positions"], + data.test_dataset["positions"] + ], + axis=0, + ) + + if model_params.use_prior: + if prior is not None: + logger.warning( + "Zernike prior explicitly provided; ignoring dataset-based prior despite use_prior=True." + ) + else: + prior = get_np_zernike_prior(data) + + elif run_type == "inference": + centroid_dataset = None + positions = data["positions"] + + if model_params.use_prior: + # Try to extract prior from `data`, if present + prior = getattr(data, "zernike_prior", None) if not isinstance(data, dict) else data.get("zernike_prior") + + if prior is None: + logger.warning( + "model_params.use_prior=True but no prior found in inference data. Proceeding with None." + ) + + else: + raise ValueError(f"Unsupported run_type: {run_type}") + + return ZernikeInputs( + zernike_prior=prior, + centroid_dataset=centroid_dataset, + misalignment_positions=positions, + batch_size=model_params.batch_size, + ) + + def get_np_zernike_prior(data): """Get the zernike prior from the provided dataset. @@ -45,80 +120,102 @@ def get_np_zernike_prior(data): return zernike_prior - -def get_zernike_prior(model_params, data, batch_size: int=16): - """Get Zernike priors from the provided dataset. - - This method concatenates the Zernike priors from both the training - and test datasets. +def pad_contribution_to_order(contribution: np.ndarray, max_order: int) -> np.ndarray: + """Pad a Zernike contribution array to the max Zernike order.""" + current_order = contribution.shape[1] + pad_width = ((0, 0), (0, max_order - current_order)) + return np.pad(contribution, pad_width=pad_width, mode="constant", constant_values=0) + +def combine_zernike_contributions(contributions: list[np.ndarray]) -> np.ndarray: + """Combine multiple Zernike contributions, padding each to the max order before summing.""" + if not contributions: + raise ValueError("No contributions provided.") + + max_order = max(contrib.shape[1] for contrib in contributions) + n_samples = contributions[0].shape[0] + + combined = np.zeros((n_samples, max_order), dtype=np.float32) + for contrib in contributions: + padded = pad_contribution_to_order(contrib, max_order) + combined += padded + + return combined + +def assemble_zernike_contributions( + model_params, + zernike_prior=None, + centroid_dataset=None, + positions=None, + batch_size=16, +): + """ + Assemble the total Zernike contribution map by combining the prior, + centroid correction, and CCD misalignment correction. Parameters ---------- model_params : RecursiveNamespace - Object containing parameters for this PSF model class. - data : DataConfigHandler - Object containing training and test datasets. - batch_size : int, optional - The batch size to use when processing the stars. Default is 16. + Parameters controlling which contributions to apply. + zernike_prior : Optional[np.ndarray or tf.Tensor] + The precomputed Zernike prior (e.g., from PDC or another model). + centroid_dataset : Optional[object] + Dataset used to compute centroid correction. Must have both training and test sets. + positions : Optional[np.ndarray or tf.Tensor] + Positions used for computing CCD misalignment. Must be available in inference mode. + batch_size : int + Batch size for centroid correction. Returns ------- tf.Tensor - Tensor containing the observed positions of the stars. - - Notes - ----- - The Zernike prior are obtained by concatenating the Zernike priors - from both the training and test datasets along the 0th axis. - + A tensor representing the full Zernike contribution map. """ - # List of zernike contribution + zernike_contribution_list = [] - if model_params.use_prior: - logger.info("Reading in Zernike prior into Zernike contribution list...") - zernike_contribution_list.append(get_np_zernike_prior(data)) + # Prior + if model_params.use_prior and zernike_prior is not None: + logger.info("Adding Zernike prior...") + if isinstance(zernike_prior, np.ndarray): + zernike_prior = tf.convert_to_tensor(zernike_prior, dtype=tf.float32) + zernike_contribution_list.append(zernike_prior) + else: + logger.info("Skipping Zernike prior (not used or not provided).") - if model_params.correct_centroids: - logger.info("Adding centroid correction to Zernike contribution list...") + # Centroid correction (tip/tilt) + if model_params.correct_centroids and centroid_dataset is not None: + logger.info("Computing centroid correction...") + centroid_correction = compute_centroid_correction( + model_params, centroid_dataset, batch_size=batch_size + ) zernike_contribution_list.append( - compute_centroid_correction(model_params, data, batch_size) + tf.convert_to_tensor(centroid_correction, dtype=tf.float32) ) - - if model_params.add_ccd_misalignments: - logger.info("Adding CCD mis-alignments to Zernike contribution list...") - zernike_contribution_list.append(compute_ccd_misalignment(model_params, data)) - - if len(zernike_contribution_list) == 1: - zernike_contribution = zernike_contribution_list[0] else: - # Get max zk order - max_zk_order = np.max( - np.array( - [ - zk_contribution.shape[1] - for zk_contribution in zernike_contribution_list - ] - ) - ) + logger.info("Skipping centroid correction (not enabled or no dataset).") - zernike_contribution = np.zeros( - (zernike_contribution_list[0].shape[0], max_zk_order) + # CCD misalignment (focus term) + if model_params.add_ccd_misalignments and positions is not None: + logger.info("Computing CCD misalignment correction...") + ccd_misalignment = compute_ccd_misalignment(model_params, positions) + zernike_contribution_list.append( + tf.convert_to_tensor(ccd_misalignment, dtype=tf.float32) ) + else: + logger.info("Skipping CCD misalignment correction (not enabled or no positions).") - # Pad arrays to get the same length and add the final contribution - for it in range(len(zernike_contribution_list)): - current_zk_order = zernike_contribution_list[it].shape[1] - current_zernike_contribution = np.pad( - zernike_contribution_list[it], - pad_width=[(0, 0), (0, int(max_zk_order - current_zk_order))], - mode="constant", - constant_values=0, - ) + # If no contributions, return zeros tensor to avoid crashes + if not zernike_contribution_list: + logger.warning("No Zernike contributions found. Returning zero tensor.") + # Infer batch size and zernike order from model_params + n_samples = 1 + n_zks = getattr(model_params.param_hparams, "n_zernikes", 10) + return tf.zeros((n_samples, n_zks), dtype=tf.float32) - zernike_contribution += current_zernike_contribution + combined_zernike_prior = combine_zernike_contributions(zernike_contribution_list) - return tf.convert_to_tensor(zernike_contribution, dtype=tf.float32) + return tf.convert_to_tensor(combined_zernike_prior, dtype=tf.float32) + def shift_x_y_to_zk1_2_wavediff(dxy, tel_focal_length=24.5, tel_diameter=1.2): From 8db8de74094704e1fbc6ecc534b9062dda2c580f Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Wed, 18 Jun 2025 23:18:57 +0200 Subject: [PATCH 063/129] Update docstring describing data_conf types permitted --- src/wf_psf/psf_models/psf_model_loader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/wf_psf/psf_models/psf_model_loader.py b/src/wf_psf/psf_models/psf_model_loader.py index 1d2e267f..797be8fc 100644 --- a/src/wf_psf/psf_models/psf_model_loader.py +++ b/src/wf_psf/psf_models/psf_model_loader.py @@ -25,8 +25,8 @@ def load_trained_psf_model(training_conf, data_conf, weights_path_pattern): training_conf : RecursiveNamespace Configuration object containing model parameters and training hyperparameters. Supports attribute-style access to nested fields. - data_conf : RecursiveNamespace - Configuration object containing data-related parameters. + data_conf : RecursiveNamespace or dict + Configuration RecursiveNamespace object or a dictionary containing data parameters (e.g. pixel data, positions, masks, etc). weights_path_pattern : str Glob-style pattern used to locate the model weights file. From 9f7f19e20df01907aec98a8a84289a5ec00ed9f0 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sat, 21 Jun 2025 18:45:41 +0200 Subject: [PATCH 064/129] Move imports to method to avoid circular imports --- src/wf_psf/data/centroids.py | 2 +- src/wf_psf/instrument/ccd_misalignments.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/wf_psf/data/centroids.py b/src/wf_psf/data/centroids.py index 83ccf8c7..97cb069f 100644 --- a/src/wf_psf/data/centroids.py +++ b/src/wf_psf/data/centroids.py @@ -10,7 +10,6 @@ import scipy.signal as scisig from wf_psf.data.data_handler import extract_star_data from fractions import Fraction -from wf_psf.data.data_zernike_utils import shift_x_y_to_zk1_2_wavediff import tensorflow as tf from typing import Optional @@ -127,6 +126,7 @@ def compute_zernike_tip_tilt( - This function processes all images at once using vectorized operations. - The Zernike coefficients are computed in the WaveDiff convention. """ + from wf_psf.data.data_zernike_utils import shift_x_y_to_zk1_2_wavediff # Vectorize the centroid computation centroid_estimator = CentroidEstimator( im=star_images, mask=star_masks, sigma_init=sigma_init, n_iter=n_iter diff --git a/src/wf_psf/instrument/ccd_misalignments.py b/src/wf_psf/instrument/ccd_misalignments.py index 35343886..76386f20 100644 --- a/src/wf_psf/instrument/ccd_misalignments.py +++ b/src/wf_psf/instrument/ccd_misalignments.py @@ -10,7 +10,7 @@ import numpy as np import matplotlib.path as mpltPath from scipy.spatial import KDTree -from wf_psf.data.data_preprocessing import defocus_to_zk4_wavediff +from wf_psf.data.data_handler import get_np_obs_positions def compute_ccd_misalignment(model_params, data): @@ -386,6 +386,8 @@ def get_zk4_from_position(self, pos): Zernike 4 value in wavediff convention corresponding to the delta z of the given input position `pos`. """ + from wf_psf.data.data_zernike_utils import defocus_to_zk4_wavediff + dz = self.get_dz_from_position(pos) return defocus_to_zk4_wavediff(dz, self.tel_focal_length, self.tel_diameter) From 3c84c1f4dcdd4912cc785d045b56fb7ccd80bae6 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sat, 21 Jun 2025 18:47:35 +0200 Subject: [PATCH 065/129] Remove batch_size arg from ZernikeInputsFactory ; raise ValueError to check Zernike contributions have the same number of samples --- src/wf_psf/data/data_zernike_utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index 648b260c..b03ff400 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -15,6 +15,8 @@ from typing import Optional, Union import numpy as np import tensorflow as tf +from wf_psf.data.centroids import compute_centroid_correction +from wf_psf.instrument.ccd_misalignments import compute_ccd_misalignment from wf_psf.utils.read_config import RecursiveNamespace import logging @@ -26,7 +28,6 @@ class ZernikeInputs: zernike_prior: Optional[np.ndarray] # true prior, if provided (e.g. from PDC) centroid_dataset: Optional[Union[dict, 'RecursiveNamespace']] # only used in training/simulation misalignment_positions: Optional[np.ndarray] # needed for CCD corrections - batch_size: int class ZernikeInputsFactory: @@ -89,8 +90,7 @@ def build(data, run_type: str, model_params, prior: Optional[np.ndarray] = None) return ZernikeInputs( zernike_prior=prior, centroid_dataset=centroid_dataset, - misalignment_positions=positions, - batch_size=model_params.batch_size, + misalignment_positions=positions ) @@ -133,6 +133,8 @@ def combine_zernike_contributions(contributions: list[np.ndarray]) -> np.ndarray max_order = max(contrib.shape[1] for contrib in contributions) n_samples = contributions[0].shape[0] + if any(c.shape[0] != n_samples for c in contributions): + raise ValueError("All contributions must have the same number of samples.") combined = np.zeros((n_samples, max_order), dtype=np.float32) for contrib in contributions: From 362ed76396c78b9ce19c939a3d69ab1a8930894a Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sat, 21 Jun 2025 18:49:32 +0200 Subject: [PATCH 066/129] Add and set run_type attribute to DataConfigHandler object in TrainingConfigHandler constructor --- src/wf_psf/utils/configs_handler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/wf_psf/utils/configs_handler.py b/src/wf_psf/utils/configs_handler.py index a7e42ceb..f59fdc17 100644 --- a/src/wf_psf/utils/configs_handler.py +++ b/src/wf_psf/utils/configs_handler.py @@ -188,6 +188,7 @@ def __init__(self, training_conf, file_handler): self.training_conf.training.training_hparams.batch_size, self.training_conf.training.load_data_on_init, ) + self.data_conf.run_type = "training" self.file_handler.copy_conffile_to_output_dir( self.training_conf.training.data_config ) From c8f060ec44e7480f57acc1ee3549363156ed144b Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sat, 21 Jun 2025 18:51:04 +0200 Subject: [PATCH 067/129] Add and set run_type attribute ; Replace var name end with end_sample in PSFInferenceEngine --- src/wf_psf/inference/psf_inference.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 6f798245..5be49343 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -162,6 +162,7 @@ def data_handler(self): load_data=False, dataset=None, ) + self._data_handler.run_type = "inference" return self._data_handler @property @@ -272,7 +273,7 @@ def compute_psfs(self, positions: tf.Tensor, sed_data: tf.Tensor) -> np.ndarray: counter = 0 while counter < n_samples: # Calculate the batch end element - end = min(counter + self.batch_size, n_samples) + end_sample = min(counter + self.batch_size, n_samples) # Define the batch positions batch_pos = positions[counter:end_sample, :] @@ -281,10 +282,10 @@ def compute_psfs(self, positions: tf.Tensor, sed_data: tf.Tensor) -> np.ndarray: # Generate PSFs for the current batch batch_psfs = self.trained_model(batch_inputs, training=False) - self.inferred_psfs[counter:end, :, :] = batch_psfs.numpy() + self.inferred_psfs[counter:end_sample, :, :] = batch_psfs.numpy() # Update the counter - counter = end + counter = end_sample return self._inferred_psfs From 8458a9ad73bc0d862ac270755e8ba900712ad7c7 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sat, 21 Jun 2025 18:52:26 +0200 Subject: [PATCH 068/129] Refactor TFPhysicalPolychromaticField to lazy load property objects and attributes dynamically at run-time according to the run_type: training or inference --- .../psf_model_physical_polychromatic.py | 339 ++++++++---------- 1 file changed, 148 insertions(+), 191 deletions(-) diff --git a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index baec8923..971dbb63 100644 --- a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -11,7 +11,7 @@ import tensorflow as tf from tensorflow.python.keras.engine import data_adapter from wf_psf.data.data_handler import get_obs_positions -from wf_psf.data.data_zernike_utils import get_zernike_prior +from wf_psf.data.data_zernike_utils import ZernikeInputsFactory, assemble_zernike_contributions from wf_psf.psf_models import psf_models as psfm from wf_psf.psf_models.tf_modules.tf_layers import ( TFPolynomialZernikeField, @@ -98,8 +98,8 @@ def __init__(self, model_params, training_params, data, coeff_mat=None): A Recursive Namespace object containing parameters for this PSF model class. training_params: Recursive Namespace A Recursive Namespace object containing training hyperparameters for this PSF model class. - data: DataConfigHandler - A DataConfigHandler object that provides access to training and test datasets, as well as prior knowledge like Zernike coefficients. + data: DataConfigHandler or dict + A DataConfigHandler object or dict that provides access to single or multiple datasets (e.g. train and test), as well as prior knowledge like Zernike coefficients. coeff_mat: Tensor or None, optional Coefficient matrix defining the parametric PSF field model. @@ -109,204 +109,151 @@ def __init__(self, model_params, training_params, data, coeff_mat=None): Initialized instance of the TFPhysicalPolychromaticField class. """ super().__init__(model_params, training_params, coeff_mat) - self._initialize_parameters_and_layers( - model_params, training_params, data, coeff_mat - ) + self.model_params = model_params + self.training_params = training_params + self.data = data + self.run_type = data.run_type - def _initialize_parameters_and_layers( - self, - model_params: RecursiveNamespace, - training_params: RecursiveNamespace, - data: DataConfigHandler, - coeff_mat: Optional[tf.Tensor] = None, - ): - """Initialize Parameters of the PSF model. - - This method sets up the PSF model parameters, observational positions, - Zernike coefficients, and components required for the automatically - differentiable optical forward model. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - training_params: Recursive Namespace - A Recursive Namespace object containing training hyperparameters for this PSF model class. - data: DataConfigHandler object - A DataConfigHandler object providing access to training and tests datasets, as well as prior knowledge like Zernike coefficients. - coeff_mat: Tensor or None, optional - Initialization of the coefficient matrix defining the parametric psf field model. - - Notes - ----- - - Initializes Zernike parameters based on dataset priors. - - Configures the PSF model layers according to `model_params`. - - If `coeff_mat` is provided, the model coefficients are updated accordingly. - """ + # Initialize the model parameters and layers self.output_Q = model_params.output_Q - self.obs_pos = get_obs_positions(data) self.l2_param = model_params.param_hparams.l2_param - # Inputs: Save optimiser history Parametric model features - self.save_optim_history_param = ( - model_params.param_hparams.save_optim_history_param - ) - # Inputs: Save optimiser history NonParameteric model features - self.save_optim_history_nonparam = ( - model_params.nonparam_hparams.save_optim_history_nonparam - ) - self._initialize_zernike_parameters(model_params, data) - self._initialize_layers(model_params, training_params) + self.output_dim = model_params.output_dim + + # Initialise lazy loading of external Zernike prior + self._external_prior = None # Initialize the model parameters with non-default value if coeff_mat is not None: self.assign_coeff_matrix(coeff_mat) - def _initialize_zernike_parameters(self, model_params, data): - """Initialize the Zernike parameters. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - data: DataConfigHandler object - A DataConfigHandler object providing access to training and tests datasets, as well as prior knowledge like Zernike coefficients. - """ - self.zks_prior = get_zernike_prior(model_params, data, data.batch_size) - self.n_zks_total = max( - model_params.param_hparams.n_zernikes, - tf.cast(tf.shape(self.zks_prior)[1], tf.int32), + def _assemble_zernike_contributions(self): + zks_inputs = ZernikeInputsFactory.build( + data=self.data, + run_type=self.run_type, + model_params=self.model_params, + prior=self._external_prior if hasattr(self, "_external_prior") else None, ) - self.zernike_maps = psfm.generate_zernike_maps_3d( - self.n_zks_total, model_params.pupil_diameter + return assemble_zernike_contributions( + model_params=self.model_params, + zernike_prior=zks_inputs.zernike_prior, + centroid_dataset=zks_inputs.centroid_dataset, + positions=zks_inputs.misalignment_positions, + batch_size=self.training_params.batch_size, ) - def _initialize_layers(self, model_params, training_params): - """Initialize the layers of the PSF model. - - This method initializes the layers of the PSF model, including the physical layer, polynomial Zernike field, batch polychromatic layer, and non-parametric OPD layer. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - training_params: Recursive Namespace - A Recursive Namespace object containing training hyperparameters for this PSF model class. - coeff_mat: Tensor or None, optional - Initialization of the coefficient matrix defining the parametric psf field model. - - """ - self._initialize_physical_layer(model_params) - self._initialize_polynomial_Z_field(model_params) - self._initialize_Zernike_OPD(model_params) - self._initialize_batch_polychromatic_layer(model_params, training_params) - self._initialize_nonparametric_opd_layer(model_params, training_params) - - def _initialize_physical_layer(self, model_params): - """Initialize the physical layer of the PSF model. - - This method initializes the physical layer of the PSF model using parameters - specified in the `model_params` object. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - """ - self.tf_physical_layer = TFPhysicalLayer( - self.obs_pos, - self.zks_prior, - interpolation_type=model_params.interpolation_type, - interpolation_args=model_params.interpolation_args, + @property + def save_param_history(self) -> bool: + """Check if the model should save the optimization history for parametric features.""" + return getattr(self.model_params.param_hparams, "save_optim_history_param", False) + + @property + def save_nonparam_history(self) -> bool: + """Check if the model should save the optimization history for non-parametric features.""" + return getattr(self.model_params.nonparam_hparams, "save_optim_history_nonparam", False) + + + # === Lazy properties ===. + @property + def obs_pos(self): + """Lazy loading of the observation positions.""" + if not hasattr(self, "_obs_pos"): + if self.run_type == "training" or self.run_type == "simulation": + # Get the observation positions from the data handler + self._obs_pos = get_obs_positions(self.data) + elif self.run_type == "inference": + # For inference, we might not have a data handler, so we use the model parameters + self._obs_pos = self.data.dataset["positions"] + return self._obs_pos + + @property + def zks_total_contribution(self): + """Lazily load all Zernike contributions, including prior and corrections.""" + if not hasattr(self, "_zks_total_contribution"): + self._zks_total_contribution = self._assemble_zernike_contributions() + return self._zks_total_contribution + + @property + def n_zks_total(self): + """Get the total number of Zernike coefficients.""" + if not hasattr(self, "_n_zks_total"): + self._n_zks_total = max( + self.model_params.param_hparams.n_zernikes, + tf.cast(tf.shape(self.zks_total_contribution)[1], tf.int32), ) - - def _initialize_polynomial_Z_field(self, model_params): - """Initialize the polynomial Zernike field of the PSF model. - - This method initializes the polynomial Zernike field of the PSF model using - parameters specified in the `model_params` object. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - - """ - self.tf_poly_Z_field = TFPolynomialZernikeField( - x_lims=model_params.x_lims, - y_lims=model_params.y_lims, - random_seed=model_params.param_hparams.random_seed, - n_zernikes=model_params.param_hparams.n_zernikes, - d_max=model_params.param_hparams.d_max, + return self._n_zks_total + + @property + def zernike_maps(self): + """Lazy loading of the Zernike maps.""" + if not hasattr(self, "_zernike_maps"): + self._zernike_maps = psfm.generate_zernike_maps_3d( + self.n_zks_total, self.model_params.pupil_diameter ) + return self._zernike_maps + + @property + def tf_poly_Z_field(self): + """Lazy loading of the polynomial Zernike field layer.""" + if not hasattr(self, "_tf_poly_Z_field"): + self._tf_poly_Z_field = TFPolynomialZernikeField( + x_lims=self.model_params.x_lims, + y_lims=self.model_params.y_lims, + random_seed=self.model_params.param_hparams.random_seed, + n_zernikes=self.model_params.param_hparams.n_zernikes, + d_max=self.model_params.param_hparams.d_max, + ) + return self._tf_poly_Z_field - def _initialize_Zernike_OPD(self, model_params): - """Initialize the Zernike OPD field of the PSF model. - - This method initializes the Zernike Optical Path Difference - field of the PSF model using parameters specified in the `model_params` object. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - - """ - # Initialize the zernike to OPD layer - self.tf_zernike_OPD = TFZernikeOPD(zernike_maps=self.zernike_maps) - - def _initialize_batch_polychromatic_layer(self, model_params, training_params): - """Initialize the batch polychromatic PSF layer. - - This method initializes the batch opd to batch polychromatic PSF layer - using the provided `model_params` and `training_params`. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - training_params: Recursive Namespace - A Recursive Namespace object containing training hyperparameters for this PSF model class. - - - """ - self.batch_size = training_params.batch_size - self.obscurations = psfm.tf_obscurations( - pupil_diam=model_params.pupil_diameter, - N_filter=model_params.LP_filter_length, - rotation_angle=model_params.obscuration_rotation_angle, - ) - self.output_dim = model_params.output_dim + @tf_poly_Z_field.deleter + def tf_poly_Z_field(self): + del self._tf_poly_Z_field - self.tf_batch_poly_PSF = TFBatchPolychromaticPSF( - obscurations=self.obscurations, - output_Q=self.output_Q, - output_dim=self.output_dim, + @property + def tf_physical_layer(self): + """Lazy loading of the physical layer of the PSF model.""" + if not hasattr(self, "_tf_physical_layer"): + self._tf_physical_layer = TFPhysicalLayer( + self.obs_pos, + self.zks_total_contribution, + interpolation_type=self.model_params.interpolation_type, + interpolation_args=self.model_params.interpolation_args, ) + + @property + def tf_zernike_OPD(self): + """Lazy loading of the Zernike Optical Path Difference (OPD) layer.""" + if not hasattr(self, "_tf_zernike_OPD"): + self._tf_zernike_OPD = TFZernikeOPD(zernike_maps=self.zernike_maps) + return self._tf_zernike_OPD + + @property + def tf_batch_poly_PSF(self): + """Lazily initialize the batch polychromatic PSF layer.""" + if not hasattr(self, "_tf_batch_poly_PSF"): + obscurations = psfm.tf_obscurations( + pupil_diam=self.model_params.pupil_diameter, + N_filter=self.model_params.LP_filter_length, + rotation_angle=self.model_params.obscuration_rotation_angle, + ) - def _initialize_nonparametric_opd_layer(self, model_params, training_params): - """Initialize the non-parametric OPD layer. - - This method initializes the non-parametric OPD layer using the provided - `model_params` and `training_params`. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - training_params: Recursive Namespace - A Recursive Namespace object containing training hyperparameters for this PSF model class. - - """ - # self.d_max_nonparam = model_params.nonparam_hparams.d_max_nonparam - # self.opd_dim = tf.shape(self.zernike_maps)[1].numpy() - - self.tf_np_poly_opd = TFNonParametricPolynomialVariationsOPD( - x_lims=model_params.x_lims, - y_lims=model_params.y_lims, - random_seed=model_params.param_hparams.random_seed, - d_max=model_params.nonparam_hparams.d_max_nonparam, - opd_dim=tf.shape(self.zernike_maps)[1].numpy(), - ) + self._tf_batch_poly_PSF = TFBatchPolychromaticPSF( + obscurations=obscurations, + output_Q=self.output_Q, + output_dim=self.output_dim, + ) + return self._tf_batch_poly_PSF + + @property + def tf_np_poly_opd(self): + """Lazy loading of the non-parametric polynomial variations OPD layer.""" + if not hasattr(self, "_tf_np_poly_opd"): + self._tf_np_poly_opd = TFNonParametricPolynomialVariationsOPD( + x_lims=self.model_params.x_lims, + y_lims=self.model_params.y_lims, + d_max=self.model_params.nonparam_hparams.d_max_nonparam, + opd_dim=tf.shape(self.zernike_maps)[1].numpy(), + ) + return self._tf_np_poly_opd def get_coeff_matrix(self): """Get coefficient matrix.""" @@ -331,23 +278,21 @@ def assign_coeff_matrix(self, coeff_mat: Optional[tf.Tensor]) -> None: """ self.tf_poly_Z_field.assign_coeff_matrix(coeff_mat) + def set_output_Q(self, output_Q: float, output_dim: Optional[int] = None) -> None: """Set the output sampling rate (output_Q) for PSF generation. This method updates the `output_Q` parameter, which defines the resampling factor for generating PSFs at different resolutions - relative to the telescope's native sampling. It also allows optionally - updating `output_dim`, which sets the output resolution of the PSF model. + relative to the telescope's native sampling. It also allows optionally updating `output_dim`, which sets the output resolution of the PSF model. If `output_dim` is provided, the PSF model's output resolution is updated. - The method then reinitializes the batch polychromatic PSF generator - to reflect the updated parameters. + The method then reinitializes the batch polychromatic PSF generator to reflect the updated parameters. Parameters ---------- output_Q : float - The resampling factor that determines the output PSF resolution - relative to the telescope's native sampling. + The resampling factor that determines the output PSF resolution relative to the telescope's native sampling. output_dim : Optional[int], default=None The new output dimension for the PSF model. If `None`, the output dimension remains unchanged. @@ -359,6 +304,7 @@ def set_output_Q(self, output_Q: float, output_dim: Optional[int] = None) -> Non self.output_Q = output_Q if output_dim is not None: self.output_dim = output_dim + # Reinitialize the PSF batch poly generator self.tf_batch_poly_PSF = TFBatchPolychromaticPSF( obscurations=self.obscurations, @@ -472,12 +418,16 @@ def predict_step(self, data, evaluate_step=False): # Compute zernikes from parametric model and physical layer zks_coeffs = self.predict_zernikes(input_positions) + # Propagate to obtain the OPD param_opd_maps = self.tf_zernike_OPD(zks_coeffs) + # Calculate the non parametric part nonparam_opd_maps = self.tf_np_poly_opd(input_positions) + # Add the estimations opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) + # Compute the polychromatic PSFs poly_psfs = self.tf_batch_poly_PSF([opd_maps, packed_SEDs]) @@ -520,10 +470,13 @@ def predict_mono_psfs(self, input_positions, lambda_obs, phase_N): # Predict zernikes from parametric model and physical layer zks_coeffs = self.predict_zernikes(input_positions) + # Propagate to obtain the OPD param_opd_maps = self.tf_zernike_OPD(zks_coeffs) + # Calculate the non parametric part nonparam_opd_maps = self.tf_np_poly_opd(input_positions) + # Add the estimations opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) @@ -548,10 +501,13 @@ def predict_opd(self, input_positions): """ # Predict zernikes from parametric model and physical layer zks_coeffs = self.predict_zernikes(input_positions) + # Propagate to obtain the OPD param_opd_maps = self.tf_zernike_OPD(zks_coeffs) + # Calculate the non parametric part nonparam_opd_maps = self.tf_np_poly_opd(input_positions) + # Add the estimations opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) @@ -589,6 +545,7 @@ def compute_zernikes(self, input_positions): padded_zernike_params, padded_zernike_prior = self.pad_zernikes( zernike_params, zernike_prior ) + zernike_coeffs = tf.math.add(padded_zernike_params, padded_zernike_prior) return zernike_coeffs @@ -685,7 +642,7 @@ def call(self, inputs, training=True): packed_SEDs = inputs[1] # For the training - if training: + if training: # Compute zernikes from parametric model and physical layer zks_coeffs = self.compute_zernikes(input_positions) From 61fb8dc9035036b2d53282807cbd908d48b02d0a Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sat, 21 Jun 2025 18:53:46 +0200 Subject: [PATCH 069/129] Update/Add unit tests to test refactoring changes to psf_model_physical_polychromatic.py and data_zernike_utils.py --- .../test_data/data_zernike_utils_test.py | 296 ++++++++++++++++-- .../psf_model_physical_polychromatic_test.py | 202 ++---------- 2 files changed, 310 insertions(+), 188 deletions(-) diff --git a/src/wf_psf/tests/test_data/data_zernike_utils_test.py b/src/wf_psf/tests/test_data/data_zernike_utils_test.py index 692624be..90994dde 100644 --- a/src/wf_psf/tests/test_data/data_zernike_utils_test.py +++ b/src/wf_psf/tests/test_data/data_zernike_utils_test.py @@ -1,40 +1,296 @@ import pytest import numpy as np +from unittest.mock import MagicMock, patch import tensorflow as tf from wf_psf.data.data_zernike_utils import ( - get_zernike_prior, + ZernikeInputs, + ZernikeInputsFactory, + get_np_zernike_prior, + pad_contribution_to_order, + combine_zernike_contributions, + assemble_zernike_contributions, compute_zernike_tip_tilt, ) from wf_psf.tests.test_data.test_data_utils import MockData, MockDataset +from types import SimpleNamespace as RecursiveNamespace -def test_get_zernike_prior(model_params, mock_data): - zernike_priors = get_zernike_prior(model_params, mock_data) - expected_shape = ( - 4, - 2, - ) # Assuming 2 Zernike priors for each dataset (training and test) - assert zernike_priors.shape == expected_shape +@pytest.fixture +def mock_model_params(): + return RecursiveNamespace( + use_prior=True, + correct_centroids=True, + add_ccd_misalignments=True, + param_hparams=RecursiveNamespace(n_zernikes=6), + ) + +@pytest.fixture +def dummy_prior(): + return np.ones((4, 6), dtype=np.float32) + +@pytest.fixture +def dummy_positions(): + return np.random.rand(4, 2).astype(np.float32) + +@pytest.fixture +def dummy_centroid_dataset(): + return {"training": "dummy_train", "test": "dummy_test"} + +def test_training_without_prior(mock_model_params): + mock_model_params.use_prior = False + data = MagicMock() + data.training_dataset = {"positions": np.ones((2, 2))} + data.test_dataset = {"positions": np.zeros((3, 2))} + + zinputs = ZernikeInputsFactory.build(data=data, run_type="training", model_params=mock_model_params) + + assert zinputs.centroid_dataset is data + assert zinputs.zernike_prior is None + np.testing.assert_array_equal( + zinputs.misalignment_positions, + np.concatenate([data.training_dataset["positions"], data.test_dataset["positions"]]) + ) + +@patch("wf_psf.data.data_zernike_utils.get_np_zernike_prior") +def test_training_with_dataset_prior(mock_get_prior, mock_model_params): + mock_model_params.use_prior = True + data = MagicMock() + data.training_dataset = {"positions": np.ones((2, 2))} + data.test_dataset = {"positions": np.zeros((2, 2))} + mock_get_prior.return_value = np.array([1.0, 2.0, 3.0]) + + zinputs = ZernikeInputsFactory.build(data=data, run_type="training", model_params=mock_model_params) + + assert zinputs.zernike_prior.tolist() == [1.0, 2.0, 3.0] + mock_get_prior.assert_called_once_with(data) + +def test_training_with_explicit_prior(mock_model_params, caplog): + mock_model_params.use_prior = True + data = MagicMock() + data.training_dataset = {"positions": np.ones((1, 2))} + data.test_dataset = {"positions": np.zeros((1, 2))} + + explicit_prior = np.array([9.0, 9.0, 9.0]) + + with caplog.at_level("WARNING"): + zinputs = ZernikeInputsFactory.build(data, "training", mock_model_params, prior=explicit_prior) + + assert "Zernike prior explicitly provided" in caplog.text + assert (zinputs.zernike_prior == explicit_prior).all() + + +def test_inference_with_dict_and_prior(mock_model_params): + mock_model_params.use_prior = True + data = { + "positions": np.ones((5, 2)), + "zernike_prior": np.array([42.0, 0.0]) + } + + zinputs = ZernikeInputsFactory.build(data, "inference", mock_model_params) + + assert zinputs.centroid_dataset is None + assert (zinputs.zernike_prior == data["zernike_prior"]).all() + np.testing.assert_array_equal(zinputs.misalignment_positions, data["positions"]) + + +def test_invalid_run_type(mock_model_params): + data = {"positions": np.ones((2, 2))} + with pytest.raises(ValueError, match="Unsupported run_type"): + ZernikeInputsFactory.build(data, "invalid_mode", mock_model_params) -def test_get_zernike_prior_dtype(model_params, mock_data): - zernike_priors = get_zernike_prior(model_params, mock_data) - assert zernike_priors.dtype == np.float32 +def test_get_np_zernike_prior(): + # Mock training and test data + training_prior = np.array([[1, 2, 3], [4, 5, 6]]) + test_prior = np.array([[7, 8, 9]]) -def test_get_zernike_prior_concatenation(model_params, mock_data): - zernike_priors = get_zernike_prior(model_params, mock_data) - expected_zernike_priors = tf.convert_to_tensor( - np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]]), dtype=tf.float32 + # Construct fake DataConfigHandler structure using RecursiveNamespace + data = RecursiveNamespace( + training_data=RecursiveNamespace(dataset={"zernike_prior": training_prior}), + test_data=RecursiveNamespace(dataset={"zernike_prior": test_prior}) ) - assert np.array_equal(zernike_priors, expected_zernike_priors) + expected_prior = np.concatenate((training_prior, test_prior), axis=0) + result = get_np_zernike_prior(data) -def test_get_zernike_prior_empty_data(model_params): - empty_data = MockData(np.array([]), np.array([]), np.array([]), np.array([])) - zernike_priors = get_zernike_prior(model_params, empty_data) - assert zernike_priors.shape == tf.TensorShape([0]) # Check for empty array shape + # Assert shape and values match expected + np.testing.assert_array_equal(result, expected_prior) + +def test_pad_contribution_to_order(): + # Input: batch of 2 samples, each with 3 Zernike coefficients + input_contribution = np.array([ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ]) + + max_order = 5 # Target size: pad to 5 coefficients + + expected_output = np.array([ + [1.0, 2.0, 3.0, 0.0, 0.0], + [4.0, 5.0, 6.0, 0.0, 0.0], + ]) + + padded = pad_contribution_to_order(input_contribution, max_order) + + assert padded.shape == (2, 5), "Output shape should match padded shape" + np.testing.assert_array_equal(padded, expected_output) + + +def test_no_padding_needed(): + """If current order equals max_order, return should be unchanged.""" + input_contribution = np.array([[1, 2, 3], [4, 5, 6]]) + max_order = 3 + output = pad_contribution_to_order(input_contribution, max_order) + assert output.shape == input_contribution.shape + np.testing.assert_array_equal(output, input_contribution) + +def test_padding_to_much_higher_order(): + """Pad from order 2 to order 10.""" + input_contribution = np.array([[1, 2], [3, 4]]) + max_order = 10 + expected_output = np.hstack([input_contribution, np.zeros((2, 8))]) + output = pad_contribution_to_order(input_contribution, max_order) + assert output.shape == (2, 10) + np.testing.assert_array_equal(output, expected_output) + +def test_empty_contribution(): + """Test behavior with empty input array (0 features).""" + input_contribution = np.empty((3, 0)) # 3 samples, 0 coefficients + max_order = 4 + expected_output = np.zeros((3, 4)) + output = pad_contribution_to_order(input_contribution, max_order) + assert output.shape == (3, 4) + np.testing.assert_array_equal(output, expected_output) + + +def test_zero_samples(): + """Test with zero samples (empty batch).""" + input_contribution = np.empty((0, 3)) # 0 samples, 3 coefficients + max_order = 5 + expected_output = np.empty((0, 5)) + output = pad_contribution_to_order(input_contribution, max_order) + assert output.shape == (0, 5) + np.testing.assert_array_equal(output, expected_output) + + +def test_combine_zernike_contributions_basic_case(): + """Combine two contributions with matching sample count and varying order.""" + contrib1 = np.array([[1, 2], [3, 4]]) # shape (2, 2) + contrib2 = np.array([[5], [6]]) # shape (2, 1) + expected = np.array([ + [1 + 5, 2 + 0], + [3 + 6, 4 + 0] + ]) # padded contrib2 to (2, 2) + result = combine_zernike_contributions([contrib1, contrib2]) + np.testing.assert_array_equal(result, expected) + +def test_combine_multiple_contributions(): + """Combine three contributions.""" + c1 = np.array([[1, 2, 3]]) # shape (1, 3) + c2 = np.array([[4, 5]]) # shape (1, 2) + c3 = np.array([[6]]) # shape (1, 1) + expected = np.array([[1+4+6, 2+5+0, 3+0+0]]) # shape (1, 3) + result = combine_zernike_contributions([c1, c2, c3]) + np.testing.assert_array_equal(result, expected) + +def test_empty_input_list(): + """Raise ValueError when input list is empty.""" + with pytest.raises(ValueError, match="No contributions provided."): + combine_zernike_contributions([]) + +def test_inconsistent_sample_count(): + """Raise error or produce incorrect shape if contributions have inconsistent sample counts.""" + c1 = np.array([[1, 2], [3, 4]]) # shape (2, 2) + c2 = np.array([[5, 6]]) # shape (1, 2) + with pytest.raises(ValueError): + combine_zernike_contributions([c1, c2]) + +def test_single_contribution(): + """Combining a single contribution should return the same array (no-op).""" + contrib = np.array([[7, 8, 9], [10, 11, 12]]) + result = combine_zernike_contributions([contrib]) + np.testing.assert_array_equal(result, contrib) + +def test_zero_order_contributions(): + """Contributions with 0 Zernike coefficients.""" + contrib1 = np.empty((2, 0)) # 2 samples, 0 coefficients + contrib2 = np.empty((2, 0)) + expected = np.empty((2, 0)) + result = combine_zernike_contributions([contrib1, contrib2]) + assert result.shape == (2, 0) + np.testing.assert_array_equal(result, expected) + +@patch("wf_psf.data.data_zernike_utils.compute_centroid_correction") +@patch("wf_psf.data.data_zernike_utils.compute_ccd_misalignment") +def test_full_contribution_combination(mock_ccd, mock_centroid, mock_model_params, dummy_prior, dummy_centroid_dataset, dummy_positions): + mock_centroid.return_value = np.full((4, 6), 2.0) + mock_ccd.return_value = np.full((4, 6), 3.0) + + result = assemble_zernike_contributions( + model_params=mock_model_params, + zernike_prior=dummy_prior, + centroid_dataset=dummy_centroid_dataset, + positions=dummy_positions + ) + + expected = dummy_prior + 2.0 + 3.0 + np.testing.assert_allclose(result.numpy(), expected) + +def test_prior_only(mock_model_params, dummy_prior): + mock_model_params.correct_centroids = False + mock_model_params.add_ccd_misalignments = False + + result = assemble_zernike_contributions( + model_params=mock_model_params, + zernike_prior=dummy_prior, + centroid_dataset=None, + positions=None + ) + + np.testing.assert_array_equal(result.numpy(), dummy_prior) + +def test_no_contributions_returns_zeros(): + model_params = RecursiveNamespace( + use_prior=False, + correct_centroids=False, + add_ccd_misalignments=False, + param_hparams=RecursiveNamespace(n_zernikes=8), + ) + + result = assemble_zernike_contributions(model_params) + + assert isinstance(result, tf.Tensor) + assert result.shape == (1, 8) + np.testing.assert_array_equal(result.numpy(), np.zeros((1, 8))) + +def test_prior_as_tensor(mock_model_params): + tensor_prior = tf.ones((4, 6), dtype=tf.float32) + + mock_model_params.correct_centroids = False + mock_model_params.add_ccd_misalignments = False + + result = assemble_zernike_contributions( + model_params=mock_model_params, + zernike_prior=tensor_prior + ) + + assert isinstance(result, tf.Tensor) + np.testing.assert_array_equal(result.numpy(), np.ones((4, 6))) + +@patch("wf_psf.data.data_zernike_utils.compute_centroid_correction") +def test_inconsistent_shapes_raises_error(mock_centroid, mock_model_params, dummy_prior, dummy_centroid_dataset): + mock_model_params.add_ccd_misalignments = False + mock_centroid.return_value = np.ones((5, 6)) # 5 samples instead of 4 + + with pytest.raises(ValueError, match="All contributions must have the same number of samples"): + assemble_zernike_contributions( + model_params=mock_model_params, + zernike_prior=dummy_prior, + centroid_dataset=dummy_centroid_dataset, + positions=None + ) def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_mask): diff --git a/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py b/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py index 7e967465..13d78667 100644 --- a/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py +++ b/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py @@ -9,6 +9,7 @@ import pytest import numpy as np import tensorflow as tf +from unittest.mock import PropertyMock from wf_psf.psf_models.models.psf_model_physical_polychromatic import ( TFPhysicalPolychromaticField, ) @@ -31,6 +32,7 @@ def zks_prior(): def mock_data(mocker): mock_instance = mocker.Mock(spec=DataConfigHandler) # Configure the mock data object to have the necessary attributes + mock_instance.run_type = "training" mock_instance.training_data = mocker.Mock() mock_instance.training_data.dataset = {"positions": np.array([[1, 2], [3, 4]])} mock_instance.test_data = mocker.Mock() @@ -46,145 +48,11 @@ def mock_model_params(mocker): model_params_mock.pupil_diameter = 256 return model_params_mock - -def test_initialize_parameters(mocker, mock_data, mock_model_params, zks_prior): - # Create mock objects for model_params, training_params - # model_params_mock = mocker.MagicMock() - mock_training_params = mocker.Mock() - - # Mock internal methods called during initialization - mocker.patch( - "wf_psf.psf_models.models.psf_model_physical_polychromatic.get_zernike_prior", - return_value=zks_prior, - ) - - mocker.patch( - "wf_psf.data.data_handler.get_obs_positions", return_value=True - ) - - # Create TFPhysicalPolychromaticField instance - field_instance = TFPhysicalPolychromaticField( - mock_model_params, mock_training_params, mock_data - ) - - mocker.patch.object(field_instance, "_initialize_zernike_parameters") - mocker.patch.object(field_instance, "_initialize_layers") - mocker.patch.object(field_instance, "assign_coeff_matrix") - - # Call the method being tested - field_instance._initialize_parameters_and_layers( - mock_model_params, mock_training_params, mock_data - ) - - # Check if internal methods were called with the correct arguments - field_instance._initialize_zernike_parameters.assert_called_once_with( - mock_model_params, mock_data - ) - field_instance._initialize_layers.assert_called_once_with( - mock_model_params, mock_training_params - ) - field_instance.assign_coeff_matrix.assert_not_called() # Because coeff_mat is None in this test - - -def test_initialize_zernike_parameters(mocker, mock_model_params, mock_data, zks_prior): - # Create training params mock object - mock_training_params = mocker.Mock() - - # Mock internal methods called during initialization - mocker.patch( - "wf_psf.psf_models.models.psf_model_physical_polychromatic.get_zernike_prior", - return_value=zks_prior, - ) - - # Create TFPhysicalPolychromaticField instance - field_instance = TFPhysicalPolychromaticField( - mock_model_params, mock_training_params, mock_data - ) - - # Assert that the attributes are set correctly - # assert field_instance.n_zernikes == mock_model_params.param_hparams.n_zernikes - assert np.array_equal(field_instance.zks_prior.numpy(), zks_prior.numpy()) - assert field_instance.n_zks_total == mock_model_params.param_hparams.n_zernikes - assert isinstance( - field_instance.zernike_maps, tf.Tensor - ) # Check if the returned value is a TensorFlow tensor - assert ( - field_instance.zernike_maps.dtype == tf.float32 - ) # Check if the data type of the tensor is float32 - - # Expected shape of the tensor based on the input parameters - expected_shape = ( - field_instance.n_zks_total, - mock_model_params.pupil_diameter, - mock_model_params.pupil_diameter, - ) - assert field_instance.zernike_maps.shape == expected_shape - - # Modify model_params to simulate zks_prior > n_zernikes - mock_model_params.param_hparams.n_zernikes = 2 - - # Call the method again to initialize the parameters - field_instance._initialize_zernike_parameters(mock_model_params, mock_data) - - assert field_instance.n_zks_total == tf.cast( - tf.shape(field_instance.zks_prior)[1], tf.int32 - ) - # Expected shape of the tensor based on the input parameters - expected_shape = ( - field_instance.n_zks_total, - mock_model_params.pupil_diameter, - mock_model_params.pupil_diameter, - ) - assert field_instance.zernike_maps.shape == expected_shape - - -def test_initialize_physical_layer_mocking( - mocker, mock_model_params, mock_data, zks_prior -): - # Create training params mock object - mock_training_params = mocker.Mock() - - # Mock internal methods called during initialization - mocker.patch( - "wf_psf.psf_models.models.psf_model_physical_polychromatic.get_zernike_prior", - return_value=zks_prior, - ) - - # Create a mock for the TFPhysicalLayer class - mock_physical_layer_class = mocker.patch( - "wf_psf.psf_models.models.psf_model_physical_polychromatic.TFPhysicalLayer" - ) - - # Create TFPhysicalPolychromaticField instance - field_instance = TFPhysicalPolychromaticField( - mock_model_params, mock_training_params, mock_data - ) - - # Assert that the TFPhysicalLayer class was called with the expected arguments - mock_physical_layer_class.assert_called_once_with( - field_instance.obs_pos, - field_instance.zks_prior, - interpolation_type=mock_model_params.interpolation_type, - interpolation_args=mock_model_params.interpolation_args, - ) - - @pytest.fixture def physical_layer_instance(mocker, mock_model_params, mock_data, zks_prior): # Create training params mock object mock_training_params = mocker.Mock() - # Mock internal methods called during initialization - mocker.patch( - "wf_psf.psf_models.models.psf_model_physical_polychromatic.get_zernike_prior", - return_value=zks_prior, - ) - - # Create a mock for the TFPhysicalLayer class - mocker.patch( - "wf_psf.psf_models.models.psf_model_physical_polychromatic.TFPhysicalLayer" - ) - # Create TFPhysicalPolychromaticField instance psf_field_instance = TFPhysicalPolychromaticField( mock_model_params, mock_training_params, mock_data @@ -202,8 +70,8 @@ def test_pad_zernikes_num_of_zernikes_equal(physical_layer_instance): zk_prior = tf.reshape(zk_prior, (1, 2, 1, 1)) # Reshaping tensor2 to (1, 2, 1, 1) # Reset n_zks_total attribute - physical_layer_instance.n_zks_total = max( - tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() + physical_layer_instance._n_zks_total = max( + tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() ) # Call the method under test padded_zk_param, padded_zk_prior = physical_layer_instance.pad_zernikes( @@ -224,7 +92,7 @@ def test_pad_zernikes_prior_greater_than_param(physical_layer_instance): zk_prior = tf.reshape(zk_prior, (1, 5, 1, 1)) # Reshaping tensor2 to (1, 5, 1, 1) # Reset n_zks_total attribute - physical_layer_instance.n_zks_total = max( + physical_layer_instance._n_zks_total = max( tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() ) @@ -247,7 +115,7 @@ def test_pad_zernikes_param_greater_than_prior(physical_layer_instance): zk_prior = tf.reshape(zk_prior, (1, 2, 1, 1)) # Reshaping tensor2 to (1, 4, 1, 1) # Reset n_zks_total attribute - physical_layer_instance.n_zks_total = max( + physical_layer_instance._n_zks_total = max( tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() ) @@ -262,43 +130,41 @@ def test_pad_zernikes_param_greater_than_prior(physical_layer_instance): def test_compute_zernikes(mocker, physical_layer_instance): - # Mock padded tensors - padded_zk_param = tf.constant( - [[[[10]], [[20]], [[30]], [[40]]]] - ) # Shape: (1, 4, 1, 1) - padded_zk_prior = tf.constant([[[[1]], [[2]], [[0]], [[0]]]]) # Shape: (1, 4, 1, 1) - - # Reset n_zks_total attribute - physical_layer_instance.n_zks_total = 4 # Assuming a specific value for simplicity + # Expected output of mock components + padded_zernike_param = tf.constant([[[[10]], [[20]], [[30]], [[40]]]], dtype=tf.float32) + padded_zernike_prior = tf.constant([[[[1]], [[2]], [[0]], [[0]]]], dtype=tf.float32) + expected_values = tf.constant([[[[11]], [[22]], [[30]], [[40]]]], dtype=tf.float32) - # Define the mock return values for tf_poly_Z_field and tf_physical_layer.call - padded_zernike_param = tf.constant( - [[[[10]], [[20]], [[30]], [[40]]]] - ) # Shape: (1, 4, 1, 1) - padded_zernike_prior = tf.constant( - [[[[1]], [[2]], [[0]], [[0]]]] - ) # Shape: (1, 4, 1, 1) + # Patch tf_poly_Z_field property + mock_tf_poly_Z_field = mocker.Mock(return_value=padded_zernike_param) + mocker.patch.object( + TFPhysicalPolychromaticField, + 'tf_poly_Z_field', + new_callable=PropertyMock, + return_value=mock_tf_poly_Z_field + ) + # Patch tf_physical_layer property + mock_tf_physical_layer = mocker.Mock() + mock_tf_physical_layer.call.return_value = padded_zernike_prior mocker.patch.object( - physical_layer_instance, "tf_poly_Z_field", return_value=padded_zk_param + TFPhysicalPolychromaticField, + 'tf_physical_layer', + new_callable=PropertyMock, + return_value=mock_tf_physical_layer ) - mocker.patch.object(physical_layer_instance, "call", return_value=padded_zk_prior) + + # Patch pad_zernikes instance method directly (this one isn't a property) mocker.patch.object( physical_layer_instance, - "pad_zernikes", - return_value=(padded_zernike_param, padded_zernike_prior), + 'pad_zernikes', + return_value=(padded_zernike_param, padded_zernike_prior) ) - # Call the method under test - zernike_coeffs = physical_layer_instance.compute_zernikes(tf.constant([[0.0, 0.0]])) - # Define the expected values - expected_values = tf.constant( - [[[[11]], [[22]], [[30]], [[40]]]] - ) # Shape: (1, 4, 1, 1) - - # Assert that the shapes are equal - assert zernike_coeffs.shape == expected_values.shape + # Run the test + zernike_coeffs = physical_layer_instance.compute_zernikes(tf.constant([[0.0, 0.0]])) - # Assert that the tensor values are equal - assert tf.reduce_all(tf.equal(zernike_coeffs, expected_values)) + # Assertions + tf.debugging.assert_equal(zernike_coeffs, expected_values) + assert zernike_coeffs.shape == expected_values.shape \ No newline at end of file From 60dfe56fa479320d36313271edcab524373b911b Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sun, 22 Jun 2025 00:08:51 +0200 Subject: [PATCH 070/129] Replace arg: data in compute_ccd_misalignment with positions --- src/wf_psf/instrument/ccd_misalignments.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/wf_psf/instrument/ccd_misalignments.py b/src/wf_psf/instrument/ccd_misalignments.py index 76386f20..f739c4fc 100644 --- a/src/wf_psf/instrument/ccd_misalignments.py +++ b/src/wf_psf/instrument/ccd_misalignments.py @@ -13,22 +13,23 @@ from wf_psf.data.data_handler import get_np_obs_positions -def compute_ccd_misalignment(model_params, data): +def compute_ccd_misalignment(model_params, positions: np.ndarray) -> np.ndarray: """Compute CCD misalignment. Parameters ---------- model_params : RecursiveNamespace Object containing parameters for this PSF model class. - data : DataConfigHandler - Object containing training and test datasets. + positions : np.ndarray + Numpy array containing the positions of the stars in the focal plane. + Shape: (n_stars, 2), where n_stars is the number of stars and 2 corresponds to x and y coordinates. Returns ------- zernike_ccd_misalignment_array : np.ndarray Numpy array containing the Zernike contributions to model the CCD chip misalignments. """ - obs_positions = get_np_obs_positions(data) + obs_positions = positions ccd_misalignment_calculator = CCDMisalignmentCalculator( tiles_path=model_params.ccd_misalignments_input_path, From 29f9fc59ebd169346b9891f6c666211965a38857 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Sun, 22 Jun 2025 00:15:06 +0200 Subject: [PATCH 071/129] Correct object attributes for DataConfigHandler in ZernikeInputsFactory --- src/wf_psf/data/data_zernike_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index b03ff400..e50381c3 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -57,8 +57,8 @@ def build(data, run_type: str, model_params, prior: Optional[np.ndarray] = None) centroid_dataset = data # Assuming data is a DataConfigHandler or similar object containing train and test datasets positions = np.concatenate( [ - data.training_dataset["positions"], - data.test_dataset["positions"] + data.training_data.dataset["positions"], + data.test_data.dataset["positions"] ], axis=0, ) From a943dbe2d35c677d682c2e089436c1d12e222bd7 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Sun, 22 Jun 2025 00:29:06 +0200 Subject: [PATCH 072/129] Add missing return for tf_physical_layer property --- src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index 971dbb63..7f19171e 100644 --- a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -218,6 +218,7 @@ def tf_physical_layer(self): interpolation_type=self.model_params.interpolation_type, interpolation_args=self.model_params.interpolation_args, ) + return self._tf_physical_layer @property def tf_zernike_OPD(self): From e15444f603653a6d7386cfd38f3ff89579fefa60 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Sun, 22 Jun 2025 18:02:17 +0200 Subject: [PATCH 073/129] Add tf_utils.py module to tf_modules subpackage --- src/wf_psf/psf_models/tf_modules/tf_utils.py | 45 ++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 src/wf_psf/psf_models/tf_modules/tf_utils.py diff --git a/src/wf_psf/psf_models/tf_modules/tf_utils.py b/src/wf_psf/psf_models/tf_modules/tf_utils.py new file mode 100644 index 00000000..a4795f89 --- /dev/null +++ b/src/wf_psf/psf_models/tf_modules/tf_utils.py @@ -0,0 +1,45 @@ +"""TensorFlow Utilities Module. + +Provides lightweight utility functions for safely converting and managing data types +within TensorFlow-based workflows. + +Includes: +- `ensure_tensor`: ensures inputs are TensorFlow tensors with specified dtype + +These tools are designed to support PSF model components, including lazy property evaluation, +data input validation, and type normalization. + +This module is intended for internal use in model layers and inference components to enforce +TensorFlow-compatible inputs. + +Authors: Jennifer Pollack +""" + +import tensorflow as tf +import numpy as np + +def ensure_tensor(input_array, dtype=tf.float32): + """ + Ensure the input is a TensorFlow tensor of the specified dtype. + + Parameters + ---------- + input_array : array-like, tf.Tensor, or np.ndarray + The input to convert. + dtype : tf.DType, optional + The desired TensorFlow dtype (default: tf.float32). + + Returns + ------- + tf.Tensor + A TensorFlow tensor with the specified dtype. + """ + if tf.is_tensor(input_array): + # If already a tensor, optionally cast dtype if different + if input_array.dtype != dtype: + return tf.cast(input_array, dtype) + return input_array + else: + # Convert numpy arrays or other types to tensor + return tf.convert_to_tensor(input_array, dtype=dtype) + From 00b07f768bc0a08619d9cbbc8dc9fe40dda90e9d Mon Sep 17 00:00:00 2001 From: jeipollack Date: Sun, 22 Jun 2025 18:03:50 +0200 Subject: [PATCH 074/129] Use ensure_tensor method from tf_utils.py to check/convert to tensorflow type; Remove get_obs_positions and replace with ensure_tensor method; add property tf_positions --- src/wf_psf/data/data_handler.py | 26 +++++++------------------- 1 file changed, 7 insertions(+), 19 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index b25881a2..2419b730 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -15,6 +15,7 @@ import os import numpy as np import wf_psf.utils.utils as utils +from wf_psf.psf_models.tf_modules.tf_utils import ensure_tensor import tensorflow as tf from fractions import Fraction from typing import Optional, Union @@ -137,6 +138,10 @@ def __init__( self.dataset = None self.sed_data = None + @property + def tf_positions(self): + return ensure_tensor(self.dataset["positions"]) + def load_dataset(self): """Load dataset. @@ -236,7 +241,8 @@ def process_sed_data(self, sed_data): ) for _sed in sed_data ] - self.sed_data = tf.convert_to_tensor(self.sed_data, dtype=tf.float32) + # Convert list of generated SED tensors to a single TensorFlow tensor of float32 dtype + self.sed_data = ensure_tensor(self.sed_data, dtype=tf.float32) self.sed_data = tf.transpose(self.sed_data, perm=[0, 2, 1]) @@ -272,24 +278,6 @@ def get_np_obs_positions(data): return obs_positions -def get_obs_positions(data): - """Get observed positions from the provided dataset. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - tf.Tensor - Tensor containing the observed positions of the stars. - """ - obs_positions = get_np_obs_positions(data) - - return tf.convert_to_tensor(obs_positions, dtype=tf.float32) - - def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: """Extract specific star-related data from training and test datasets. From 3446280a3938019a0aa30b363210febd2ee7a751 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Sun, 22 Jun 2025 18:08:05 +0200 Subject: [PATCH 075/129] Refactor: Add eager-mode helpers and avoid lazy-loading obscurations in graph mode --- .../psf_model_physical_polychromatic.py | 76 +++++++++++-------- 1 file changed, 46 insertions(+), 30 deletions(-) diff --git a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index 7f19171e..0c5d3d06 100644 --- a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -10,7 +10,7 @@ from typing import Optional import tensorflow as tf from tensorflow.python.keras.engine import data_adapter -from wf_psf.data.data_handler import get_obs_positions +from wf_psf.data.data_handler import get_np_obs_positions from wf_psf.data.data_zernike_utils import ZernikeInputsFactory, assemble_zernike_contributions from wf_psf.psf_models import psf_models as psfm from wf_psf.psf_models.tf_modules.tf_layers import ( @@ -21,6 +21,7 @@ TFNonParametricPolynomialVariationsOPD, TFPhysicalLayer, ) +from wf_psf.psf_models.tf_modules.tf_utils import ensure_tensor from wf_psf.utils.read_config import RecursiveNamespace from wf_psf.utils.configs_handler import DataConfigHandler import logging @@ -112,7 +113,8 @@ def __init__(self, model_params, training_params, data, coeff_mat=None): self.model_params = model_params self.training_params = training_params self.data = data - self.run_type = data.run_type + self.run_type = self._get_run_type(data) + self.obs_pos = self.get_obs_pos() # Initialize the model parameters and layers self.output_Q = model_params.output_Q @@ -126,6 +128,22 @@ def __init__(self, model_params, training_params, data, coeff_mat=None): if coeff_mat is not None: self.assign_coeff_matrix(coeff_mat) + # Eagerly initialise tf_batch_poly_PSF + self.tf_batch_poly_PSF = self._build_tf_batch_poly_PSF() + + + def _get_run_type(self, data): + if hasattr(data, 'run_type'): + run_type = data.run_type + elif isinstance(data, dict) and 'run_type' in data: + run_type = data['run_type'] + else: + raise ValueError("data must have a 'run_type' attribute or key") + + if run_type not in {"training", "simulation", "inference"}: + raise ValueError(f"Unknown run_type: {run_type}") + return run_type + def _assemble_zernike_contributions(self): zks_inputs = ZernikeInputsFactory.build( data=self.data, @@ -150,21 +168,20 @@ def save_param_history(self) -> bool: def save_nonparam_history(self) -> bool: """Check if the model should save the optimization history for non-parametric features.""" return getattr(self.model_params.nonparam_hparams, "save_optim_history_nonparam", False) - - # === Lazy properties ===. - @property - def obs_pos(self): - """Lazy loading of the observation positions.""" - if not hasattr(self, "_obs_pos"): - if self.run_type == "training" or self.run_type == "simulation": - # Get the observation positions from the data handler - self._obs_pos = get_obs_positions(self.data) - elif self.run_type == "inference": - # For inference, we might not have a data handler, so we use the model parameters - self._obs_pos = self.data.dataset["positions"] - return self._obs_pos + def get_obs_pos(self): + assert self.run_type in {"training", "simulation", "inference"}, f"Unknown run_type: {self.run_type}" + + if self.run_type in {"training", "simulation"}: + raw_pos = get_np_obs_positions(self.data) + else: + raw_pos = self.data.dataset["positions"] + obs_pos = ensure_tensor(raw_pos, dtype=tf.float32) + + return obs_pos + + # === Lazy properties ===. @property def zks_total_contribution(self): """Lazily load all Zernike contributions, including prior and corrections.""" @@ -227,22 +244,20 @@ def tf_zernike_OPD(self): self._tf_zernike_OPD = TFZernikeOPD(zernike_maps=self.zernike_maps) return self._tf_zernike_OPD - @property - def tf_batch_poly_PSF(self): - """Lazily initialize the batch polychromatic PSF layer.""" - if not hasattr(self, "_tf_batch_poly_PSF"): - obscurations = psfm.tf_obscurations( + def _build_tf_batch_poly_PSF(self): + """Eagerly build the TFBatchPolychromaticPSF layer with numpy-based obscurations.""" + + obscurations = psfm.tf_obscurations( pupil_diam=self.model_params.pupil_diameter, N_filter=self.model_params.LP_filter_length, rotation_angle=self.model_params.obscuration_rotation_angle, ) - self._tf_batch_poly_PSF = TFBatchPolychromaticPSF( + return TFBatchPolychromaticPSF( obscurations=obscurations, output_Q=self.output_Q, output_dim=self.output_dim, ) - return self._tf_batch_poly_PSF @property def tf_np_poly_opd(self): @@ -646,23 +661,24 @@ def call(self, inputs, training=True): if training: # Compute zernikes from parametric model and physical layer zks_coeffs = self.compute_zernikes(input_positions) - - # Propagate to obtain the OPD + + # Parametric OPD maps from Zernikes param_opd_maps = self.tf_zernike_OPD(zks_coeffs) - # Add l2 loss on the parametric OPD + # Add L2 regularization loss on parametric OPD maps self.add_loss( - self.l2_param * tf.math.reduce_sum(tf.math.square(param_opd_maps)) + self.l2_param * tf.reduce_sum(tf.square(param_opd_maps)) ) - # Calculate the non parametric part + # Non-parametric correction nonparam_opd_maps = self.tf_np_poly_opd(input_positions) - # Add the estimations - opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) - + # Combine both contributions + opd_maps = tf.add(param_opd_maps, nonparam_opd_maps) + # Compute the polychromatic PSFs poly_psfs = self.tf_batch_poly_PSF([opd_maps, packed_SEDs]) + # For the inference else: # Compute predictions From 278479f712a1202be077207265fbfcb8df9b1318 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Sun, 22 Jun 2025 18:09:01 +0200 Subject: [PATCH 076/129] Replace deprecated get_obs_positions with get_np_obs_positions and apply ensure_tensor to convert obs_pos to tensorflow float32 --- src/wf_psf/psf_models/tf_modules/tf_psf_field.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/wf_psf/psf_models/tf_modules/tf_psf_field.py b/src/wf_psf/psf_models/tf_modules/tf_psf_field.py index 23a19b8b..b2d1b53e 100644 --- a/src/wf_psf/psf_models/tf_modules/tf_psf_field.py +++ b/src/wf_psf/psf_models/tf_modules/tf_psf_field.py @@ -16,8 +16,9 @@ TFPhysicalLayer, ) from wf_psf.psf_models.models.psf_model_semiparametric import TFSemiParametricField -from wf_psf.data.data_handler import get_obs_positions +from wf_psf.data.data_handler import get_np_obs_positions from wf_psf.psf_models import psf_models as psfm +from wf_psf.psf_models.tf_modules.tf_utils import ensure_tensor import logging logger = logging.getLogger(__name__) @@ -221,7 +222,7 @@ def __init__(self, model_params, training_params, data, coeff_mat): self.output_Q = model_params.output_Q # Inputs: TF_physical_layer - self.obs_pos = get_obs_positions(data) + self.obs_pos = ensure_tensor(get_np_obs_positions(data), dtype=tf.float32) self.zks_prior = get_ground_truth_zernike(data) self.n_zks_prior = tf.shape(self.zks_prior)[1].numpy() From 0eacd98f495a7a8ab1756472d20668a4dcacc322 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Mon, 23 Jun 2025 18:06:04 +0200 Subject: [PATCH 077/129] Remove tf.convert_to_tensor from all Zernike list contributors --- src/wf_psf/data/data_zernike_utils.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index e50381c3..309732d8 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -53,7 +53,7 @@ def build(data, run_type: str, model_params, prior: Optional[np.ndarray] = None) centroid_dataset = None positions = None - if run_type in {"training", "simulation"}: + if run_type in {"training", "simulation", "metrics"}: centroid_dataset = data # Assuming data is a DataConfigHandler or similar object containing train and test datasets positions = np.concatenate( [ @@ -178,9 +178,9 @@ def assemble_zernike_contributions( # Prior if model_params.use_prior and zernike_prior is not None: logger.info("Adding Zernike prior...") - if isinstance(zernike_prior, np.ndarray): - zernike_prior = tf.convert_to_tensor(zernike_prior, dtype=tf.float32) - zernike_contribution_list.append(zernike_prior) + if isinstance(zernike_prior, tf.Tensor): + zernike_prior = zernike_prior.numpy() + zernike_contribution_list.append(zernike_prior) else: logger.info("Skipping Zernike prior (not used or not provided).") @@ -190,9 +190,7 @@ def assemble_zernike_contributions( centroid_correction = compute_centroid_correction( model_params, centroid_dataset, batch_size=batch_size ) - zernike_contribution_list.append( - tf.convert_to_tensor(centroid_correction, dtype=tf.float32) - ) + zernike_contribution_list.append(centroid_correction) else: logger.info("Skipping centroid correction (not enabled or no dataset).") @@ -200,9 +198,7 @@ def assemble_zernike_contributions( if model_params.add_ccd_misalignments and positions is not None: logger.info("Computing CCD misalignment correction...") ccd_misalignment = compute_ccd_misalignment(model_params, positions) - zernike_contribution_list.append( - tf.convert_to_tensor(ccd_misalignment, dtype=tf.float32) - ) + zernike_contribution_list.append(ccd_misalignment) else: logger.info("Skipping CCD misalignment correction (not enabled or no positions).") From 98fde06e4d559750fa7466cc51969d9b72809030 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Mon, 23 Jun 2025 18:07:20 +0200 Subject: [PATCH 078/129] Add and set self.data_conf.run_type value to 'metrics' in MetricsConfigHandler --- src/wf_psf/utils/configs_handler.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/wf_psf/utils/configs_handler.py b/src/wf_psf/utils/configs_handler.py index f59fdc17..13ceb6de 100644 --- a/src/wf_psf/utils/configs_handler.py +++ b/src/wf_psf/utils/configs_handler.py @@ -134,7 +134,7 @@ def __init__(self, data_conf, training_model_params, batch_size=16, load_data=Tr exit() self.simPSF = psf_models.simPSF(training_model_params) - + # Extract sub-configs early train_params = self.data_conf.data.training test_params = self.data_conf.data.test @@ -153,7 +153,7 @@ def __init__(self, data_conf, training_model_params, batch_size=16, load_data=Tr n_bins_lambda=training_model_params.n_bins_lda, load_data=load_data, ) - + self.batch_size = batch_size @@ -262,6 +262,7 @@ def __init__(self, metrics_conf, file_handler, training_conf=None): self._file_handler = file_handler self.training_conf = training_conf self.data_conf = self._load_data_conf() + self.data_conf.run_type = "metrics" self.metrics_dir = self._file_handler.get_metrics_dir( self._file_handler._run_output_dir ) From 5c3b585e13929e7dd9b62d890329ca07788ea360 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Mon, 23 Jun 2025 18:19:32 +0200 Subject: [PATCH 079/129] Eagerly precompute Zernike components; add support for 'metrics' run_type - Precomputed `zks_total_contribution` outside the TensorFlow graph and converted it to tf.float32. - Calculated `n_zks_total` from the contribution shape and param config. - Eagerly generated Zernike maps and stored as tf.float32. - Derived OPD dimension (`opd_dim`) from Zernike map shape. - Generated obscurations via `tf_obscurations` and stored as tf.complex64. - Added 'metrics' as a valid `run_type` alongside 'training', 'simulation', and 'inference'. - Adjusted `get_obs_pos()` logic to treat 'metrics' like 'training' and 'simulation' for position loading. These changes avoid runtime `.numpy()` calls inside `@tf.function` contexts, improve robustness across run modes, and ensure compatibility with training and evaluation pipelines. --- .../psf_model_physical_polychromatic.py | 68 ++++++++++++------- 1 file changed, 42 insertions(+), 26 deletions(-) diff --git a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index 0c5d3d06..08c5ce71 100644 --- a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -128,6 +128,32 @@ def __init__(self, model_params, training_params, data, coeff_mat=None): if coeff_mat is not None: self.assign_coeff_matrix(coeff_mat) + # Compute contributions once eagerly (outside graph) + zks_total_contribution_np = self._assemble_zernike_contributions().numpy() + self._zks_total_contribution = tf.convert_to_tensor(zks_total_contribution_np, dtype=tf.float32) + + # Compute n_zks_total as int + self._n_zks_total = max( + self.model_params.param_hparams.n_zernikes, + zks_total_contribution_np.shape[1] + ) + + # Precompute zernike maps as tf.float32 + self._zernike_maps = psfm.generate_zernike_maps_3d( + n_zernikes=self._n_zks_total, + pupil_diam=self.model_params.pupil_diameter + ) + + # Precompute OPD dimension + self._opd_dim = self._zernike_maps.shape[1] + + # Precompute obscurations as tf.complex64 + self._obscurations = psfm.tf_obscurations( + pupil_diam=self.model_params.pupil_diameter, + N_filter=self.model_params.LP_filter_length, + rotation_angle=self.model_params.obscuration_rotation_angle, + ) + # Eagerly initialise tf_batch_poly_PSF self.tf_batch_poly_PSF = self._build_tf_batch_poly_PSF() @@ -140,7 +166,7 @@ def _get_run_type(self, data): else: raise ValueError("data must have a 'run_type' attribute or key") - if run_type not in {"training", "simulation", "inference"}: + if run_type not in {"training", "simulation", "metrics", "inference"}: raise ValueError(f"Unknown run_type: {run_type}") return run_type @@ -170,9 +196,9 @@ def save_nonparam_history(self) -> bool: return getattr(self.model_params.nonparam_hparams, "save_optim_history_nonparam", False) def get_obs_pos(self): - assert self.run_type in {"training", "simulation", "inference"}, f"Unknown run_type: {self.run_type}" + assert self.run_type in {"training", "simulation", "metrics", "inference"}, f"Unknown run_type: {self.run_type}" - if self.run_type in {"training", "simulation"}: + if self.run_type in {"training", "simulation", "metrics"}: raw_pos = get_np_obs_positions(self.data) else: raw_pos = self.data.dataset["positions"] @@ -184,30 +210,26 @@ def get_obs_pos(self): # === Lazy properties ===. @property def zks_total_contribution(self): - """Lazily load all Zernike contributions, including prior and corrections.""" - if not hasattr(self, "_zks_total_contribution"): - self._zks_total_contribution = self._assemble_zernike_contributions() return self._zks_total_contribution - + @property def n_zks_total(self): """Get the total number of Zernike coefficients.""" - if not hasattr(self, "_n_zks_total"): - self._n_zks_total = max( - self.model_params.param_hparams.n_zernikes, - tf.cast(tf.shape(self.zks_total_contribution)[1], tf.int32), - ) return self._n_zks_total @property def zernike_maps(self): - """Lazy loading of the Zernike maps.""" - if not hasattr(self, "_zernike_maps"): - self._zernike_maps = psfm.generate_zernike_maps_3d( - self.n_zks_total, self.model_params.pupil_diameter - ) + """Get Zernike maps.""" return self._zernike_maps - + + @property + def opd_dim(self): + return self._opd_dim + + @property + def obscurations(self): + return self._obscurations + @property def tf_poly_Z_field(self): """Lazy loading of the polynomial Zernike field layer.""" @@ -246,15 +268,9 @@ def tf_zernike_OPD(self): def _build_tf_batch_poly_PSF(self): """Eagerly build the TFBatchPolychromaticPSF layer with numpy-based obscurations.""" - - obscurations = psfm.tf_obscurations( - pupil_diam=self.model_params.pupil_diameter, - N_filter=self.model_params.LP_filter_length, - rotation_angle=self.model_params.obscuration_rotation_angle, - ) return TFBatchPolychromaticPSF( - obscurations=obscurations, + obscurations=self.obscurations, output_Q=self.output_Q, output_dim=self.output_dim, ) @@ -267,7 +283,7 @@ def tf_np_poly_opd(self): x_lims=self.model_params.x_lims, y_lims=self.model_params.y_lims, d_max=self.model_params.nonparam_hparams.d_max_nonparam, - opd_dim=tf.shape(self.zernike_maps)[1].numpy(), + opd_dim=self.opd_dim, ) return self._tf_np_poly_opd From e3aea260dd09c43fe319ff08e8b0a80c00f89c4f Mon Sep 17 00:00:00 2001 From: jeipollack Date: Wed, 25 Jun 2025 11:13:33 +0200 Subject: [PATCH 080/129] Correct value error: train in dataset_type with training --- src/wf_psf/data/data_handler.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index 2419b730..40611c71 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -184,17 +184,13 @@ def _validate_dataset_structure(self): def _convert_dataset_to_tensorflow(self): """Convert dataset to TensorFlow tensors.""" - self.dataset["positions"] = tf.convert_to_tensor( - self.dataset["positions"], dtype=tf.float32 - ) - if self.dataset_type == "training": - self.dataset["noisy_stars"] = tf.convert_to_tensor( - self.dataset["noisy_stars"], dtype=tf.float32 - ) - + self.dataset["possitions"] = ensure_tensor(self.dataset["positions"], dtype=tf.float32) + + if self.dataset_type == "train": + self.dataset["noisy_stars"] = ensure_tensor(self.dataset["noisy_stars"], dtype=tf.float32) elif self.dataset_type == "test": if "stars" in self.dataset: - self.dataset["stars"] = tf.convert_to_tensor( + self.dataset["stars"] = ensure_tensor( self.dataset["stars"], dtype=tf.float32 ) else: From ee0f8e03d51da98db8c9da7aec1cb19cf579c954 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Wed, 25 Jun 2025 22:17:36 +0200 Subject: [PATCH 081/129] fix: pass random seed to TFNonParametricPolynomialVariationsOPD constructor in psf_model_physical_polychromatic.py --- src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index 08c5ce71..9979bcf1 100644 --- a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -282,6 +282,7 @@ def tf_np_poly_opd(self): self._tf_np_poly_opd = TFNonParametricPolynomialVariationsOPD( x_lims=self.model_params.x_lims, y_lims=self.model_params.y_lims, + random_seed=self.model_params.param_hparams.random_seed, d_max=self.model_params.nonparam_hparams.d_max_nonparam, opd_dim=self.opd_dim, ) From 776c019799f46789958f2f791fbb01ecf4533bed Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 26 Jun 2025 14:09:58 +0200 Subject: [PATCH 082/129] Refactor to suppress TensorFlow debug msgs: replace lambda in call method with a proper function: find_position_indices enable batch processing for better graph optimization --- src/wf_psf/psf_models/tf_modules/tf_layers.py | 13 ++--- src/wf_psf/psf_models/tf_modules/tf_utils.py | 55 +++++++++++++++++++ 2 files changed, 61 insertions(+), 7 deletions(-) diff --git a/src/wf_psf/psf_models/tf_modules/tf_layers.py b/src/wf_psf/psf_models/tf_modules/tf_layers.py index fcb5e8f9..5bb5e2df 100644 --- a/src/wf_psf/psf_models/tf_modules/tf_layers.py +++ b/src/wf_psf/psf_models/tf_modules/tf_layers.py @@ -9,6 +9,7 @@ import tensorflow as tf import tensorflow_addons as tfa from wf_psf.psf_models.tf_modules.tf_modules import TFMonochromaticPSF +from wf_psf.psf_models.tf_modules.tf_utils import find_position_indices from wf_psf.utils.utils import calc_poly_position_mat import wf_psf.utils.utils as utils from wf_psf.utils.interpolation import tfa_interpolate_spline_rbf @@ -16,7 +17,6 @@ logger = logging.getLogger(__name__) - class TFPolynomialZernikeField(tf.keras.layers.Layer): """Calculate the zernike coefficients for a given position. @@ -964,6 +964,7 @@ def interpolate_independent_Zk(self, positions): return interp_zks[:, :, tf.newaxis, tf.newaxis] + def call(self, positions): """Calculate the prior Zernike coefficients for a batch of positions. @@ -999,12 +1000,10 @@ def call(self, positions): """ - def calc_index(idx_pos): - return tf.where(tf.equal(self.obs_pos, idx_pos))[0, 0] + # Find indices for all positions in one batch operation + idx = find_position_indices(self.obs_pos, positions) - # Calculate the indices of the input batch - indices = tf.map_fn(calc_index, positions, fn_output_signature=tf.int64) - # Recover the prior zernikes from the batch indexes - batch_zks = tf.gather(self.zks_prior, indices=indices, axis=0, batch_dims=0) + # Gather the corresponding Zernike coefficients + batch_zks = tf.gather(self.zks_prior, idx, axis=0) return batch_zks[:, :, tf.newaxis, tf.newaxis] diff --git a/src/wf_psf/psf_models/tf_modules/tf_utils.py b/src/wf_psf/psf_models/tf_modules/tf_utils.py index a4795f89..d0a2002c 100644 --- a/src/wf_psf/psf_models/tf_modules/tf_utils.py +++ b/src/wf_psf/psf_models/tf_modules/tf_utils.py @@ -18,6 +18,61 @@ import tensorflow as tf import numpy as np + +@tf.function +def find_position_indices(obs_pos, batch_positions): + """Find indices of batch positions within observed positions using vectorized operations. + + This function locates the indices of multiple query positions within a + reference set of observed positions using broadcasting and vectorized operations. + Each position in the batch must have an exact match in the observed positions. + + Parameters + ---------- + obs_pos : tf.Tensor + Reference positions tensor of shape (n_obs, 2), where n_obs is the number of + observed positions. Each row contains [x, y] coordinates. + batch_positions : tf.Tensor + Query positions tensor of shape (batch_size, 2), where batch_size is the number + of positions to look up. Each row contains [x, y] coordinates. + + Returns + ------- + indices : tf.Tensor + Tensor of shape (batch_size,) containing the indices of each batch position + within obs_pos. The dtype is tf.int64. + + Raises + ------ + tf.errors.InvalidArgumentError + If any position in batch_positions is not found in obs_pos. + + Notes + ----- + Uses exact equality matching - positions must match exactly. More efficient than + iterative lookups for multiple positions due to vectorized operations. + """ + # Shape: obs_pos (n_obs, 2), batch_positions (batch_size, 2) + # Expand for broadcasting: (1, n_obs, 2) and (batch_size, 1, 2) + obs_expanded = tf.expand_dims(obs_pos, 0) + pos_expanded = tf.expand_dims(batch_positions, 1) + + # Compare all positions at once: (batch_size, n_obs) + matches = tf.reduce_all(tf.equal(obs_expanded, pos_expanded), axis=2) + + # Find the index of the matching position for each batch item + # argmax returns the first True value's index along axis=1 + indices = tf.argmax(tf.cast(matches, tf.int32), axis=1) + + # Verify all positions were found + tf.debugging.assert_equal( + tf.reduce_all(tf.reduce_any(matches, axis=1)), + True, + message="Some positions not found in obs_pos" + ) + + return indices + def ensure_tensor(input_array, dtype=tf.float32): """ Ensure the input is a TensorFlow tensor of the specified dtype. From ab8856e8c3d58f9952c5e41e3f05c7eaffbebffc Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 26 Jun 2025 17:24:28 +0200 Subject: [PATCH 083/129] Match old behaviour with conditional and float64 accumulation --- src/wf_psf/data/data_zernike_utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index 309732d8..6ef2c93c 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -131,12 +131,17 @@ def combine_zernike_contributions(contributions: list[np.ndarray]) -> np.ndarray if not contributions: raise ValueError("No contributions provided.") + if len(contributions) == 1: + return contributions[0] + max_order = max(contrib.shape[1] for contrib in contributions) n_samples = contributions[0].shape[0] + if any(c.shape[0] != n_samples for c in contributions): raise ValueError("All contributions must have the same number of samples.") - combined = np.zeros((n_samples, max_order), dtype=np.float32) + combined = np.zeros((n_samples, max_order)) + for contrib in contributions: padded = pad_contribution_to_order(contrib, max_order) combined += padded From ea854776db88f36b9060bb0dceeb716038b9ba05 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Tue, 8 Jul 2025 18:26:02 +0200 Subject: [PATCH 084/129] Add helper to stack x/y field coordinates into (N, 2) positions array - Introduced get_positions() to convert x_field and y_field into a stacked (N, 2) array. - Updated data handler to pass positions and sed_data explicitly. - Includes validation for shape mismatches and None inputs. --- src/wf_psf/inference/psf_inference.py | 36 +++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 5be49343..f3e876c3 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -160,7 +160,8 @@ def data_handler(self): simPSF=self.simPSF, n_bins_lambda=self.n_bins_lambda, load_data=False, - dataset=None, + dataset={"positions": self.get_positions()}, + sed_data = self.seds, ) self._data_handler.run_type = "inference" return self._data_handler @@ -171,6 +172,37 @@ def trained_psf_model(self): self._trained_psf_model = self.load_inference_model() return self._trained_psf_model + def get_positions(self): + """ + Combine x_field and y_field into position pairs. + + Returns + ------- + numpy.ndarray + Array of shape (num_positions, 2) where each row contains [x, y] coordinates. + Returns None if either x_field or y_field is None. + + Raises + ------ + ValueError + If x_field and y_field have different lengths. + """ + if self.x_field is None or self.y_field is None: + return None + + x_arr = np.asarray(self.x_field) + y_arr = np.asarray(self.y_field) + + if x_arr.size != y_arr.size: + raise ValueError(f"x_field and y_field must have the same length. " + f"Got {x_arr.size} and {y_arr.size}") + + # Flatten arrays to handle any input shape, then stack + x_flat = x_arr.flatten() + y_flat = y_arr.flatten() + + return np.column_stack((x_flat, y_flat)) + def load_inference_model(self): """Load the trained PSF model based on the inference configuration.""" model_path = self.config_handler.trained_model_path @@ -187,7 +219,7 @@ def load_inference_model(self): # Load the trained PSF model return load_trained_psf_model( self.training_config, - self.data_config, + self.data_handler, weights_path_pattern, ) From 7b24c45a3b13a4eb89c5da65d0302cd92c10d07d Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 10 Jul 2025 10:35:39 +0200 Subject: [PATCH 085/129] Add helper method to prepare dataset for inference & handle empty/None fields --- src/wf_psf/inference/psf_inference.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index f3e876c3..6bba2282 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -150,9 +150,17 @@ def simPSF(self): self._simPSF = psf_models.simPSF(self.training_config.training.model_params) return self._simPSF + + def _prepare_dataset_for_inference(self): + """Prepare dataset dictionary for inference, returning None if positions are invalid.""" + positions = self.get_positions() + if positions is None: + return None + return {"positions": positions} + @property def data_handler(self): - if self._data_handler is None: + if self._data_handler is None: # Instantiate the data handler self._data_handler = DataHandler( dataset_type="inference", @@ -160,7 +168,7 @@ def data_handler(self): simPSF=self.simPSF, n_bins_lambda=self.n_bins_lambda, load_data=False, - dataset={"positions": self.get_positions()}, + dataset=self._prepare_dataset_for_inference(), sed_data = self.seds, ) self._data_handler.run_type = "inference" @@ -193,6 +201,9 @@ def get_positions(self): x_arr = np.asarray(self.x_field) y_arr = np.asarray(self.y_field) + if x_arr.size == 0 or y_arr.size == 0: + return None + if x_arr.size != y_arr.size: raise ValueError(f"x_field and y_field must have the same length. " f"Got {x_arr.size} and {y_arr.size}") From aeeafa1bfea1a22bbe4360e9e76d36c19f78fe34 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Tue, 22 Jul 2025 10:09:13 +0200 Subject: [PATCH 086/129] Update data_handler_test replacing "get_obs_positions" (deprecation) with "get_np_obs_positions" --- src/wf_psf/tests/test_data/data_handler_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/wf_psf/tests/test_data/data_handler_test.py b/src/wf_psf/tests/test_data/data_handler_test.py index 6545964a..fdf5a436 100644 --- a/src/wf_psf/tests/test_data/data_handler_test.py +++ b/src/wf_psf/tests/test_data/data_handler_test.py @@ -3,7 +3,7 @@ import tensorflow as tf from wf_psf.data.data_handler import ( DataHandler, - get_obs_positions, + get_np_obs_positions, extract_star_data, ) from wf_psf.utils.read_config import RecursiveNamespace @@ -166,8 +166,8 @@ def test_load_test_dataset_missing_stars(tmp_path, simPSF): mock_warning.assert_called_with("Missing 'noisy_stars' in 'train' dataset.") -def test_get_obs_positions(mock_data): - observed_positions = get_obs_positions(mock_data) +def test_get_np_obs_positions(mock_data): + observed_positions = get_np_obs_positions(mock_data) expected_positions = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) assert tf.reduce_all(tf.equal(observed_positions, expected_positions)) From 66b6e6b3a164b14e4c60625273eb76dca0116545 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Wed, 6 Aug 2025 17:58:48 +0200 Subject: [PATCH 087/129] Remove deprecated code from rebase --- src/wf_psf/tests/test_data/data_handler_test.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/src/wf_psf/tests/test_data/data_handler_test.py b/src/wf_psf/tests/test_data/data_handler_test.py index fdf5a436..c1310d21 100644 --- a/src/wf_psf/tests/test_data/data_handler_test.py +++ b/src/wf_psf/tests/test_data/data_handler_test.py @@ -150,21 +150,6 @@ def test_load_test_dataset_missing_stars(tmp_path, simPSF): data_handler.load_dataset() data_handler.validate_and_process_dataset() - data_handler = DataHandler( - dataset_type="train", - data_params=data_params, - simPSF=simPSF, - n_bins_lambda=10, - load_data=False - ) - - data_handler.load_dataset() - data_handler.process_sed_data(mock_dataset["SEDs"]) - - with patch("wf_psf.data.data_handler.logger.warning") as mock_warning: - data_handler._validate_dataset_structure() - mock_warning.assert_called_with("Missing 'noisy_stars' in 'train' dataset.") - def test_get_np_obs_positions(mock_data): observed_positions = get_np_obs_positions(mock_data) From 29970095a01672fcfa92cf8dbe7451a723aadd2a Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Wed, 6 Aug 2025 17:59:32 +0200 Subject: [PATCH 088/129] Remove duplicated checks on arg existance --- src/wf_psf/data/data_handler.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index 40611c71..120974eb 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -189,14 +189,7 @@ def _convert_dataset_to_tensorflow(self): if self.dataset_type == "train": self.dataset["noisy_stars"] = ensure_tensor(self.dataset["noisy_stars"], dtype=tf.float32) elif self.dataset_type == "test": - if "stars" in self.dataset: - self.dataset["stars"] = ensure_tensor( - self.dataset["stars"], dtype=tf.float32 - ) - else: - logger.warning(f"Missing 'stars' in {self.dataset_type} dataset.") - elif "inference" == self.dataset_type: - pass + self.dataset["stars"] = ensure_tensor(self.dataset["stars"], dtype=tf.float32) def process_sed_data(self, sed_data): """ From f4adcd453a5ef19981eb1f0f9afd459e0fe13025 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 8 Aug 2025 20:16:57 +0200 Subject: [PATCH 089/129] Improve Zernike prior handling in assemble_zernike_contributions - Support both NumPy arrays and TensorFlow tensors as valid inputs for zernike_prior. - Added an eager execution check before calling `.numpy()` to ensure safe conversion of tensors. - Raise informative errors for unsupported types or if eager mode is disabled. - Updated function docstring to reflect accepted types and behavior. - Removed extraneous whitespace and added clarifying comments in related code paths. --- src/wf_psf/data/data_zernike_utils.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index 6ef2c93c..8dd74e4d 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -62,7 +62,6 @@ def build(data, run_type: str, model_params, prior: Optional[np.ndarray] = None) ], axis=0, ) - if model_params.use_prior: if prior is not None: logger.warning( @@ -141,7 +140,7 @@ def combine_zernike_contributions(contributions: list[np.ndarray]) -> np.ndarray raise ValueError("All contributions must have the same number of samples.") combined = np.zeros((n_samples, max_order)) - + # Pad each contribution to the max order and sum them for contrib in contributions: padded = pad_contribution_to_order(contrib, max_order) combined += padded @@ -164,7 +163,8 @@ def assemble_zernike_contributions( model_params : RecursiveNamespace Parameters controlling which contributions to apply. zernike_prior : Optional[np.ndarray or tf.Tensor] - The precomputed Zernike prior (e.g., from PDC or another model). + The precomputed Zernike prior. Can be either a NumPy array or a TensorFlow tensor. + If a Tensor, will be converted to NumPy in eager mode. centroid_dataset : Optional[object] Dataset used to compute centroid correction. Must have both training and test sets. positions : Optional[np.ndarray or tf.Tensor] @@ -184,8 +184,17 @@ def assemble_zernike_contributions( if model_params.use_prior and zernike_prior is not None: logger.info("Adding Zernike prior...") if isinstance(zernike_prior, tf.Tensor): - zernike_prior = zernike_prior.numpy() + if tf.executing_eagerly(): + zernike_prior = zernike_prior.numpy() + else: + raise RuntimeError( + "Zernike prior is a TensorFlow tensor but eager execution is disabled. " + "Cannot call `.numpy()` outside of eager mode." + ) + elif isinstance(zernike_prior, np.ndarray): zernike_contribution_list.append(zernike_prior) + else: + raise TypeError("Unsupported zernike_prior type. Must be np.ndarray or tf.Tensor.") else: logger.info("Skipping Zernike prior (not used or not provided).") @@ -220,7 +229,6 @@ def assemble_zernike_contributions( return tf.convert_to_tensor(combined_zernike_prior, dtype=tf.float32) - def shift_x_y_to_zk1_2_wavediff(dxy, tel_focal_length=24.5, tel_diameter=1.2): """Compute Zernike 1(2) for a given shifts in x(y) in WaveDifff conventions. From e8f60758356b8c10df8abba19f008bfbbb1b6c37 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 8 Aug 2025 20:25:35 +0200 Subject: [PATCH 090/129] Fix bug where Tensor zernike_prior was not appended after eager conversion --- src/wf_psf/data/data_zernike_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index 8dd74e4d..1157be84 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -191,10 +191,10 @@ def assemble_zernike_contributions( "Zernike prior is a TensorFlow tensor but eager execution is disabled. " "Cannot call `.numpy()` outside of eager mode." ) - elif isinstance(zernike_prior, np.ndarray): - zernike_contribution_list.append(zernike_prior) - else: + + elif not isinstance(zernike_prior, np.ndarray): raise TypeError("Unsupported zernike_prior type. Must be np.ndarray or tf.Tensor.") + zernike_contribution_list.append(zernike_prior) else: logger.info("Skipping Zernike prior (not used or not provided).") From f8ba288f6d530eab2a59132583777aafa00d7ffd Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 8 Aug 2025 20:28:46 +0200 Subject: [PATCH 091/129] Update unit tests with latest changes to fixtures and data_zernike_utils.py --- .../test_data/data_zernike_utils_test.py | 67 ++++++++++--------- 1 file changed, 35 insertions(+), 32 deletions(-) diff --git a/src/wf_psf/tests/test_data/data_zernike_utils_test.py b/src/wf_psf/tests/test_data/data_zernike_utils_test.py index 90994dde..cc6e22de 100644 --- a/src/wf_psf/tests/test_data/data_zernike_utils_test.py +++ b/src/wf_psf/tests/test_data/data_zernike_utils_test.py @@ -12,10 +12,9 @@ assemble_zernike_contributions, compute_zernike_tip_tilt, ) -from wf_psf.tests.test_data.test_data_utils import MockData, MockDataset +from wf_psf.tests.test_data.test_data_utils import MockData from types import SimpleNamespace as RecursiveNamespace - @pytest.fixture def mock_model_params(): return RecursiveNamespace( @@ -29,41 +28,45 @@ def mock_model_params(): def dummy_prior(): return np.ones((4, 6), dtype=np.float32) -@pytest.fixture -def dummy_positions(): - return np.random.rand(4, 2).astype(np.float32) @pytest.fixture def dummy_centroid_dataset(): return {"training": "dummy_train", "test": "dummy_test"} -def test_training_without_prior(mock_model_params): + +def test_training_without_prior(mock_model_params, mock_data): mock_model_params.use_prior = False - data = MagicMock() - data.training_dataset = {"positions": np.ones((2, 2))} - data.test_dataset = {"positions": np.zeros((3, 2))} - zinputs = ZernikeInputsFactory.build(data=data, run_type="training", model_params=mock_model_params) + # Clear priors to simulate no prior being used + mock_data.training_data.dataset.pop("zernike_prior", None) + mock_data.test_data.dataset.pop("zernike_prior", None) - assert zinputs.centroid_dataset is data + zinputs = ZernikeInputsFactory.build(data=mock_data, run_type="training", model_params=mock_model_params) + + assert zinputs.centroid_dataset is mock_data assert zinputs.zernike_prior is None - np.testing.assert_array_equal( - zinputs.misalignment_positions, - np.concatenate([data.training_dataset["positions"], data.test_dataset["positions"]]) - ) -@patch("wf_psf.data.data_zernike_utils.get_np_zernike_prior") -def test_training_with_dataset_prior(mock_get_prior, mock_model_params): + expected_positions = np.concatenate([ + mock_data.training_data.dataset["positions"], + mock_data.test_data.dataset["positions"] + ]) + np.testing.assert_array_equal(zinputs.misalignment_positions, expected_positions) + + +def test_training_with_dataset_prior(mock_model_params, mock_data): mock_model_params.use_prior = True - data = MagicMock() - data.training_dataset = {"positions": np.ones((2, 2))} - data.test_dataset = {"positions": np.zeros((2, 2))} - mock_get_prior.return_value = np.array([1.0, 2.0, 3.0]) - zinputs = ZernikeInputsFactory.build(data=data, run_type="training", model_params=mock_model_params) + zinputs = ZernikeInputsFactory.build(data=mock_data, run_type="training", model_params=mock_model_params) + + expected_priors = np.concatenate( + ( + mock_data.training_data.dataset["zernike_prior"], + mock_data.test_data.dataset["zernike_prior"], + ), + axis=0, + ) + np.testing.assert_array_equal(zinputs.zernike_prior, expected_priors) - assert zinputs.zernike_prior.tolist() == [1.0, 2.0, 3.0] - mock_get_prior.assert_called_once_with(data) def test_training_with_explicit_prior(mock_model_params, caplog): mock_model_params.use_prior = True @@ -224,17 +227,18 @@ def test_zero_order_contributions(): @patch("wf_psf.data.data_zernike_utils.compute_centroid_correction") @patch("wf_psf.data.data_zernike_utils.compute_ccd_misalignment") -def test_full_contribution_combination(mock_ccd, mock_centroid, mock_model_params, dummy_prior, dummy_centroid_dataset, dummy_positions): +def test_full_contribution_combination(mock_ccd, mock_centroid, mock_model_params, dummy_prior, dummy_centroid_dataset): mock_centroid.return_value = np.full((4, 6), 2.0) mock_ccd.return_value = np.full((4, 6), 3.0) + dummy_positions = np.full((4, 6), 1.0) result = assemble_zernike_contributions( model_params=mock_model_params, zernike_prior=dummy_prior, centroid_dataset=dummy_centroid_dataset, - positions=dummy_positions + positions = dummy_positions ) - + expected = dummy_prior + 2.0 + 3.0 np.testing.assert_allclose(result.numpy(), expected) @@ -275,7 +279,7 @@ def test_prior_as_tensor(mock_model_params): model_params=mock_model_params, zernike_prior=tensor_prior ) - + assert tf.executing_eagerly(), "TensorFlow must be in eager mode for this test" assert isinstance(result, tf.Tensor) np.testing.assert_array_equal(result.numpy(), np.ones((4, 6))) @@ -294,7 +298,7 @@ def test_inconsistent_shapes_raises_error(mock_centroid, mock_model_params, dumm def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_mask): - """Test compute_zernike_tip_tilt with single batch input and mocks.""" + """Test compute_zernike_tip_tilt handling with single batch input and mocks.""" # Mock the CentroidEstimator class mock_centroid_calc = mocker.patch("wf_psf.data.centroids.CentroidEstimator", autospec=True) @@ -332,11 +336,11 @@ def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_ma np.testing.assert_allclose(args[0][1], expected_dy * pixel_sampling, rtol=1e-7, atol=0) # Expected values based on mock side_effect (0.5 * shift) - np.testing.assert_allclose(zernike_corrections[0, 0], expected_dx * pixel_sampling * 0.5) # Zk1 + np.testing.assert_allclose(zernike_corrections[0, 0], expected_dx * pixel_sampling * 0.5) # Zk1 np.testing.assert_allclose(zernike_corrections[0, 1], expected_dy * pixel_sampling * 0.5) # Zk2 def test_compute_zernike_tip_tilt_batch(mocker, multiple_images): - """Test compute_zernike_tip_tilt with batch input and mocks.""" + """Test compute_zernike_tip_tilt batch handling of multiple inputs.""" # Mock the CentroidEstimator class mock_centroid_calc = mocker.patch("wf_psf.data.centroids.CentroidEstimator", autospec=True) @@ -377,7 +381,6 @@ def test_compute_zernike_tip_tilt_batch(mocker, multiple_images): # Process the displacements and expected values for each image in the batch expected_dx = reference_shifts[1] - mock_instance.get_intra_pixel_shifts.return_value[:, 1] # Expected x-axis shift in meters - expected_dy = reference_shifts[0] - mock_instance.get_intra_pixel_shifts.return_value[:, 0] # Expected y-axis shift in meters # Compare expected values with the actual arguments passed to the mock function From a97cc66a5f812ed8c2c7f7ec97edd6e114a7cebe Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 8 Aug 2025 20:30:00 +0200 Subject: [PATCH 092/129] Set mock Zernike priors to None in test_data_utils.py helper module --- src/wf_psf/tests/test_data/test_data_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/wf_psf/tests/test_data/test_data_utils.py b/src/wf_psf/tests/test_data/test_data_utils.py index a5ead298..1ebc00cb 100644 --- a/src/wf_psf/tests/test_data/test_data_utils.py +++ b/src/wf_psf/tests/test_data/test_data_utils.py @@ -8,8 +8,8 @@ def __init__( self, training_positions, test_positions, - training_zernike_priors, - test_zernike_priors, + training_zernike_priors=None, + test_zernike_priors=None, noisy_stars=None, noisy_masks=None, stars=None, From 88e49bcc35fc7d4a38e5a59364d5b46df8bce520 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 8 Aug 2025 20:31:12 +0200 Subject: [PATCH 093/129] Remove -1.0 multiplicative factor applied to Zernike tip and tilt values --- src/wf_psf/data/centroids.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/wf_psf/data/centroids.py b/src/wf_psf/data/centroids.py index 97cb069f..b01af2dd 100644 --- a/src/wf_psf/data/centroids.py +++ b/src/wf_psf/data/centroids.py @@ -32,12 +32,11 @@ def compute_centroid_correction(model_params, data, batch_size: int=1) -> np.nda batch_size : int, optional The batch size to use when processing the stars. Default is 16. - Returns ------- zernike_centroid_array : np.ndarray A 2D NumPy array of shape `(n_stars, 3)`, where `n_stars` is the number of - observed stars. The array contains the computed Zernike contributions, + observed stars. The array contains the computed Zernike (Z1, Z2) contributions, with zero padding applied to the first column to ensure a consistent shape. """ star_postage_stamps = extract_star_data(data=data, train_key="noisy_stars", test_key="stars") @@ -70,7 +69,7 @@ def compute_centroid_correction(model_params, data, batch_size: int=1) -> np.nda batch_masks = star_masks[i:i + batch_size] if star_masks is not None else None # Compute Zernike 1 and Zernike 2 for the batch - zk1_2_batch = -1.0 * compute_zernike_tip_tilt( + zk1_2_batch = compute_zernike_tip_tilt( batch_postage_stamps, batch_masks, pix_sampling, reference_shifts ) From a1d055b323df69b47d74d399c2625b09f1cdfbe5 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 8 Aug 2025 22:42:28 +0200 Subject: [PATCH 094/129] Move TFPhysicalPolychromaticField.pad_zernikes to helper method pad_tf_zernikes in data_zernike_utils.py - Import pad_tf_zernikes into psf_model_physical_polychromatic.py - Replace calls to pad_zernikes with pad_tf_zernikes - Move padding zernike unit tests to data_zernike_utils_test.py - Update/Remove unit tests in psf_model_physical_polychromatic_test.py --- src/wf_psf/data/data_zernike_utils.py | 40 +++++ .../psf_model_physical_polychromatic.py | 14 +- .../test_data/data_zernike_utils_test.py | 75 ++++++++- .../psf_model_physical_polychromatic_test.py | 144 +++++------------- 4 files changed, 164 insertions(+), 109 deletions(-) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index 1157be84..3a7fa90b 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -147,6 +147,46 @@ def combine_zernike_contributions(contributions: list[np.ndarray]) -> np.ndarray return combined + +def pad_tf_zernikes(zk_param: tf.Tensor, zk_prior: tf.Tensor, n_zks_total: int): + """ + Pad the Zernike coefficient tensors to match the specified total number of Zernikes. + + Parameters + ---------- + zk_param : tf.Tensor + Zernike coefficients for the parametric part. Shape [batch, n_zks_param, 1, 1]. + zk_prior : tf.Tensor + Zernike coefficients for the prior part. Shape [batch, n_zks_prior, 1, 1]. + n_zks_total : int + Total number of Zernikes to pad to. + + Returns + ------- + padded_zk_param : tf.Tensor + Padded Zernike coefficients for the parametric part. Shape [batch, n_zks_total, 1, 1]. + padded_zk_prior : tf.Tensor + Padded Zernike coefficients for the prior part. Shape [batch, n_zks_total, 1, 1]. + """ + + pad_num_param = n_zks_total - tf.shape(zk_param)[1] + pad_num_prior = n_zks_total - tf.shape(zk_prior)[1] + + padded_zk_param = tf.cond( + tf.not_equal(pad_num_param, 0), + lambda: tf.pad(zk_param, [(0, 0), (0, pad_num_param), (0, 0), (0, 0)]), + lambda: zk_param, + ) + + padded_zk_prior = tf.cond( + tf.not_equal(pad_num_prior, 0), + lambda: tf.pad(zk_prior, [(0, 0), (0, pad_num_prior), (0, 0), (0, 0)]), + lambda: zk_prior, + ) + + return padded_zk_param, padded_zk_prior + + def assemble_zernike_contributions( model_params, zernike_prior=None, diff --git a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index 9979bcf1..918ffc4b 100644 --- a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -11,7 +11,11 @@ import tensorflow as tf from tensorflow.python.keras.engine import data_adapter from wf_psf.data.data_handler import get_np_obs_positions -from wf_psf.data.data_zernike_utils import ZernikeInputsFactory, assemble_zernike_contributions +from wf_psf.data.data_zernike_utils import ( + ZernikeInputsFactory, + assemble_zernike_contributions, + pad_tf_zernikes +) from wf_psf.psf_models import psf_models as psfm from wf_psf.psf_models.tf_modules.tf_layers import ( TFPolynomialZernikeField, @@ -575,8 +579,8 @@ def compute_zernikes(self, input_positions): zernike_prior = self.tf_physical_layer.call(input_positions) # Pad and sum the zernike coefficients - padded_zernike_params, padded_zernike_prior = self.pad_zernikes( - zernike_params, zernike_prior + padded_zernike_params, padded_zernike_prior = pad_tf_zernikes( + zernike_params, zernike_prior, self.n_zks_total ) zernike_coeffs = tf.math.add(padded_zernike_params, padded_zernike_prior) @@ -613,8 +617,8 @@ def predict_zernikes(self, input_positions): physical_layer_prediction = self.tf_physical_layer.predict(input_positions) # Pad and sum the Zernike coefficients - padded_zernike_params, padded_physical_layer_prediction = self.pad_zernikes( - zernike_params, physical_layer_prediction + padded_zernike_params, padded_physical_layer_prediction = pad_tf_zernikes( + zernike_params, physical_layer_prediction, self.n_zks_total ) zernike_coeffs = tf.math.add( padded_zernike_params, padded_physical_layer_prediction diff --git a/src/wf_psf/tests/test_data/data_zernike_utils_test.py b/src/wf_psf/tests/test_data/data_zernike_utils_test.py index cc6e22de..afafc1db 100644 --- a/src/wf_psf/tests/test_data/data_zernike_utils_test.py +++ b/src/wf_psf/tests/test_data/data_zernike_utils_test.py @@ -4,13 +4,13 @@ from unittest.mock import MagicMock, patch import tensorflow as tf from wf_psf.data.data_zernike_utils import ( - ZernikeInputs, ZernikeInputsFactory, get_np_zernike_prior, pad_contribution_to_order, combine_zernike_contributions, assemble_zernike_contributions, compute_zernike_tip_tilt, + pad_tf_zernikes ) from wf_psf.tests.test_data.test_data_utils import MockData from types import SimpleNamespace as RecursiveNamespace @@ -297,6 +297,79 @@ def test_inconsistent_shapes_raises_error(mock_centroid, mock_model_params, dumm ) +def test_pad_zernikes_num_of_zernikes_equal(): + # Prepare your test tensors + zk_param = tf.constant([[[[1.0]]], [[[2.0]]]]) # Shape (2, 1, 1, 1) + zk_prior = tf.constant([[[[1.0]]], [[[2.0]]]]) # Same shape + + # Reshape to (1, 2, 1, 1) + zk_param = tf.reshape(zk_param, (1, 2, 1, 1)) + zk_prior = tf.reshape(zk_prior, (1, 2, 1, 1)) + + # Reset _n_zks_total to max number of zernikes (2 here) + n_zks_total = max( + tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() + ) + + # Call pad_zernikes method + padded_zk_param, padded_zk_prior = pad_tf_zernikes( + zk_param, zk_prior, n_zks_total + ) + + # Assert shapes are equal and correct + assert padded_zk_param.shape[1] == n_zks_total + assert padded_zk_prior.shape[1] == n_zks_total + + # If num zernikes already equal, output should be unchanged + np.testing.assert_array_equal(padded_zk_param.numpy(), zk_param.numpy()) + np.testing.assert_array_equal(padded_zk_prior.numpy(), zk_prior.numpy()) + +def test_pad_zernikes_prior_greater_than_param(): + zk_param = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) + zk_prior = tf.constant([[[[1]], [[2]], [[3]], [[4]], [[5]]]]) # Shape: (5, 1, 1, 1) + + # Reshape the tensors to have the desired shapes + zk_param = tf.reshape(zk_param, (1, 2, 1, 1)) # Reshaping tensor1 to (1, 2, 1, 1) + zk_prior = tf.reshape(zk_prior, (1, 5, 1, 1)) # Reshaping tensor2 to (1, 5, 1, 1) + + # Reset n_zks_total attribute + n_zks_total = max( + tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() + ) + + # Call the method under test + padded_zk_param, padded_zk_prior = pad_tf_zernikes( + zk_param, zk_prior, n_zks_total + ) + + # Assert that the padded tensors have the correct shapes + assert padded_zk_param.shape == (1, 5, 1, 1) + assert padded_zk_prior.shape == (1, 5, 1, 1) + + +def test_pad_zernikes_param_greater_than_prior(): + zk_param = tf.constant([[[[10]], [[20]], [[30]], [[40]]]]) # Shape: (4, 1, 1, 1) + zk_prior = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) + + # Reshape the tensors to have the desired shapes + zk_param = tf.reshape(zk_param, (1, 4, 1, 1)) # Reshaping tensor1 to (1, 2, 1, 1) + zk_prior = tf.reshape(zk_prior, (1, 2, 1, 1)) # Reshaping tensor2 to (1, 4, 1, 1) + + # Reset n_zks_total attribute + n_zks_total = max( + tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() + ) + + # Call the method under test + padded_zk_param, padded_zk_prior = pad_tf_zernikes( + zk_param, zk_prior, n_zks_total + ) + + # Assert that the padded tensors have the correct shapes + assert padded_zk_param.shape == (1, 4, 1, 1) + assert padded_zk_prior.shape == (1, 4, 1, 1) + + def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_mask): """Test compute_zernike_tip_tilt handling with single batch input and mocks.""" diff --git a/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py b/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py index 13d78667..95ad6287 100644 --- a/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py +++ b/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py @@ -9,7 +9,7 @@ import pytest import numpy as np import tensorflow as tf -from unittest.mock import PropertyMock +from unittest.mock import patch from wf_psf.psf_models.models.psf_model_physical_polychromatic import ( TFPhysicalPolychromaticField, ) @@ -29,15 +29,27 @@ def zks_prior(): @pytest.fixture -def mock_data(mocker): +def mock_data(mocker, zks_prior): mock_instance = mocker.Mock(spec=DataConfigHandler) - # Configure the mock data object to have the necessary attributes mock_instance.run_type = "training" + + training_dataset = { + "positions": np.array([[1, 2], [3, 4]]), + "zernike_prior": zks_prior, + "noisy_stars": np.zeros((2, 1, 1, 1)), + } + test_dataset = { + "positions": np.array([[5, 6], [7, 8]]), + "zernike_prior": zks_prior, + "stars": np.zeros((2, 1, 1, 1)), + } + mock_instance.training_data = mocker.Mock() - mock_instance.training_data.dataset = {"positions": np.array([[1, 2], [3, 4]])} + mock_instance.training_data.dataset = training_dataset mock_instance.test_data = mocker.Mock() - mock_instance.test_data.dataset = {"positions": np.array([[5, 6], [7, 8]])} - mock_instance.batch_size = 32 + mock_instance.test_data.dataset = test_dataset + mock_instance.batch_size = 16 + return mock_instance @@ -49,122 +61,48 @@ def mock_model_params(mocker): return model_params_mock @pytest.fixture -def physical_layer_instance(mocker, mock_model_params, mock_data, zks_prior): - # Create training params mock object - mock_training_params = mocker.Mock() - - # Create TFPhysicalPolychromaticField instance - psf_field_instance = TFPhysicalPolychromaticField( - mock_model_params, mock_training_params, mock_data - ) - return psf_field_instance - - -def test_pad_zernikes_num_of_zernikes_equal(physical_layer_instance): - # Define input tensors with same length and num of Zernikes - zk_param = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) - zk_prior = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) - - # Reshape the tensors to have the desired shapes - zk_param = tf.reshape(zk_param, (1, 2, 1, 1)) # Reshaping tensor1 to (1, 2, 1, 1) - zk_prior = tf.reshape(zk_prior, (1, 2, 1, 1)) # Reshaping tensor2 to (1, 2, 1, 1) - - # Reset n_zks_total attribute - physical_layer_instance._n_zks_total = max( - tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() - ) - # Call the method under test - padded_zk_param, padded_zk_prior = physical_layer_instance.pad_zernikes( - zk_param, zk_prior - ) - - # Assert that the padded tensors have the correct shapes - assert padded_zk_param.shape == (1, 2, 1, 1) - assert padded_zk_prior.shape == (1, 2, 1, 1) - - -def test_pad_zernikes_prior_greater_than_param(physical_layer_instance): - zk_param = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) - zk_prior = tf.constant([[[[1]], [[2]], [[3]], [[4]], [[5]]]]) # Shape: (5, 1, 1, 1) - - # Reshape the tensors to have the desired shapes - zk_param = tf.reshape(zk_param, (1, 2, 1, 1)) # Reshaping tensor1 to (1, 2, 1, 1) - zk_prior = tf.reshape(zk_prior, (1, 5, 1, 1)) # Reshaping tensor2 to (1, 5, 1, 1) - - # Reset n_zks_total attribute - physical_layer_instance._n_zks_total = max( - tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() - ) - - # Call the method under test - padded_zk_param, padded_zk_prior = physical_layer_instance.pad_zernikes( - zk_param, zk_prior - ) - - # Assert that the padded tensors have the correct shapes - assert padded_zk_param.shape == (1, 5, 1, 1) - assert padded_zk_prior.shape == (1, 5, 1, 1) - - -def test_pad_zernikes_param_greater_than_prior(physical_layer_instance): - zk_param = tf.constant([[[[10]], [[20]], [[30]], [[40]]]]) # Shape: (4, 1, 1, 1) - zk_prior = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) - - # Reshape the tensors to have the desired shapes - zk_param = tf.reshape(zk_param, (1, 4, 1, 1)) # Reshaping tensor1 to (1, 2, 1, 1) - zk_prior = tf.reshape(zk_prior, (1, 2, 1, 1)) # Reshaping tensor2 to (1, 4, 1, 1) - - # Reset n_zks_total attribute - physical_layer_instance._n_zks_total = max( - tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() - ) - - # Call the method under test - padded_zk_param, padded_zk_prior = physical_layer_instance.pad_zernikes( - zk_param, zk_prior - ) - - # Assert that the padded tensors have the correct shapes - assert padded_zk_param.shape == (1, 4, 1, 1) - assert padded_zk_prior.shape == (1, 4, 1, 1) - +def physical_layer_instance(mocker, mock_model_params, mock_data): + # Patch expensive methods during construction to avoid errors + with patch("wf_psf.psf_models.models.psf_model_physical_polychromatic.TFPhysicalPolychromaticField._assemble_zernike_contributions", return_value=tf.constant([[[[1.0]]], [[[2.0]]]])): + from wf_psf.psf_models.models.psf_model_physical_polychromatic import TFPhysicalPolychromaticField + instance = TFPhysicalPolychromaticField(mock_model_params, mocker.Mock(), mock_data) + return instance def test_compute_zernikes(mocker, physical_layer_instance): - # Expected output of mock components + # Expected output of mock components padded_zernike_param = tf.constant([[[[10]], [[20]], [[30]], [[40]]]], dtype=tf.float32) padded_zernike_prior = tf.constant([[[[1]], [[2]], [[0]], [[0]]]], dtype=tf.float32) - expected_values = tf.constant([[[[11]], [[22]], [[30]], [[40]]]], dtype=tf.float32) - - # Patch tf_poly_Z_field property - mock_tf_poly_Z_field = mocker.Mock(return_value=padded_zernike_param) + n_zks_total = physical_layer_instance.n_zks_total + expected_values_list = [11, 22, 30, 40] + [0] * (n_zks_total - 4) + expected_values = tf.constant( + [[[[v]] for v in expected_values_list]], + dtype=tf.float32 +) + # Patch tf_poly_Z_field method mocker.patch.object( TFPhysicalPolychromaticField, - 'tf_poly_Z_field', - new_callable=PropertyMock, - return_value=mock_tf_poly_Z_field + "tf_poly_Z_field", + return_value=padded_zernike_param ) - # Patch tf_physical_layer property + # Patch tf_physical_layer.call method mock_tf_physical_layer = mocker.Mock() mock_tf_physical_layer.call.return_value = padded_zernike_prior mocker.patch.object( TFPhysicalPolychromaticField, - 'tf_physical_layer', - new_callable=PropertyMock, - return_value=mock_tf_physical_layer + "tf_physical_layer", + mock_tf_physical_layer ) - # Patch pad_zernikes instance method directly (this one isn't a property) - mocker.patch.object( - physical_layer_instance, - 'pad_zernikes', + # Patch pad_tf_zernikes function + mocker.patch( + "wf_psf.data.data_zernike_utils.pad_tf_zernikes", return_value=(padded_zernike_param, padded_zernike_prior) ) - # Run the test zernike_coeffs = physical_layer_instance.compute_zernikes(tf.constant([[0.0, 0.0]])) # Assertions tf.debugging.assert_equal(zernike_coeffs, expected_values) - assert zernike_coeffs.shape == expected_values.shape \ No newline at end of file + assert zernike_coeffs.shape == expected_values.shape From b721f2fffc698361eddeefd243210f85de7d1e7c Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 8 Aug 2025 22:59:07 +0200 Subject: [PATCH 095/129] Correct bug in test_load_inference_model - Add missing fixture arguments to mock_training_config and mock_inference_config - Add patch for DataHandler - Remove unused import --- .../test_inference/psf_inference_test.py | 38 +++++++++++++++---- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/src/wf_psf/tests/test_inference/psf_inference_test.py b/src/wf_psf/tests/test_inference/psf_inference_test.py index 4add460b..91328300 100644 --- a/src/wf_psf/tests/test_inference/psf_inference_test.py +++ b/src/wf_psf/tests/test_inference/psf_inference_test.py @@ -12,11 +12,11 @@ import pytest import tensorflow as tf from types import SimpleNamespace -from unittest.mock import MagicMock, patch, PropertyMock +from unittest.mock import MagicMock, patch from wf_psf.inference.psf_inference import ( InferenceConfigHandler, PSFInference, - PSFInferenceEngine + PSFInferenceEngine ) from wf_psf.utils.read_config import RecursiveNamespace @@ -29,7 +29,26 @@ def mock_training_config(): model_params=RecursiveNamespace( model_name="mock_model", output_Q=2, - output_dim=32 + output_dim=32, + pupil_diameter=256, + oversampling_rate=3, + interpolation_type=None, + interpolation_args=None, + sed_interp_pts_per_bin=0, + sed_extrapolate=True, + sed_interp_kind="linear", + sed_sigma=0, + x_lims=[0.0, 1000.0], + y_lims=[0.0, 1000.0], + pix_sampling=12, + tel_diameter=1.2, + tel_focal_length=24.5, + euclid_obsc=True, + LP_filter_length=3, + param_hparams=RecursiveNamespace( + n_zernikes=10, + + ) ) ) ) @@ -48,9 +67,10 @@ def mock_inference_config(): data_config_path=None ), model_params=RecursiveNamespace( + n_bins_lda=8, output_Q=1, output_dim=64 - ) + ), ) ) return inference_config @@ -179,9 +199,11 @@ def test_batch_size_positive(): assert inference.batch_size == 4 +@patch('wf_psf.inference.psf_inference.DataHandler') @patch('wf_psf.inference.psf_inference.load_trained_psf_model') -def test_load_inference_model(mock_load_trained_psf_model, mock_training_config, mock_inference_config): - +def test_load_inference_model(mock_load_trained_psf_model, mock_data_handler, mock_training_config, mock_inference_config): + mock_data_config = MagicMock() + mock_data_handler.return_value = mock_data_config mock_config_handler = MagicMock(spec=InferenceConfigHandler) mock_config_handler.trained_model_path = "mock/path/to/model" mock_config_handler.training_config = mock_training_config @@ -202,8 +224,8 @@ def test_load_inference_model(mock_load_trained_psf_model, mock_training_config, # Assert calls to the mocked methods mock_load_trained_psf_model.assert_called_once_with( - mock_config_handler.training_config, - mock_config_handler.data_config, + mock_training_config, + mock_data_config, weights_path_pattern ) From 98d92a0372a11cce388291c8827fa7c0c05f7293 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Mon, 18 Aug 2025 15:33:26 +0200 Subject: [PATCH 096/129] Revert sign change applied to compute_centroid_correction --- src/wf_psf/data/centroids.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/wf_psf/data/centroids.py b/src/wf_psf/data/centroids.py index b01af2dd..7052f833 100644 --- a/src/wf_psf/data/centroids.py +++ b/src/wf_psf/data/centroids.py @@ -69,7 +69,7 @@ def compute_centroid_correction(model_params, data, batch_size: int=1) -> np.nda batch_masks = star_masks[i:i + batch_size] if star_masks is not None else None # Compute Zernike 1 and Zernike 2 for the batch - zk1_2_batch = compute_zernike_tip_tilt( + zk1_2_batch = -1.0 * compute_zernike_tip_tilt( batch_postage_stamps, batch_masks, pix_sampling, reference_shifts ) From 2e644d7918e63c7bbfd985e15322dc9c560d1489 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Tue, 19 Aug 2025 23:37:52 +0200 Subject: [PATCH 097/129] Refactor _prepare_positions_and_seds to enforce shape consistency and add validation - Ensure x_field and y_field are at least 1D and broadcast to positions of shape (n_samples, 2) - Validate that SEDs batch size matches the number of positions; raise ValueError if not - Validate that SEDs last dimension is 2 (flux, wavelength) - Process SEDs via data_handler as before - Removes need for broadcasting SEDs silently, avoiding hidden shape mismatches - Supports single-star, multi-star, and scalar inputs - Improves error messages for easier debugging in unit tests Also updates unit tests to cover: - Single-star input shapes - Mismatched x/y fields or SED batch size (ValueError cases) --- src/wf_psf/inference/psf_inference.py | 60 ++++++-- .../test_inference/psf_inference_test.py | 137 +++++++++++++++++- src/wf_psf/utils/utils.py | 41 ++++++ 3 files changed, 226 insertions(+), 12 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 6bba2282..76619f5a 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -13,6 +13,7 @@ import numpy as np from wf_psf.data.data_handler import DataHandler from wf_psf.utils.read_config import read_conf +from wf_psf.utils.utils import ensure_batch from wf_psf.psf_models import psf_models from wf_psf.psf_models.psf_model_loader import load_trained_psf_model import tensorflow as tf @@ -260,13 +261,40 @@ def output_dim(self): return self._output_dim def _prepare_positions_and_seds(self): - """Preprocess and return tensors for positions and SEDs.""" - positions = tf.convert_to_tensor( - np.array([self.x_field, self.y_field]).T, dtype=tf.float32 - ) - self.data_handler.process_sed_data(self.seds) - sed_data = self.data_handler.sed_data - return positions, sed_data + """ + Preprocess and return tensors for positions and SEDs with consistent shapes. + + Handles single-star, multi-star, and even scalar inputs, ensuring: + - positions: shape (n_samples, 2) + - sed_data: shape (n_samples, n_bins, 2) + """ + # Ensure x_field and y_field are at least 1D + x_arr = np.atleast_1d(self.x_field) + y_arr = np.atleast_1d(self.y_field) + + if x_arr.size != y_arr.size: + raise ValueError(f"x_field and y_field must have the same length. " + f"Got {x_arr.size} and {y_arr.size}") + + # Combine into positions array (n_samples, 2) + positions = np.column_stack((x_arr, y_arr)) + positions = tf.convert_to_tensor(positions, dtype=tf.float32) + + # Ensure SEDs have shape (n_samples, n_bins, 2) + sed_data = ensure_batch(self.seds) + + if sed_data.shape[0] != positions.shape[0]: + raise ValueError(f"SEDs batch size {sed_data.shape[0]} does not match number of positions {positions.shape[0]}") + + if sed_data.shape[2] != 2: + raise ValueError(f"SEDs last dimension must be 2 (flux, wavelength). Got {sed_data.shape}") + + # Process SEDs through the data handler + self.data_handler.process_sed_data(sed_data) + sed_data_tensor = self.data_handler.sed_data + + return positions, sed_data_tensor + def run_inference(self): """Run PSF inference and return the full PSF array.""" @@ -291,9 +319,23 @@ def get_psfs(self): self._ensure_psf_inference_completed() return self.engine.get_psfs() - def get_psf(self, index): + def get_psf(self, index: int = 0) -> np.ndarray: + """ + Get the PSF at a specific index. + + If only a single star was passed during instantiation, the index defaults to 0. + """ self._ensure_psf_inference_completed() - return self.engine.get_psf(index) + + inferred_psfs = self.engine.get_psfs() + + # If a single-star batch, ignore index bounds + if inferred_psfs.shape[0] == 1: + return inferred_psfs[0] + + # Otherwise, return the PSF at the requested index + return inferred_psfs[index] + class PSFInferenceEngine: def __init__(self, trained_model, batch_size: int, output_dim: int): diff --git a/src/wf_psf/tests/test_inference/psf_inference_test.py b/src/wf_psf/tests/test_inference/psf_inference_test.py index 91328300..ec9f0495 100644 --- a/src/wf_psf/tests/test_inference/psf_inference_test.py +++ b/src/wf_psf/tests/test_inference/psf_inference_test.py @@ -12,7 +12,7 @@ import pytest import tensorflow as tf from types import SimpleNamespace -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, patch, PropertyMock from wf_psf.inference.psf_inference import ( InferenceConfigHandler, PSFInference, @@ -21,6 +21,19 @@ from wf_psf.utils.read_config import RecursiveNamespace +def _patch_data_handler(): + """Helper for patching data_handler to avoid full PSF logic.""" + patcher = patch.object(PSFInference, "data_handler", new_callable=PropertyMock) + mock_data_handler = patcher.start() + mock_instance = MagicMock() + mock_data_handler.return_value = mock_instance + + def fake_process(x): + mock_instance.sed_data = tf.convert_to_tensor(x) + + mock_instance.process_sed_data.side_effect = fake_process + return patcher, mock_instance + @pytest.fixture def mock_training_config(): training_config = RecursiveNamespace( @@ -83,14 +96,46 @@ def psf_test_setup(mock_inference_config): output_dim = 32 mock_positions = tf.convert_to_tensor([[0.1, 0.1], [0.2, 0.2]], dtype=tf.float32) - mock_seds = tf.convert_to_tensor(np.random.rand(num_sources, 2, num_bins), dtype=tf.float32) + mock_seds = tf.convert_to_tensor(np.random.rand(num_sources, num_bins, 2), dtype=tf.float32) expected_psfs = np.random.rand(num_sources, output_dim, output_dim).astype(np.float32) inference = PSFInference( "dummy_path.yaml", x_field=[0.1, 0.2], y_field=[0.1, 0.2], - seds=np.random.rand(num_sources, num_bins) + seds=np.random.rand(num_sources, num_bins, 2) + ) + inference._config_handler = MagicMock() + inference._config_handler.inference_config = mock_inference_config + inference._trained_psf_model = MagicMock() + + return { + "inference": inference, + "mock_positions": mock_positions, + "mock_seds": mock_seds, + "expected_psfs": expected_psfs, + "num_sources": num_sources, + "num_bins": num_bins, + "output_dim": output_dim + } + +@pytest.fixture +def psf_single_star_setup(mock_inference_config): + num_sources = 1 + num_bins = 10 + output_dim = 32 + + # Single position + mock_positions = tf.convert_to_tensor([[0.1, 0.1]], dtype=tf.float32) + # Shape (1, 2, num_bins) + mock_seds = tf.convert_to_tensor(np.random.rand(num_sources, 2, num_bins), dtype=tf.float32) + expected_psfs = np.random.rand(num_sources, output_dim, output_dim).astype(np.float32) + + inference = PSFInference( + "dummy_path.yaml", + x_field=0.1, # scalar for single star + y_field=0.1, + seds=np.random.rand(num_bins, 2) # shape (num_bins, 2) before batching ) inference._config_handler = MagicMock() inference._config_handler.inference_config = mock_inference_config @@ -308,3 +353,89 @@ def fake_compute_psfs(positions, seds): assert np.all(psfs_2 == expected_psfs) assert mock_compute_psfs.call_count == 1 + + + +def test_single_star_inference_shape(psf_single_star_setup): + setup = psf_single_star_setup + + _, mock_instance = _patch_data_handler() + + # Run the method under test + positions, sed_data = setup["inference"]._prepare_positions_and_seds() + + # Check shapes + assert sed_data.shape == (1, setup["num_bins"], 2) + assert positions.shape == (1, 2) + + # Verify the call happened + mock_instance.process_sed_data.assert_called_once() + args, _ = mock_instance.process_sed_data.call_args + input_array = args[0] + + # Check input SED had the right shape before being tensorized + assert input_array.shape == (1, setup["num_bins"], 2), \ + "process_sed_data should have been called with shape (1, num_bins, 2)" + + + +def test_multiple_star_inference_shape(psf_test_setup): + """Test that _prepare_positions_and_seds returns correct shapes for multiple stars.""" + setup = psf_test_setup + + _, mock_instance = _patch_data_handler() + + # Run the method under test + positions, sed_data = setup["inference"]._prepare_positions_and_seds() + + # Check shapes + assert sed_data.shape == (2, setup["num_bins"], 2) + assert positions.shape == (2, 2) + + # Verify the call happened + mock_instance.process_sed_data.assert_called_once() + args, _ = mock_instance.process_sed_data.call_args + input_array = args[0] + + # Check input SED had the right shape before being tensorized + assert input_array.shape == (2, setup["num_bins"], 2), \ + "process_sed_data should have been called with shape (2, num_bins, 2)" + + +def test_valueerror_on_mismatched_batches(psf_single_star_setup): + """Raise if sed_data batch size != positions batch size and sed_data != 1.""" + setup = psf_single_star_setup + inference = setup["inference"] + + patcher, _ = _patch_data_handler() + try: + # Force sed_data to have 2 sources while positions has 1 + bad_sed = np.ones((2, setup["num_bins"], 2), dtype=np.float32) + + # Replace fixture's sed_data with mismatched one + inference.seds = bad_sed + inference.positions = np.ones((1, 2), dtype=np.float32) + + with pytest.raises(ValueError, match="SEDs batch size 2 does not match number of positions 1"): + inference._prepare_positions_and_seds() + finally: + patcher.stop() + + +def test_valueerror_on_mismatched_positions(psf_single_star_setup): + """Raise if positions batch size != sed_data batch size (opposite mismatch).""" + setup = psf_single_star_setup + inference = setup["inference"] + + patcher, _ = _patch_data_handler() + try: + # Force positions to have 3 entries while sed_data has 2 + bad_sed = np.ones((2, setup["num_bins"], 2), dtype=np.float32) + inference.seds = bad_sed + inference.x_field = np.ones((3, 1), dtype=np.float32) + inference.y_field = np.ones((3, 1), dtype=np.float32) + + with pytest.raises(ValueError, match="SEDs batch size 2 does not match number of positions 3"): + inference._prepare_positions_and_seds() + finally: + patcher.stop() \ No newline at end of file diff --git a/src/wf_psf/utils/utils.py b/src/wf_psf/utils/utils.py index 1b1f2d6d..17219ad9 100644 --- a/src/wf_psf/utils/utils.py +++ b/src/wf_psf/utils/utils.py @@ -26,6 +26,47 @@ pass +def scale_to_range(input_array, old_range, new_range): + # Scale to [0,1] + input_array = (input_array - old_range[0]) / (old_range[1] - old_range[0]) + # Scale to new_range + input_array = input_array * (new_range[1] - new_range[0]) + new_range[0] + return input_array + + +def ensure_batch(arr): + """ + Ensure array/tensor has a batch dimension. Converts shape (M, N) → (1, M, N). + + Parameters + ---------- + arr : np.ndarray or tf.Tensor + Input 2D or 3D array/tensor. + + Returns + ------- + np.ndarray or tf.Tensor + With batch dimension prepended if needed. + """ + if isinstance(arr, np.ndarray): + return arr if arr.ndim == 3 else np.expand_dims(arr, axis=0) + elif isinstance(arr, tf.Tensor): + return arr if arr.ndim == 3 else tf.expand_dims(arr, axis=0) + else: + raise TypeError(f"Expected np.ndarray or tf.Tensor, got {type(arr)}") + + +def calc_wfe(zernike_basis, zks): + wfe = np.einsum("ijk,ijk->jk", zernike_basis, zks.reshape(-1, 1, 1)) + return wfe + + +def calc_wfe_rms(zernike_basis, zks, pupil_mask): + wfe = calc_wfe(zernike_basis, zks) + wfe_rms = np.sqrt(np.mean((wfe[pupil_mask] - np.mean(wfe[pupil_mask])) ** 2)) + return wfe_rms + + def generalised_sigmoid(x, max_val=1, power_k=1): """ Apply a generalized sigmoid function to the input. From dd99fe27b08a673a9a546cccfb0220f7e7b56544 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 21 Aug 2025 16:43:44 +0200 Subject: [PATCH 098/129] Fix tensor handling in ZernikeInputsFactory Explicitly convert positions to NumPy arrays with `.numpy()` and adjust inference path to read from `data.dataset` for positions and priors. --- src/wf_psf/data/data_zernike_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index 3a7fa90b..4149d79e 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -57,8 +57,8 @@ def build(data, run_type: str, model_params, prior: Optional[np.ndarray] = None) centroid_dataset = data # Assuming data is a DataConfigHandler or similar object containing train and test datasets positions = np.concatenate( [ - data.training_data.dataset["positions"], - data.test_data.dataset["positions"] + data.training_data.dataset["positions"].numpy(), + data.test_data.dataset["positions"].numpy() ], axis=0, ) @@ -72,11 +72,11 @@ def build(data, run_type: str, model_params, prior: Optional[np.ndarray] = None) elif run_type == "inference": centroid_dataset = None - positions = data["positions"] + positions = data.dataset["positions"].numpy() if model_params.use_prior: # Try to extract prior from `data`, if present - prior = getattr(data, "zernike_prior", None) if not isinstance(data, dict) else data.get("zernike_prior") + prior = getattr(data.dataset, "zernike_prior", None) if not isinstance(data, dict) else data.dataset.get("zernike_prior") if prior is None: logger.warning( From 818dc73806dad8d350b8b48ddbdf68d1f8f6395c Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Thu, 21 Aug 2025 11:50:45 +0200 Subject: [PATCH 099/129] Reformat and remove unused import --- src/wf_psf/data/data_handler.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index 120974eb..fe660940 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -17,7 +17,6 @@ import wf_psf.utils.utils as utils from wf_psf.psf_models.tf_modules.tf_utils import ensure_tensor import tensorflow as tf -from fractions import Fraction from typing import Optional, Union import logging @@ -184,12 +183,18 @@ def _validate_dataset_structure(self): def _convert_dataset_to_tensorflow(self): """Convert dataset to TensorFlow tensors.""" - self.dataset["possitions"] = ensure_tensor(self.dataset["positions"], dtype=tf.float32) - + self.dataset["positions"] = ensure_tensor( + self.dataset["positions"], dtype=tf.float32 + ) + if self.dataset_type == "train": - self.dataset["noisy_stars"] = ensure_tensor(self.dataset["noisy_stars"], dtype=tf.float32) + self.dataset["noisy_stars"] = ensure_tensor( + self.dataset["noisy_stars"], dtype=tf.float32 + ) elif self.dataset_type == "test": - self.dataset["stars"] = ensure_tensor(self.dataset["stars"], dtype=tf.float32) + self.dataset["stars"] = ensure_tensor( + self.dataset["stars"], dtype=tf.float32 + ) def process_sed_data(self, sed_data): """ From f3d4f165b6b549c0eab21c348b45bddfe133dac7 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 22 Aug 2025 09:39:40 +0200 Subject: [PATCH 100/129] Correct zernike_prior extraction when dataset is a dict, reformat file --- src/wf_psf/data/data_zernike_utils.py | 76 ++++++++++++++++----------- 1 file changed, 45 insertions(+), 31 deletions(-) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index 4149d79e..f7653540 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -26,13 +26,17 @@ @dataclass class ZernikeInputs: zernike_prior: Optional[np.ndarray] # true prior, if provided (e.g. from PDC) - centroid_dataset: Optional[Union[dict, 'RecursiveNamespace']] # only used in training/simulation + centroid_dataset: Optional[ + Union[dict, "RecursiveNamespace"] + ] # only used in training/simulation misalignment_positions: Optional[np.ndarray] # needed for CCD corrections class ZernikeInputsFactory: @staticmethod - def build(data, run_type: str, model_params, prior: Optional[np.ndarray] = None) -> ZernikeInputs: + def build( + data, run_type: str, model_params, prior: Optional[np.ndarray] = None + ) -> ZernikeInputs: """Builds a ZernikeInputs dataclass instance based on run type and data. Parameters @@ -58,7 +62,7 @@ def build(data, run_type: str, model_params, prior: Optional[np.ndarray] = None) positions = np.concatenate( [ data.training_data.dataset["positions"].numpy(), - data.test_data.dataset["positions"].numpy() + data.test_data.dataset["positions"].numpy(), ], axis=0, ) @@ -76,7 +80,11 @@ def build(data, run_type: str, model_params, prior: Optional[np.ndarray] = None) if model_params.use_prior: # Try to extract prior from `data`, if present - prior = getattr(data.dataset, "zernike_prior", None) if not isinstance(data, dict) else data.dataset.get("zernike_prior") + prior = ( + getattr(data.dataset, "zernike_prior", None) + if not isinstance(data.dataset, dict) + else data.dataset.get("zernike_prior") + ) if prior is None: logger.warning( @@ -89,7 +97,7 @@ def build(data, run_type: str, model_params, prior: Optional[np.ndarray] = None) return ZernikeInputs( zernike_prior=prior, centroid_dataset=centroid_dataset, - misalignment_positions=positions + misalignment_positions=positions, ) @@ -119,12 +127,14 @@ def get_np_zernike_prior(data): return zernike_prior + def pad_contribution_to_order(contribution: np.ndarray, max_order: int) -> np.ndarray: """Pad a Zernike contribution array to the max Zernike order.""" current_order = contribution.shape[1] pad_width = ((0, 0), (0, max_order - current_order)) return np.pad(contribution, pad_width=pad_width, mode="constant", constant_values=0) + def combine_zernike_contributions(contributions: list[np.ndarray]) -> np.ndarray: """Combine multiple Zernike contributions, padding each to the max order before summing.""" if not contributions: @@ -228,12 +238,14 @@ def assemble_zernike_contributions( zernike_prior = zernike_prior.numpy() else: raise RuntimeError( - "Zernike prior is a TensorFlow tensor but eager execution is disabled. " - "Cannot call `.numpy()` outside of eager mode." + "Zernike prior is a TensorFlow tensor but eager execution is disabled. " + "Cannot call `.numpy()` outside of eager mode." ) - + elif not isinstance(zernike_prior, np.ndarray): - raise TypeError("Unsupported zernike_prior type. Must be np.ndarray or tf.Tensor.") + raise TypeError( + "Unsupported zernike_prior type. Must be np.ndarray or tf.Tensor." + ) zernike_contribution_list.append(zernike_prior) else: logger.info("Skipping Zernike prior (not used or not provided).") @@ -254,7 +266,9 @@ def assemble_zernike_contributions( ccd_misalignment = compute_ccd_misalignment(model_params, positions) zernike_contribution_list.append(ccd_misalignment) else: - logger.info("Skipping CCD misalignment correction (not enabled or no positions).") + logger.info( + "Skipping CCD misalignment correction (not enabled or no positions)." + ) # If no contributions, return zeros tensor to avoid crashes if not zernike_contribution_list: @@ -267,7 +281,7 @@ def assemble_zernike_contributions( combined_zernike_prior = combine_zernike_contributions(zernike_contribution_list) return tf.convert_to_tensor(combined_zernike_prior, dtype=tf.float32) - + def shift_x_y_to_zk1_2_wavediff(dxy, tel_focal_length=24.5, tel_diameter=1.2): """Compute Zernike 1(2) for a given shifts in x(y) in WaveDifff conventions. @@ -303,18 +317,19 @@ def shift_x_y_to_zk1_2_wavediff(dxy, tel_focal_length=24.5, tel_diameter=1.2): * 3.0 ) + def compute_zernike_tip_tilt( star_images: np.ndarray, star_masks: Optional[np.ndarray] = None, pixel_sampling: float = 12e-6, - reference_shifts: list[float] = [-1/3, -1/3], + reference_shifts: list[float] = [-1 / 3, -1 / 3], sigma_init: float = 2.5, n_iter: int = 20, ) -> np.ndarray: """ Compute Zernike tip-tilt corrections for a batch of PSF images. - This function estimates the centroid shifts of multiple PSFs and computes + This function estimates the centroid shifts of multiple PSFs and computes the corresponding Zernike tip-tilt corrections to align them with a reference. Parameters @@ -330,7 +345,7 @@ def compute_zernike_tip_tilt( pixel_sampling : float, optional The pixel size in meters. Defaults to `12e-6 m` (12 microns). reference_shifts : list[float], optional - The target centroid shifts in pixels, specified as `[dy, dx]`. + The target centroid shifts in pixels, specified as `[dy, dx]`. Defaults to `[-1/3, -1/3]` (nominal Euclid conditions). sigma_init : float, optional Initial standard deviation for centroid estimation. Default is `2.5`. @@ -343,21 +358,18 @@ def compute_zernike_tip_tilt( An array of shape `(num_images, 2)`, where: - Column 0 contains `Zk1` (tip) values. - Column 1 contains `Zk2` (tilt) values. - + Notes ----- - This function processes all images at once using vectorized operations. - The Zernike coefficients are computed in the WaveDiff convention. """ from wf_psf.data.centroids import CentroidEstimator - + # Vectorize the centroid computation centroid_estimator = CentroidEstimator( - im=star_images, - mask=star_masks, - sigma_init=sigma_init, - n_iter=n_iter - ) + im=star_images, mask=star_masks, sigma_init=sigma_init, n_iter=n_iter + ) shifts = centroid_estimator.get_intra_pixel_shifts() @@ -365,23 +377,25 @@ def compute_zernike_tip_tilt( reference_shifts = np.array(reference_shifts) # Reshape to ensure it's a column vector (1, 2) - reference_shifts = reference_shifts[None,:] - + reference_shifts = reference_shifts[None, :] + # Broadcast reference_shifts to match the shape of shifts - reference_shifts = np.broadcast_to(reference_shifts, shifts.shape) - + reference_shifts = np.broadcast_to(reference_shifts, shifts.shape) + # Compute displacements - displacements = (reference_shifts - shifts) # - + displacements = reference_shifts - shifts # + # Ensure the correct axis order for displacements (x-axis, then y-axis) - displacements_swapped = displacements[:, [1, 0]] # Adjust axis order if necessary + displacements_swapped = displacements[:, [1, 0]] # Adjust axis order if necessary # Call shift_x_y_to_zk1_2_wavediff directly on the vector of displacements - zk1_2_array = shift_x_y_to_zk1_2_wavediff(displacements_swapped.flatten() * pixel_sampling ) # vectorized call - + zk1_2_array = shift_x_y_to_zk1_2_wavediff( + displacements_swapped.flatten() * pixel_sampling + ) # vectorized call + # Reshape the result back to the original shape of displacements zk1_2_array = zk1_2_array.reshape(displacements.shape) - + return zk1_2_array From 829136a24b396c728bb3544f3b6f78077dd2b385 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 22 Aug 2025 09:41:46 +0200 Subject: [PATCH 101/129] Replace np.array with Tensorflow tensors in unit test and fixtures, reformat files --- src/wf_psf/tests/test_data/conftest.py | 50 ++-- .../test_data/data_zernike_utils_test.py | 242 +++++++++++------- 2 files changed, 176 insertions(+), 116 deletions(-) diff --git a/src/wf_psf/tests/test_data/conftest.py b/src/wf_psf/tests/test_data/conftest.py index 04a56893..47eed929 100644 --- a/src/wf_psf/tests/test_data/conftest.py +++ b/src/wf_psf/tests/test_data/conftest.py @@ -100,38 +100,35 @@ def mock_data(scope="module"): """Fixture to provide mock data for testing.""" # Mock positions and Zernike priors - training_positions = np.array([[1, 2], [3, 4]]) - test_positions = np.array([[5, 6], [7, 8]]) - training_zernike_priors = np.array([[0.1, 0.2], [0.3, 0.4]]) - test_zernike_priors = np.array([[0.5, 0.6], [0.7, 0.8]]) + training_positions = tf.constant([[1, 2], [3, 4]]) + test_positions = tf.constant([[5, 6], [7, 8]]) + training_zernike_priors = tf.constant([[0.1, 0.2], [0.3, 0.4]]) + test_zernike_priors = tf.constant([[0.5, 0.6], [0.7, 0.8]]) # Define dummy 5x5 image patches for stars (mock star images) # Define varied values for 5x5 star images - noisy_stars = tf.constant([ - np.arange(25).reshape(5, 5), - np.arange(25, 50).reshape(5, 5) - ], dtype=tf.float32) - - noisy_masks = tf.constant([ - np.eye(5), - np.ones((5, 5)) - ], dtype=tf.float32) - - stars = tf.constant([ - np.full((5, 5), 100), - np.full((5, 5), 200) - ], dtype=tf.float32) - - masks = tf.constant([ - np.zeros((5, 5)), - np.tri(5) - ], dtype=tf.float32) + noisy_stars = tf.constant( + [np.arange(25).reshape(5, 5), np.arange(25, 50).reshape(5, 5)], dtype=tf.float32 + ) + + noisy_masks = tf.constant([np.eye(5), np.ones((5, 5))], dtype=tf.float32) + + stars = tf.constant([np.full((5, 5), 100), np.full((5, 5), 200)], dtype=tf.float32) + + masks = tf.constant([np.zeros((5, 5)), np.tri(5)], dtype=tf.float32) return MockData( - training_positions, test_positions, training_zernike_priors, - test_zernike_priors, noisy_stars, noisy_masks, stars, masks + training_positions, + test_positions, + training_zernike_priors, + test_zernike_priors, + noisy_stars, + noisy_masks, + stars, + masks, ) + @pytest.fixture def simple_image(scope="module"): """Fixture for a simple star image.""" @@ -140,11 +137,13 @@ def simple_image(scope="module"): image[:, 2, 2] = 1 # Place the star at the center for each image return image + @pytest.fixture def identity_mask(scope="module"): """Creates a mask where all pixels are fully considered.""" return np.ones((5, 5)) + @pytest.fixture def multiple_images(scope="module"): """Fixture for a batch of images with stars at different positions.""" @@ -154,6 +153,7 @@ def multiple_images(scope="module"): images[2, 3, 1] = 1 # Star at (3, 1) in image 2 return images + @pytest.fixture(scope="module", params=[data]) def data_params(): return data diff --git a/src/wf_psf/tests/test_data/data_zernike_utils_test.py b/src/wf_psf/tests/test_data/data_zernike_utils_test.py index afafc1db..390d10f9 100644 --- a/src/wf_psf/tests/test_data/data_zernike_utils_test.py +++ b/src/wf_psf/tests/test_data/data_zernike_utils_test.py @@ -1,4 +1,3 @@ - import pytest import numpy as np from unittest.mock import MagicMock, patch @@ -10,11 +9,12 @@ combine_zernike_contributions, assemble_zernike_contributions, compute_zernike_tip_tilt, - pad_tf_zernikes + pad_tf_zernikes, ) from wf_psf.tests.test_data.test_data_utils import MockData from types import SimpleNamespace as RecursiveNamespace + @pytest.fixture def mock_model_params(): return RecursiveNamespace( @@ -24,6 +24,7 @@ def mock_model_params(): param_hparams=RecursiveNamespace(n_zernikes=6), ) + @pytest.fixture def dummy_prior(): return np.ones((4, 6), dtype=np.float32) @@ -41,22 +42,28 @@ def test_training_without_prior(mock_model_params, mock_data): mock_data.training_data.dataset.pop("zernike_prior", None) mock_data.test_data.dataset.pop("zernike_prior", None) - zinputs = ZernikeInputsFactory.build(data=mock_data, run_type="training", model_params=mock_model_params) + zinputs = ZernikeInputsFactory.build( + data=mock_data, run_type="training", model_params=mock_model_params + ) assert zinputs.centroid_dataset is mock_data assert zinputs.zernike_prior is None - expected_positions = np.concatenate([ - mock_data.training_data.dataset["positions"], - mock_data.test_data.dataset["positions"] - ]) + expected_positions = np.concatenate( + [ + mock_data.training_data.dataset["positions"], + mock_data.test_data.dataset["positions"], + ] + ) np.testing.assert_array_equal(zinputs.misalignment_positions, expected_positions) def test_training_with_dataset_prior(mock_model_params, mock_data): mock_model_params.use_prior = True - zinputs = ZernikeInputsFactory.build(data=mock_data, run_type="training", model_params=mock_model_params) + zinputs = ZernikeInputsFactory.build( + data=mock_data, run_type="training", model_params=mock_model_params + ) expected_priors = np.concatenate( ( @@ -77,7 +84,9 @@ def test_training_with_explicit_prior(mock_model_params, caplog): explicit_prior = np.array([9.0, 9.0, 9.0]) with caplog.at_level("WARNING"): - zinputs = ZernikeInputsFactory.build(data, "training", mock_model_params, prior=explicit_prior) + zinputs = ZernikeInputsFactory.build( + data, "training", mock_model_params, prior=explicit_prior + ) assert "Zernike prior explicitly provided" in caplog.text assert (zinputs.zernike_prior == explicit_prior).all() @@ -85,16 +94,24 @@ def test_training_with_explicit_prior(mock_model_params, caplog): def test_inference_with_dict_and_prior(mock_model_params): mock_model_params.use_prior = True - data = { - "positions": np.ones((5, 2)), - "zernike_prior": np.array([42.0, 0.0]) - } + data = RecursiveNamespace( + dataset={ + "positions": tf.ones((5, 2)), + "zernike_prior": tf.constant([42.0, 0.0]), + } + ) zinputs = ZernikeInputsFactory.build(data, "inference", mock_model_params) assert zinputs.centroid_dataset is None - assert (zinputs.zernike_prior == data["zernike_prior"]).all() - np.testing.assert_array_equal(zinputs.misalignment_positions, data["positions"]) + + # NumPy array comparison + np.testing.assert_array_equal( + zinputs.misalignment_positions, data.dataset["positions"].numpy() + ) + + # TensorFlow tensor comparison + tf.debugging.assert_equal(zinputs.zernike_prior, data.dataset["zernike_prior"]) def test_invalid_run_type(mock_model_params): @@ -111,7 +128,7 @@ def test_get_np_zernike_prior(): # Construct fake DataConfigHandler structure using RecursiveNamespace data = RecursiveNamespace( training_data=RecursiveNamespace(dataset={"zernike_prior": training_prior}), - test_data=RecursiveNamespace(dataset={"zernike_prior": test_prior}) + test_data=RecursiveNamespace(dataset={"zernike_prior": test_prior}), ) expected_prior = np.concatenate((training_prior, test_prior), axis=0) @@ -121,19 +138,24 @@ def test_get_np_zernike_prior(): # Assert shape and values match expected np.testing.assert_array_equal(result, expected_prior) + def test_pad_contribution_to_order(): # Input: batch of 2 samples, each with 3 Zernike coefficients - input_contribution = np.array([ - [1.0, 2.0, 3.0], - [4.0, 5.0, 6.0], - ]) - + input_contribution = np.array( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ) + max_order = 5 # Target size: pad to 5 coefficients - expected_output = np.array([ - [1.0, 2.0, 3.0, 0.0, 0.0], - [4.0, 5.0, 6.0, 0.0, 0.0], - ]) + expected_output = np.array( + [ + [1.0, 2.0, 3.0, 0.0, 0.0], + [4.0, 5.0, 6.0, 0.0, 0.0], + ] + ) padded = pad_contribution_to_order(input_contribution, max_order) @@ -149,6 +171,7 @@ def test_no_padding_needed(): assert output.shape == input_contribution.shape np.testing.assert_array_equal(output, input_contribution) + def test_padding_to_much_higher_order(): """Pad from order 2 to order 10.""" input_contribution = np.array([[1, 2], [3, 4]]) @@ -158,6 +181,7 @@ def test_padding_to_much_higher_order(): assert output.shape == (2, 10) np.testing.assert_array_equal(output, expected_output) + def test_empty_contribution(): """Test behavior with empty input array (0 features).""" input_contribution = np.empty((3, 0)) # 3 samples, 0 coefficients @@ -180,42 +204,44 @@ def test_zero_samples(): def test_combine_zernike_contributions_basic_case(): """Combine two contributions with matching sample count and varying order.""" - contrib1 = np.array([[1, 2], [3, 4]]) # shape (2, 2) - contrib2 = np.array([[5], [6]]) # shape (2, 1) - expected = np.array([ - [1 + 5, 2 + 0], - [3 + 6, 4 + 0] - ]) # padded contrib2 to (2, 2) + contrib1 = np.array([[1, 2], [3, 4]]) # shape (2, 2) + contrib2 = np.array([[5], [6]]) # shape (2, 1) + expected = np.array([[1 + 5, 2 + 0], [3 + 6, 4 + 0]]) # padded contrib2 to (2, 2) result = combine_zernike_contributions([contrib1, contrib2]) np.testing.assert_array_equal(result, expected) + def test_combine_multiple_contributions(): """Combine three contributions.""" - c1 = np.array([[1, 2, 3]]) # shape (1, 3) - c2 = np.array([[4, 5]]) # shape (1, 2) - c3 = np.array([[6]]) # shape (1, 1) - expected = np.array([[1+4+6, 2+5+0, 3+0+0]]) # shape (1, 3) + c1 = np.array([[1, 2, 3]]) # shape (1, 3) + c2 = np.array([[4, 5]]) # shape (1, 2) + c3 = np.array([[6]]) # shape (1, 1) + expected = np.array([[1 + 4 + 6, 2 + 5 + 0, 3 + 0 + 0]]) # shape (1, 3) result = combine_zernike_contributions([c1, c2, c3]) np.testing.assert_array_equal(result, expected) + def test_empty_input_list(): """Raise ValueError when input list is empty.""" with pytest.raises(ValueError, match="No contributions provided."): combine_zernike_contributions([]) + def test_inconsistent_sample_count(): """Raise error or produce incorrect shape if contributions have inconsistent sample counts.""" c1 = np.array([[1, 2], [3, 4]]) # shape (2, 2) - c2 = np.array([[5, 6]]) # shape (1, 2) + c2 = np.array([[5, 6]]) # shape (1, 2) with pytest.raises(ValueError): combine_zernike_contributions([c1, c2]) + def test_single_contribution(): """Combining a single contribution should return the same array (no-op).""" contrib = np.array([[7, 8, 9], [10, 11, 12]]) result = combine_zernike_contributions([contrib]) np.testing.assert_array_equal(result, contrib) + def test_zero_order_contributions(): """Contributions with 0 Zernike coefficients.""" contrib1 = np.empty((2, 0)) # 2 samples, 0 coefficients @@ -225,9 +251,12 @@ def test_zero_order_contributions(): assert result.shape == (2, 0) np.testing.assert_array_equal(result, expected) + @patch("wf_psf.data.data_zernike_utils.compute_centroid_correction") @patch("wf_psf.data.data_zernike_utils.compute_ccd_misalignment") -def test_full_contribution_combination(mock_ccd, mock_centroid, mock_model_params, dummy_prior, dummy_centroid_dataset): +def test_full_contribution_combination( + mock_ccd, mock_centroid, mock_model_params, dummy_prior, dummy_centroid_dataset +): mock_centroid.return_value = np.full((4, 6), 2.0) mock_ccd.return_value = np.full((4, 6), 3.0) dummy_positions = np.full((4, 6), 1.0) @@ -236,12 +265,13 @@ def test_full_contribution_combination(mock_ccd, mock_centroid, mock_model_param model_params=mock_model_params, zernike_prior=dummy_prior, centroid_dataset=dummy_centroid_dataset, - positions = dummy_positions + positions=dummy_positions, ) - + expected = dummy_prior + 2.0 + 3.0 np.testing.assert_allclose(result.numpy(), expected) + def test_prior_only(mock_model_params, dummy_prior): mock_model_params.correct_centroids = False mock_model_params.add_ccd_misalignments = False @@ -250,11 +280,12 @@ def test_prior_only(mock_model_params, dummy_prior): model_params=mock_model_params, zernike_prior=dummy_prior, centroid_dataset=None, - positions=None + positions=None, ) np.testing.assert_array_equal(result.numpy(), dummy_prior) + def test_no_contributions_returns_zeros(): model_params = RecursiveNamespace( use_prior=False, @@ -269,6 +300,7 @@ def test_no_contributions_returns_zeros(): assert result.shape == (1, 8) np.testing.assert_array_equal(result.numpy(), np.zeros((1, 8))) + def test_prior_as_tensor(mock_model_params): tensor_prior = tf.ones((4, 6), dtype=tf.float32) @@ -276,24 +308,28 @@ def test_prior_as_tensor(mock_model_params): mock_model_params.add_ccd_misalignments = False result = assemble_zernike_contributions( - model_params=mock_model_params, - zernike_prior=tensor_prior + model_params=mock_model_params, zernike_prior=tensor_prior ) assert tf.executing_eagerly(), "TensorFlow must be in eager mode for this test" assert isinstance(result, tf.Tensor) np.testing.assert_array_equal(result.numpy(), np.ones((4, 6))) + @patch("wf_psf.data.data_zernike_utils.compute_centroid_correction") -def test_inconsistent_shapes_raises_error(mock_centroid, mock_model_params, dummy_prior, dummy_centroid_dataset): +def test_inconsistent_shapes_raises_error( + mock_centroid, mock_model_params, dummy_prior, dummy_centroid_dataset +): mock_model_params.add_ccd_misalignments = False mock_centroid.return_value = np.ones((5, 6)) # 5 samples instead of 4 - with pytest.raises(ValueError, match="All contributions must have the same number of samples"): + with pytest.raises( + ValueError, match="All contributions must have the same number of samples" + ): assemble_zernike_contributions( model_params=mock_model_params, zernike_prior=dummy_prior, centroid_dataset=dummy_centroid_dataset, - positions=None + positions=None, ) @@ -307,14 +343,10 @@ def test_pad_zernikes_num_of_zernikes_equal(): zk_prior = tf.reshape(zk_prior, (1, 2, 1, 1)) # Reset _n_zks_total to max number of zernikes (2 here) - n_zks_total = max( - tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() - ) + n_zks_total = max(tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy()) # Call pad_zernikes method - padded_zk_param, padded_zk_prior = pad_tf_zernikes( - zk_param, zk_prior, n_zks_total - ) + padded_zk_param, padded_zk_prior = pad_tf_zernikes(zk_param, zk_prior, n_zks_total) # Assert shapes are equal and correct assert padded_zk_param.shape[1] == n_zks_total @@ -324,6 +356,7 @@ def test_pad_zernikes_num_of_zernikes_equal(): np.testing.assert_array_equal(padded_zk_param.numpy(), zk_param.numpy()) np.testing.assert_array_equal(padded_zk_prior.numpy(), zk_prior.numpy()) + def test_pad_zernikes_prior_greater_than_param(): zk_param = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) zk_prior = tf.constant([[[[1]], [[2]], [[3]], [[4]], [[5]]]]) # Shape: (5, 1, 1, 1) @@ -333,14 +366,10 @@ def test_pad_zernikes_prior_greater_than_param(): zk_prior = tf.reshape(zk_prior, (1, 5, 1, 1)) # Reshaping tensor2 to (1, 5, 1, 1) # Reset n_zks_total attribute - n_zks_total = max( - tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() - ) + n_zks_total = max(tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy()) # Call the method under test - padded_zk_param, padded_zk_prior = pad_tf_zernikes( - zk_param, zk_prior, n_zks_total - ) + padded_zk_param, padded_zk_prior = pad_tf_zernikes(zk_param, zk_prior, n_zks_total) # Assert that the padded tensors have the correct shapes assert padded_zk_param.shape == (1, 5, 1, 1) @@ -356,14 +385,10 @@ def test_pad_zernikes_param_greater_than_prior(): zk_prior = tf.reshape(zk_prior, (1, 2, 1, 1)) # Reshaping tensor2 to (1, 4, 1, 1) # Reset n_zks_total attribute - n_zks_total = max( - tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() - ) + n_zks_total = max(tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy()) # Call the method under test - padded_zk_param, padded_zk_prior = pad_tf_zernikes( - zk_param, zk_prior, n_zks_total - ) + padded_zk_param, padded_zk_prior = pad_tf_zernikes(zk_param, zk_prior, n_zks_total) # Assert that the padded tensors have the correct shapes assert padded_zk_param.shape == (1, 4, 1, 1) @@ -374,16 +399,20 @@ def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_ma """Test compute_zernike_tip_tilt handling with single batch input and mocks.""" # Mock the CentroidEstimator class - mock_centroid_calc = mocker.patch("wf_psf.data.centroids.CentroidEstimator", autospec=True) + mock_centroid_calc = mocker.patch( + "wf_psf.data.centroids.CentroidEstimator", autospec=True + ) # Create a mock instance and configure get_intra_pixel_shifts() mock_instance = mock_centroid_calc.return_value - mock_instance.get_intra_pixel_shifts.return_value = np.array([[0.05, -0.02]]) # Shape (1, 2) + mock_instance.get_intra_pixel_shifts.return_value = np.array( + [[0.05, -0.02]] + ) # Shape (1, 2) # Mock shift_x_y_to_zk1_2_wavediff to return predictable values mock_shift_fn = mocker.patch( "wf_psf.data.data_zernike_utils.shift_x_y_to_zk1_2_wavediff", - side_effect=lambda shift: shift * 0.5 # Mocked conversion for test + side_effect=lambda shift: shift * 0.5, # Mocked conversion for test ) # Define test inputs (batch of 1 image) @@ -391,41 +420,58 @@ def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_ma reference_shifts = [-1 / 3, -1 / 3] # Default Euclid conditions # Run the function - zernike_corrections = compute_zernike_tip_tilt(simple_image, identity_mask, pixel_sampling, reference_shifts) - zernike_corrections = compute_zernike_tip_tilt(simple_image, identity_mask, pixel_sampling, reference_shifts) + zernike_corrections = compute_zernike_tip_tilt( + simple_image, identity_mask, pixel_sampling, reference_shifts + ) + zernike_corrections = compute_zernike_tip_tilt( + simple_image, identity_mask, pixel_sampling, reference_shifts + ) # Expected shifts based on centroid calculation - expected_dx = (reference_shifts[1] - (-0.02)) # Expected x-axis shift in meters - expected_dy = (reference_shifts[0] - 0.05) # Expected y-axis shift in meters + expected_dx = reference_shifts[1] - (-0.02) # Expected x-axis shift in meters + expected_dy = reference_shifts[0] - 0.05 # Expected y-axis shift in meters # Expected calls to the mocked function # Extract the arguments passed to mock_shift_fn - args, _ = mock_shift_fn.call_args_list[0] # Get the first call args + args, _ = mock_shift_fn.call_args_list[0] # Get the first call args # Compare expected values with the actual arguments passed to the mock function - np.testing.assert_allclose(args[0][0], expected_dx * pixel_sampling, rtol=1e-7, atol=0) + np.testing.assert_allclose( + args[0][0], expected_dx * pixel_sampling, rtol=1e-7, atol=0 + ) # Check dy values similarly - np.testing.assert_allclose(args[0][1], expected_dy * pixel_sampling, rtol=1e-7, atol=0) + np.testing.assert_allclose( + args[0][1], expected_dy * pixel_sampling, rtol=1e-7, atol=0 + ) # Expected values based on mock side_effect (0.5 * shift) - np.testing.assert_allclose(zernike_corrections[0, 0], expected_dx * pixel_sampling * 0.5) # Zk1 - np.testing.assert_allclose(zernike_corrections[0, 1], expected_dy * pixel_sampling * 0.5) # Zk2 + np.testing.assert_allclose( + zernike_corrections[0, 0], expected_dx * pixel_sampling * 0.5 + ) # Zk1 + np.testing.assert_allclose( + zernike_corrections[0, 1], expected_dy * pixel_sampling * 0.5 + ) # Zk2 + def test_compute_zernike_tip_tilt_batch(mocker, multiple_images): """Test compute_zernike_tip_tilt batch handling of multiple inputs.""" - + # Mock the CentroidEstimator class - mock_centroid_calc = mocker.patch("wf_psf.data.centroids.CentroidEstimator", autospec=True) + mock_centroid_calc = mocker.patch( + "wf_psf.data.centroids.CentroidEstimator", autospec=True + ) # Create a mock instance and configure get_intra_pixel_shifts() mock_instance = mock_centroid_calc.return_value - mock_instance.get_intra_pixel_shifts.return_value = np.array([[0.05, -0.02], [0.04, -0.01], [0.06, -0.03]]) # Shape (3, 2) + mock_instance.get_intra_pixel_shifts.return_value = np.array( + [[0.05, -0.02], [0.04, -0.01], [0.06, -0.03]] + ) # Shape (3, 2) # Mock shift_x_y_to_zk1_2_wavediff to return predictable values mock_shift_fn = mocker.patch( "wf_psf.data.data_zernike_utils.shift_x_y_to_zk1_2_wavediff", - side_effect=lambda shift: shift * 0.5 # Mocked conversion for test + side_effect=lambda shift: shift * 0.5, # Mocked conversion for test ) # Define test inputs (batch of 3 images) @@ -434,16 +480,18 @@ def test_compute_zernike_tip_tilt_batch(mocker, multiple_images): # Run the function zernike_corrections = compute_zernike_tip_tilt( - star_images=multiple_images, - pixel_sampling=pixel_sampling, - reference_shifts=reference_shifts - ) + star_images=multiple_images, + pixel_sampling=pixel_sampling, + reference_shifts=reference_shifts, + ) # Check if the mock function was called once with the full batch - assert len(mock_shift_fn.call_args_list) == 1, f"Expected 1 call, but got {len(mock_shift_fn.call_args_list)}" + assert ( + len(mock_shift_fn.call_args_list) == 1 + ), f"Expected 1 call, but got {len(mock_shift_fn.call_args_list)}" # Get the arguments passed to the mock function for the batch of images - args, _ = mock_shift_fn.call_args_list[0] + args, _ = mock_shift_fn.call_args_list[0] print("Shape of args[0]:", args[0].shape) print("Contents of args[0]:", args[0]) @@ -453,13 +501,25 @@ def test_compute_zernike_tip_tilt_batch(mocker, multiple_images): args_array = np.array(args[0]).reshape(-1, 2) # Process the displacements and expected values for each image in the batch - expected_dx = reference_shifts[1] - mock_instance.get_intra_pixel_shifts.return_value[:, 1] # Expected x-axis shift in meters - expected_dy = reference_shifts[0] - mock_instance.get_intra_pixel_shifts.return_value[:, 0] # Expected y-axis shift in meters + expected_dx = ( + reference_shifts[1] - mock_instance.get_intra_pixel_shifts.return_value[:, 1] + ) # Expected x-axis shift in meters + expected_dy = ( + reference_shifts[0] - mock_instance.get_intra_pixel_shifts.return_value[:, 0] + ) # Expected y-axis shift in meters # Compare expected values with the actual arguments passed to the mock function - np.testing.assert_allclose(args_array[:, 0], expected_dx * pixel_sampling, rtol=1e-7, atol=0) - np.testing.assert_allclose(args_array[:, 1], expected_dy * pixel_sampling, rtol=1e-7, atol=0) + np.testing.assert_allclose( + args_array[:, 0], expected_dx * pixel_sampling, rtol=1e-7, atol=0 + ) + np.testing.assert_allclose( + args_array[:, 1], expected_dy * pixel_sampling, rtol=1e-7, atol=0 + ) # Expected values based on mock side_effect (0.5 * shift) - np.testing.assert_allclose(zernike_corrections[:, 0], expected_dx * pixel_sampling * 0.5) # Zk1 for each image - np.testing.assert_allclose(zernike_corrections[:, 1], expected_dy * pixel_sampling * 0.5) # Zk2 for each image \ No newline at end of file + np.testing.assert_allclose( + zernike_corrections[:, 0], expected_dx * pixel_sampling * 0.5 + ) # Zk1 for each image + np.testing.assert_allclose( + zernike_corrections[:, 1], expected_dy * pixel_sampling * 0.5 + ) # Zk2 for each image From 96460a5600e748c6c9809812f406d8f11a60bda8 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Wed, 27 Aug 2025 11:19:53 +0200 Subject: [PATCH 102/129] Eagerly initialise trainable layers in physical poly model constructor required for evaluation/inference --- .../models/psf_model_physical_polychromatic.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index 918ffc4b..e2302da5 100644 --- a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -120,15 +120,15 @@ def __init__(self, model_params, training_params, data, coeff_mat=None): self.run_type = self._get_run_type(data) self.obs_pos = self.get_obs_pos() - # Initialize the model parameters and layers + # Initialize the model parameters self.output_Q = model_params.output_Q self.l2_param = model_params.param_hparams.l2_param self.output_dim = model_params.output_dim - + # Initialise lazy loading of external Zernike prior self._external_prior = None - # Initialize the model parameters with non-default value + # Set Zernike Polynomial Coefficient Matrix if not None if coeff_mat is not None: self.assign_coeff_matrix(coeff_mat) @@ -158,9 +158,10 @@ def __init__(self, model_params, training_params, data, coeff_mat=None): rotation_angle=self.model_params.obscuration_rotation_angle, ) - # Eagerly initialise tf_batch_poly_PSF + # Eagerly initialise model layers self.tf_batch_poly_PSF = self._build_tf_batch_poly_PSF() - + _ = self.tf_poly_Z_field + _ = self.tf_np_poly_opd def _get_run_type(self, data): if hasattr(data, 'run_type'): From 06d715b6e91d2eb2e4b8db63bc87cb5c8cbe4812 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Wed, 27 Aug 2025 11:21:59 +0200 Subject: [PATCH 103/129] fix: use expect_partial() when loading model weights for evaluation - Add status handling to model.load_weights() call - Use expect_partial() to suppress warnings about unused optimizer state - Allows successful weight loading for metrics evaluation when checkpoint contains training artifacts --- src/wf_psf/psf_models/psf_model_loader.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/wf_psf/psf_models/psf_model_loader.py b/src/wf_psf/psf_models/psf_model_loader.py index 797be8fc..c30d31ad 100644 --- a/src/wf_psf/psf_models/psf_model_loader.py +++ b/src/wf_psf/psf_models/psf_model_loader.py @@ -12,7 +12,7 @@ get_psf_model, get_psf_model_weights_filepath ) - +import tensorflow as tf logger = logging.getLogger(__name__) @@ -48,7 +48,9 @@ def load_trained_psf_model(training_conf, data_conf, weights_path_pattern): try: logger.info(f"Loading PSF model weights from {weights_path}") - model.load_weights(weights_path) + status = model.load_weights(weights_path) + status.expect_partial() + except Exception as e: logger.exception("Failed to load model weights.") raise RuntimeError("Model weight loading failed.") from e From 5b3dee35adc843f352e44df86f898edad7215827 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Wed, 27 Aug 2025 11:24:43 +0200 Subject: [PATCH 104/129] Add memory cleanup after training completion - Delete model reference and run garbage collection - Clear TensorFlow session to free GPU memory - Prevents OOM issues in subsequent operations or multiple training runs --- src/wf_psf/training/train.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/wf_psf/training/train.py b/src/wf_psf/training/train.py index bb0e3df9..f2366ca7 100644 --- a/src/wf_psf/training/train.py +++ b/src/wf_psf/training/train.py @@ -7,6 +7,7 @@ """ +import gc import numpy as np import time import tensorflow as tf @@ -538,3 +539,8 @@ def train( final_time = time.time() logger.info("\nTotal elapsed time: %f" % (final_time - starting_time)) logger.info("\n Training complete..") + + # Clean up memory + del psf_model + gc.collect() + tf.keras.backend.clear_session() From 6a2d24dbb5b312744b9dc6eb4a85b51a9e40c26c Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 4 Sep 2025 01:59:17 +0200 Subject: [PATCH 105/129] refactor: centralise PSF data extraction in data_handler - Introduce unified `get_data_array` for training/metrics/inference access - Add helpers `extract_star_data` and `_get_inference_data` - Remove redundant `get_np_obs_positions` - Move data handling logic out of `compute_centroid_corrections` - Standardise `centroid_dataset` as dict (stamps + optional masks) - Support optional keys (masks, priors) via `allow_missing` - Improve and unify docstrings for data extraction utilities - Add optional "sources" and "masks" attributes to `PSFInference - Add `correct_centroids` and `add_ccd_misalignments` as options to inference_config.yaml --- src/wf_psf/data/centroids.py | 35 ++- src/wf_psf/data/data_handler.py | 200 ++++++++++++++---- src/wf_psf/data/data_zernike_utils.py | 27 +-- src/wf_psf/inference/psf_inference.py | 6 +- .../psf_model_physical_polychromatic.py | 7 +- 5 files changed, 193 insertions(+), 82 deletions(-) diff --git a/src/wf_psf/data/centroids.py b/src/wf_psf/data/centroids.py index 7052f833..fb69c8dd 100644 --- a/src/wf_psf/data/centroids.py +++ b/src/wf_psf/data/centroids.py @@ -14,7 +14,7 @@ from typing import Optional -def compute_centroid_correction(model_params, data, batch_size: int=1) -> np.ndarray: +def compute_centroid_correction(model_params, centroid_dataset, batch_size: int=1) -> np.ndarray: """Compute centroid corrections using Zernike polynomials. This function calculates the Zernike contributions required to match the centroid @@ -25,10 +25,13 @@ def compute_centroid_correction(model_params, data, batch_size: int=1) -> np.nda model_params : RecursiveNamespace An object containing parameters for the PSF model, including pixel sampling and initial centroid window parameters. - data : DataConfigHandler - An object containing training and test datasets, including observed PSFs - and optional star masks. - + centroid_dataset : dict + Dictionary containing star data needed for centroiding: + - "stamps" : np.ndarray + Array of star postage stamps (required). + - "masks" : Optional[np.ndarray] + Array of star masks (optional, can be None). + batch_size : int, optional The batch size to use when processing the stars. Default is 16. @@ -39,24 +42,14 @@ def compute_centroid_correction(model_params, data, batch_size: int=1) -> np.nda observed stars. The array contains the computed Zernike (Z1, Z2) contributions, with zero padding applied to the first column to ensure a consistent shape. """ - star_postage_stamps = extract_star_data(data=data, train_key="noisy_stars", test_key="stars") - - # Get star mask catalogue only if "masks" exist in both training and test datasets - star_masks = ( - extract_star_data(data=data, train_key="masks", test_key="masks") - if ( - data.training_data.dataset.get("masks") is not None - and data.test_data.dataset.get("masks") is not None - and tf.size(data.training_data.dataset["masks"]) > 0 - and tf.size(data.test_data.dataset["masks"]) > 0 - ) - else None - ) + # Retrieve stamps and masks from centroid_dataset + star_postage_stamps = centroid_dataset.get("stamps") + star_masks = centroid_dataset.get("masks") # may be None - pix_sampling = model_params.pix_sampling * 1e-6 # Change units from [um] to [m] + if star_postage_stamps is None: + raise ValueError("centroid_dataset must contain 'stamps'") - # Ensure star_masks is properly handled - star_masks = star_masks if star_masks is not None else None + pix_sampling = model_params.pix_sampling * 1e-6 # Change units from [um] to [m] reference_shifts = [float(Fraction(value)) for value in model_params.reference_shifts] diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index fe660940..bd802763 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -240,67 +240,43 @@ def process_sed_data(self, sed_data): self.sed_data = tf.transpose(self.sed_data, perm=[0, 2, 1]) -def get_np_obs_positions(data): - """Get observed positions in numpy from the provided dataset. - - This method concatenates the positions of the stars from both the training - and test datasets to obtain the observed positions. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - np.ndarray - Numpy array containing the observed positions of the stars. - - Notes - ----- - The observed positions are obtained by concatenating the positions of stars - from both the training and test datasets along the 0th axis. - """ - obs_positions = np.concatenate( - ( - data.training_data.dataset["positions"], - data.test_data.dataset["positions"], - ), - axis=0, - ) - - return obs_positions - - def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: - """Extract specific star-related data from training and test datasets. + """ + Extract and concatenate star-related data from training and test datasets. - This function retrieves and concatenates specific star-related data (e.g., stars, masks) from the - star training and test datasets such as star stamps or masks, based on the provided keys. + This function retrieves arrays (e.g., postage stamps, masks, positions) from + both the training and test datasets using the specified keys, converts them + to NumPy if necessary, and concatenates them along the first axis. Parameters ---------- data : DataConfigHandler Object containing training and test datasets. train_key : str - The key to retrieve data from the training dataset (e.g., 'noisy_stars', 'masks'). + Key to retrieve data from the training dataset + (e.g., 'noisy_stars', 'masks'). test_key : str - The key to retrieve data from the test dataset (e.g., 'stars', 'masks'). + Key to retrieve data from the test dataset + (e.g., 'stars', 'masks'). Returns ------- np.ndarray - A NumPy array containing the concatenated data for the given keys. + Concatenated NumPy array containing the selected data from both + training and test sets. Raises ------ KeyError - If the specified keys do not exist in the training or test datasets. + If either the training or test dataset does not contain the + requested key. Notes ----- - - If the dataset contains TensorFlow tensors, they will be converted to NumPy arrays. - - Ensure that eager execution is enabled when calling this function. + - Designed for datasets with separate train/test splits, such as when + evaluating metrics on held-out data. + - TensorFlow tensors are automatically converted to NumPy arrays. + - Requires eager execution if TensorFlow tensors are present. """ # Ensure the requested keys exist in both training and test datasets missing_keys = [ @@ -327,3 +303,145 @@ def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: # Concatenate and return return np.concatenate((train_data, test_data), axis=0) + +def get_data_array( + data, + run_type: str, + key: str = None, + train_key: str = None, + test_key: str = None, + allow_missing: bool = False, +) -> np.ndarray | None: + """ + Retrieve data from dataset depending on run type. + + This function provides a unified interface for accessing data across different + execution contexts (training, simulation, metrics, inference). It handles + key resolution with sensible fallbacks and optional missing data tolerance. + + Parameters + ---------- + data : DataConfigHandler + Dataset object containing training, test, or inference data. + Expected to have methods compatible with the specified run_type. + run_type : {"training", "simulation", "metrics", "inference"} + Execution context that determines how data is retrieved: + - "training", "simulation", "metrics": Uses extract_star_data function + - "inference": Retrieves data directly from dataset using key lookup + key : str, optional + Primary key for data lookup. Used directly for inference run_type. + If None, falls back to train_key value. Default is None. + train_key : str, optional + Key for training dataset access. If None and key is provided, + defaults to key value. Default is None. + test_key : str, optional + Key for test dataset access. If None, defaults to the resolved + train_key value. Default is None. + allow_missing : bool, default False + Control behavior when data is missing or keys are not found: + - True: Return None instead of raising exceptions + - False: Raise appropriate exceptions (KeyError, ValueError) + + Returns + ------- + np.ndarray or None + Retrieved data as NumPy array. Returns None only when allow_missing=True + and the requested data is not available. + + Raises + ------ + ValueError + If run_type is not one of the supported values, or if no key can be + resolved for the operation and allow_missing=False. + KeyError + If the specified key is not found in the dataset and allow_missing=False. + + Notes + ----- + Key resolution follows this priority order: + 1. train_key = train_key or key + 2. test_key = test_key or resolved_train_key + 3. key = key or resolved_train_key (for inference fallback) + + For TensorFlow tensors, the .numpy() method is called to convert to NumPy. + Other data types are converted using np.asarray(). + + Examples + -------- + >>> # Training data retrieval + >>> train_data = get_data_array(data, "training", train_key="noisy_stars") + + >>> # Inference with fallback handling + >>> inference_data = get_data_array(data, "inference", key="positions", + ... allow_missing=True) + >>> if inference_data is None: + ... print("No inference data available") + + >>> # Using key parameter for both train and inference + >>> result = get_data_array(data, "inference", key="positions") + """ + # Validate run_type early + valid_run_types = {"training", "simulation", "metrics", "inference"} + if run_type not in valid_run_types: + raise ValueError(f"run_type must be one of {valid_run_types}, got '{run_type}'") + + # Simplify key resolution with clear precedence + effective_train_key = train_key or key + effective_test_key = test_key or effective_train_key + effective_key = key or effective_train_key + + try: + if run_type in {"simulation", "training", "metrics"}: + return extract_star_data(data, effective_train_key, effective_test_key) + else: # inference + return _get_inference_data(data, effective_key, allow_missing) + except Exception as e: + if allow_missing: + return None + raise + + +def _get_inference_data(data, key: str, allow_missing: bool) -> np.ndarray | None: + """ + Extract inference data with proper error handling and type conversion. + + Parameters + ---------- + data : DataConfigHandler + Dataset object with a .dataset attribute that supports .get() method. + key : str or None + Key to lookup in the dataset. If None, behavior depends on allow_missing. + allow_missing : bool + If True, return None for missing keys/data instead of raising exceptions. + + Returns + ------- + np.ndarray or None + Data converted to NumPy array, or None if allow_missing=True and + data is unavailable. + + Raises + ------ + ValueError + If key is None and allow_missing=False. + KeyError + If key is not found in dataset and allow_missing=False. + + Notes + ----- + Conversion logic: + - TensorFlow tensors: Converted using .numpy() method + - Other types: Converted using np.asarray() + """ + if key is None: + if allow_missing: + return None + raise ValueError("No key provided for inference data") + + value = data.dataset.get(key, None) + if value is None: + if allow_missing: + return None + raise KeyError(f"Key '{key}' not found in inference dataset") + + return value.numpy() if tf.is_tensor(value) else np.asarray(value) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index f7653540..9bf9d1fd 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -16,6 +16,7 @@ import numpy as np import tensorflow as tf from wf_psf.data.centroids import compute_centroid_correction +from wf_psf.data.data_handler import get_data_array from wf_psf.instrument.ccd_misalignments import compute_ccd_misalignment from wf_psf.utils.read_config import RecursiveNamespace import logging @@ -54,29 +55,29 @@ def build( ------- ZernikeInputs """ - centroid_dataset = None - positions = None + centroid_dataset, positions = None, None if run_type in {"training", "simulation", "metrics"}: - centroid_dataset = data # Assuming data is a DataConfigHandler or similar object containing train and test datasets - positions = np.concatenate( - [ - data.training_data.dataset["positions"].numpy(), - data.test_data.dataset["positions"].numpy(), - ], - axis=0, - ) + stamps = get_data_array(data, run_type, train_key="noisy_stars", test_key="stars") + masks = get_data_array(data, run_type, key="masks", allow_missing=True) + centroid_dataset = {"stamps": stamps, "masks": masks} + + positions = get_data_array(data=data, run_type=run_type, key="positions") + if model_params.use_prior: if prior is not None: logger.warning( - "Zernike prior explicitly provided; ignoring dataset-based prior despite use_prior=True." + "Explicit prior provided; ignoring dataset-based prior." ) else: prior = get_np_zernike_prior(data) elif run_type == "inference": - centroid_dataset = None - positions = data.dataset["positions"].numpy() + stamps = get_data_array(data=data, run_type=run_type, key="sources") + masks = get_data_array(data, run_type, key="masks", allow_missing=True) + centroid_dataset = {"stamps": stamps, "masks": masks} + + positions = get_data_array(data=data, run_type=run_type, key="positions") if model_params.use_prior: # Try to extract prior from `data`, if present diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 76619f5a..a37e6ba7 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -97,7 +97,7 @@ class PSFInference: Spectral energy distributions (SEDs). """ - def __init__(self, inference_config_path: str, x_field=None, y_field=None, seds=None): + def __init__(self, inference_config_path: str, x_field=None, y_field=None, seds=None, sources=None, masks=None): self.inference_config_path = inference_config_path @@ -105,6 +105,8 @@ def __init__(self, inference_config_path: str, x_field=None, y_field=None, seds= self.x_field = x_field self.y_field = y_field self.seds = seds + self.sources = sources + self.masks = masks # Internal caches for lazy-loading self._config_handler = None @@ -157,7 +159,7 @@ def _prepare_dataset_for_inference(self): positions = self.get_positions() if positions is None: return None - return {"positions": positions} + return {"positions": positions, "sources": self.sources, "masks": self.masks} @property def data_handler(self): diff --git a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index e2302da5..c8086cb7 100644 --- a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -10,7 +10,7 @@ from typing import Optional import tensorflow as tf from tensorflow.python.keras.engine import data_adapter -from wf_psf.data.data_handler import get_np_obs_positions +from wf_psf.data.data_handler import get_data_array from wf_psf.data.data_zernike_utils import ( ZernikeInputsFactory, assemble_zernike_contributions, @@ -203,10 +203,7 @@ def save_nonparam_history(self) -> bool: def get_obs_pos(self): assert self.run_type in {"training", "simulation", "metrics", "inference"}, f"Unknown run_type: {self.run_type}" - if self.run_type in {"training", "simulation", "metrics"}: - raw_pos = get_np_obs_positions(self.data) - else: - raw_pos = self.data.dataset["positions"] + raw_pos = get_data_array(data=self.data, run_type=self.run_type, key="positions") obs_pos = ensure_tensor(raw_pos, dtype=tf.float32) From 8376fe1de78bdf31e23ed796ad6c4564d7dd17b2 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 4 Sep 2025 02:41:53 +0200 Subject: [PATCH 106/129] Add and options to inference_config.yaml (forgot to stage with previous commit) --- config/inference_conf.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/config/inference_conf.yaml b/config/inference_conf.yaml index c9d29cb8..927723c7 100644 --- a/config/inference_conf.yaml +++ b/config/inference_conf.yaml @@ -30,3 +30,8 @@ inference: # Dimension of the pixel PSF postage stamp output_dim: 64 + # Flag to perform centroid error correction + correct_centroids: False + + # Flag to perform CCD misalignment error correction + add_ccd_misalignments: True From 77434be4fa9eb4d43f8ccafdcb4424e7cd72292b Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 4 Sep 2025 04:35:23 +0200 Subject: [PATCH 107/129] Update PSFInference doc string with new optional attributes --- src/wf_psf/inference/psf_inference.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index a37e6ba7..5c27fdbf 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -95,6 +95,10 @@ class PSFInference: y coordinates in SHE convention. seds : array-like, optional Spectral energy distributions (SEDs). + sources : array-like, optional + Postage stamps of sources, e.g. star images (shape: [n_stars, h, w]). + masks : array-like, optional + Corresponding masks for the sources (same shape as sources). Defaults to None. """ def __init__(self, inference_config_path: str, x_field=None, y_field=None, seds=None, sources=None, masks=None): From 76a449c009cfedce003afd755881467a230033f6 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 4 Sep 2025 05:44:34 +0200 Subject: [PATCH 108/129] Rename _get_inference_data to _get_direct_data --- src/wf_psf/data/data_handler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index bd802763..c0ddf7b0 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -394,16 +394,16 @@ def get_data_array( if run_type in {"simulation", "training", "metrics"}: return extract_star_data(data, effective_train_key, effective_test_key) else: # inference - return _get_inference_data(data, effective_key, allow_missing) + return _get_direct_data(data, effective_key, allow_missing) except Exception as e: if allow_missing: return None raise -def _get_inference_data(data, key: str, allow_missing: bool) -> np.ndarray | None: +def _get_direct_data(data, key: str, allow_missing: bool) -> np.ndarray | None: """ - Extract inference data with proper error handling and type conversion. + Extract data directly with proper error handling and type conversion. Parameters ---------- From b0874e76622948cd5dc3ff8499caeb445dbc11af Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 5 Sep 2025 10:37:27 -0400 Subject: [PATCH 109/129] Reformat with black --- src/wf_psf/__init__.py | 6 +- src/wf_psf/data/centroids.py | 29 ++- src/wf_psf/data/data_zernike_utils.py | 6 +- src/wf_psf/inference/psf_inference.py | 84 ++++---- .../psf_model_physical_polychromatic.py | 125 ++++++------ src/wf_psf/psf_models/psf_model_loader.py | 16 +- src/wf_psf/psf_models/tf_modules/tf_layers.py | 2 +- src/wf_psf/psf_models/tf_modules/tf_utils.py | 42 ++-- .../masked_loss/results/plot_results.ipynb | 164 ++++++++++++---- src/wf_psf/tests/test_data/test_data_utils.py | 22 ++- .../test_inference/psf_inference_test.py | 183 +++++++++++------- .../psf_model_physical_polychromatic_test.py | 39 ++-- .../tests/test_psf_models/psf_models_test.py | 2 +- src/wf_psf/tests/test_utils/utils_test.py | 42 ++-- src/wf_psf/training/train.py | 17 +- 15 files changed, 480 insertions(+), 299 deletions(-) diff --git a/src/wf_psf/__init__.py b/src/wf_psf/__init__.py index d4394f09..988b02fe 100644 --- a/src/wf_psf/__init__.py +++ b/src/wf_psf/__init__.py @@ -2,6 +2,6 @@ # Dynamically import modules to trigger side effects when wf_psf is imported importlib.import_module("wf_psf.psf_models.psf_models") -importlib.import_module("wf_psf.psf_models.psf_model_semiparametric") -importlib.import_module("wf_psf.psf_models.psf_model_physical_polychromatic") -importlib.import_module("wf_psf.psf_models.tf_psf_field") +importlib.import_module("wf_psf.psf_models.models.psf_model_semiparametric") +importlib.import_module("wf_psf.psf_models.models.psf_model_physical_polychromatic") +importlib.import_module("wf_psf.psf_models.tf_modules.tf_psf_field") diff --git a/src/wf_psf/data/centroids.py b/src/wf_psf/data/centroids.py index fb69c8dd..20391793 100644 --- a/src/wf_psf/data/centroids.py +++ b/src/wf_psf/data/centroids.py @@ -14,7 +14,9 @@ from typing import Optional -def compute_centroid_correction(model_params, centroid_dataset, batch_size: int=1) -> np.ndarray: +def compute_centroid_correction( + model_params, centroid_dataset, batch_size: int = 1 +) -> np.ndarray: """Compute centroid corrections using Zernike polynomials. This function calculates the Zernike contributions required to match the centroid @@ -31,15 +33,15 @@ def compute_centroid_correction(model_params, centroid_dataset, batch_size: int= Array of star postage stamps (required). - "masks" : Optional[np.ndarray] Array of star masks (optional, can be None). - + batch_size : int, optional The batch size to use when processing the stars. Default is 16. Returns ------- zernike_centroid_array : np.ndarray - A 2D NumPy array of shape `(n_stars, 3)`, where `n_stars` is the number of - observed stars. The array contains the computed Zernike (Z1, Z2) contributions, + A 2D NumPy array of shape `(n_stars, 3)`, where `n_stars` is the number of + observed stars. The array contains the computed Zernike (Z1, Z2) contributions, with zero padding applied to the first column to ensure a consistent shape. """ # Retrieve stamps and masks from centroid_dataset @@ -51,15 +53,17 @@ def compute_centroid_correction(model_params, centroid_dataset, batch_size: int= pix_sampling = model_params.pix_sampling * 1e-6 # Change units from [um] to [m] - reference_shifts = [float(Fraction(value)) for value in model_params.reference_shifts] + reference_shifts = [ + float(Fraction(value)) for value in model_params.reference_shifts + ] n_stars = len(star_postage_stamps) zernike_centroid_array = [] # Batch process the stars for i in range(0, n_stars, batch_size): - batch_postage_stamps = star_postage_stamps[i:i + batch_size] - batch_masks = star_masks[i:i + batch_size] if star_masks is not None else None + batch_postage_stamps = star_postage_stamps[i : i + batch_size] + batch_masks = star_masks[i : i + batch_size] if star_masks is not None else None # Compute Zernike 1 and Zernike 2 for the batch zk1_2_batch = -1.0 * compute_zernike_tip_tilt( @@ -67,11 +71,19 @@ def compute_centroid_correction(model_params, centroid_dataset, batch_size: int= ) # Zero pad array for each batch and append - zernike_centroid_array.append(np.pad(zk1_2_batch, pad_width=[(0, 0), (1, 0)], mode="constant", constant_values=0)) + zernike_centroid_array.append( + np.pad( + zk1_2_batch, + pad_width=[(0, 0), (1, 0)], + mode="constant", + constant_values=0, + ) + ) # Combine all batches into a single array return np.concatenate(zernike_centroid_array, axis=0) + def compute_zernike_tip_tilt( star_images: np.ndarray, star_masks: Optional[np.ndarray] = None, @@ -119,6 +131,7 @@ def compute_zernike_tip_tilt( - The Zernike coefficients are computed in the WaveDiff convention. """ from wf_psf.data.data_zernike_utils import shift_x_y_to_zk1_2_wavediff + # Vectorize the centroid computation centroid_estimator = CentroidEstimator( im=star_images, mask=star_masks, sigma_init=sigma_init, n_iter=n_iter diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index 9bf9d1fd..399ee9ef 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -58,12 +58,14 @@ def build( centroid_dataset, positions = None, None if run_type in {"training", "simulation", "metrics"}: - stamps = get_data_array(data, run_type, train_key="noisy_stars", test_key="stars") + stamps = get_data_array( + data, run_type, train_key="noisy_stars", test_key="stars" + ) masks = get_data_array(data, run_type, key="masks", allow_missing=True) centroid_dataset = {"stamps": stamps, "masks": masks} positions = get_data_array(data=data, run_type=run_type, key="positions") - + if model_params.use_prior: if prior is not None: logger.warning( diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 5c27fdbf..c7d73249 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -28,7 +28,6 @@ def __init__(self, inference_config_path: str): self.training_config = None self.data_config = None - def load_configs(self): """Load configuration files based on the inference config.""" self.inference_config = read_conf(self.inference_config_path) @@ -39,7 +38,6 @@ def load_configs(self): # Load the data configuration self.data_config = read_conf(self.data_config_path) - def set_config_paths(self): """Extract and set the configuration paths.""" # Set config paths @@ -47,10 +45,11 @@ def set_config_paths(self): self.trained_model_path = Path(config_paths.trained_model_path) self.model_subdir = config_paths.model_subdir - self.trained_model_config_path = self.trained_model_path / config_paths.trained_model_config_path + self.trained_model_config_path = ( + self.trained_model_path / config_paths.trained_model_config_path + ) self.data_config_path = config_paths.data_config_path - @staticmethod def overwrite_model_params(training_config=None, inference_config=None): """ @@ -75,7 +74,6 @@ def overwrite_model_params(training_config=None, inference_config=None): if hasattr(model_params, key): setattr(model_params, key, value) - class PSFInference: """ @@ -101,7 +99,15 @@ class PSFInference: Corresponding masks for the sources (same shape as sources). Defaults to None. """ - def __init__(self, inference_config_path: str, x_field=None, y_field=None, seds=None, sources=None, masks=None): + def __init__( + self, + inference_config_path: str, + x_field=None, + y_field=None, + seds=None, + sources=None, + masks=None, + ): self.inference_config_path = inference_config_path @@ -111,7 +117,7 @@ def __init__(self, inference_config_path: str, x_field=None, y_field=None, seds= self.seds = seds self.sources = sources self.masks = masks - + # Internal caches for lazy-loading self._config_handler = None self._simPSF = None @@ -123,7 +129,7 @@ def __init__(self, inference_config_path: str, x_field=None, y_field=None, seds= self._output_dim = None # Initialise PSF Inference engine - self.engine = None + self.engine = None @property def config_handler(self): @@ -157,7 +163,6 @@ def simPSF(self): self._simPSF = psf_models.simPSF(self.training_config.training.model_params) return self._simPSF - def _prepare_dataset_for_inference(self): """Prepare dataset dictionary for inference, returning None if positions are invalid.""" positions = self.get_positions() @@ -167,7 +172,7 @@ def _prepare_dataset_for_inference(self): @property def data_handler(self): - if self._data_handler is None: + if self._data_handler is None: # Instantiate the data handler self._data_handler = DataHandler( dataset_type="inference", @@ -176,7 +181,7 @@ def data_handler(self): n_bins_lambda=self.n_bins_lambda, load_data=False, dataset=self._prepare_dataset_for_inference(), - sed_data = self.seds, + sed_data=self.seds, ) self._data_handler.run_type = "inference" return self._data_handler @@ -190,13 +195,13 @@ def trained_psf_model(self): def get_positions(self): """ Combine x_field and y_field into position pairs. - + Returns ------- numpy.ndarray Array of shape (num_positions, 2) where each row contains [x, y] coordinates. Returns None if either x_field or y_field is None. - + Raises ------ ValueError @@ -204,25 +209,27 @@ def get_positions(self): """ if self.x_field is None or self.y_field is None: return None - + x_arr = np.asarray(self.x_field) y_arr = np.asarray(self.y_field) - + if x_arr.size == 0 or y_arr.size == 0: return None if x_arr.size != y_arr.size: - raise ValueError(f"x_field and y_field must have the same length. " - f"Got {x_arr.size} and {y_arr.size}") - + raise ValueError( + f"x_field and y_field must have the same length. " + f"Got {x_arr.size} and {y_arr.size}" + ) + # Flatten arrays to handle any input shape, then stack x_flat = x_arr.flatten() y_flat = y_arr.flatten() - + return np.column_stack((x_flat, y_flat)) def load_inference_model(self): - """Load the trained PSF model based on the inference configuration.""" + """Load the trained PSF model based on the inference configuration.""" model_path = self.config_handler.trained_model_path model_dir = self.config_handler.model_subdir model_name = self.training_config.training.model_params.model_name @@ -231,9 +238,9 @@ def load_inference_model(self): weights_path_pattern = os.path.join( model_path, model_dir, - f"{model_dir}*_{model_name}*{id_name}_cycle{self.cycle}*" + f"{model_dir}*_{model_name}*{id_name}_cycle{self.cycle}*", ) - + # Load the trained PSF model return load_trained_psf_model( self.training_config, @@ -244,7 +251,9 @@ def load_inference_model(self): @property def n_bins_lambda(self): if self._n_bins_lambda is None: - self._n_bins_lambda = self.inference_config.inference.model_params.n_bins_lda + self._n_bins_lambda = ( + self.inference_config.inference.model_params.n_bins_lda + ) return self._n_bins_lambda @property @@ -279,8 +288,10 @@ def _prepare_positions_and_seds(self): y_arr = np.atleast_1d(self.y_field) if x_arr.size != y_arr.size: - raise ValueError(f"x_field and y_field must have the same length. " - f"Got {x_arr.size} and {y_arr.size}") + raise ValueError( + f"x_field and y_field must have the same length. " + f"Got {x_arr.size} and {y_arr.size}" + ) # Combine into positions array (n_samples, 2) positions = np.column_stack((x_arr, y_arr)) @@ -288,12 +299,16 @@ def _prepare_positions_and_seds(self): # Ensure SEDs have shape (n_samples, n_bins, 2) sed_data = ensure_batch(self.seds) - + if sed_data.shape[0] != positions.shape[0]: - raise ValueError(f"SEDs batch size {sed_data.shape[0]} does not match number of positions {positions.shape[0]}") + raise ValueError( + f"SEDs batch size {sed_data.shape[0]} does not match number of positions {positions.shape[0]}" + ) if sed_data.shape[2] != 2: - raise ValueError(f"SEDs last dimension must be 2 (flux, wavelength). Got {sed_data.shape}") + raise ValueError( + f"SEDs last dimension must be 2 (flux, wavelength). Got {sed_data.shape}" + ) # Process SEDs through the data handler self.data_handler.process_sed_data(sed_data) @@ -301,7 +316,6 @@ def _prepare_positions_and_seds(self): return positions, sed_data_tensor - def run_inference(self): """Run PSF inference and return the full PSF array.""" # Prepare the configuration for inference @@ -332,7 +346,7 @@ def get_psf(self, index: int = 0) -> np.ndarray: If only a single star was passed during instantiation, the index defaults to 0. """ self._ensure_psf_inference_completed() - + inferred_psfs = self.engine.get_psfs() # If a single-star batch, ignore index bounds @@ -358,7 +372,9 @@ def inferred_psfs(self) -> np.ndarray: def compute_psfs(self, positions: tf.Tensor, sed_data: tf.Tensor) -> np.ndarray: """Compute and cache PSFs for the input source parameters.""" n_samples = positions.shape[0] - self._inferred_psfs = np.zeros((n_samples, self.output_dim, self.output_dim), dtype=np.float32) + self._inferred_psfs = np.zeros( + (n_samples, self.output_dim, self.output_dim), dtype=np.float32 + ) # Initialize counter counter = 0 @@ -370,14 +386,14 @@ def compute_psfs(self, positions: tf.Tensor, sed_data: tf.Tensor) -> np.ndarray: batch_pos = positions[counter:end_sample, :] batch_seds = sed_data[counter:end_sample, :, :] batch_inputs = [batch_pos, batch_seds] - + # Generate PSFs for the current batch batch_psfs = self.trained_model(batch_inputs, training=False) self.inferred_psfs[counter:end_sample, :, :] = batch_psfs.numpy() # Update the counter counter = end_sample - + return self._inferred_psfs def get_psfs(self) -> np.ndarray: @@ -391,5 +407,3 @@ def get_psf(self, index: int) -> np.ndarray: if self._inferred_psfs is None: raise ValueError("PSFs not yet computed. Call compute_psfs() first.") return self._inferred_psfs[index] - - diff --git a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index c8086cb7..e2c84868 100644 --- a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -12,9 +12,9 @@ from tensorflow.python.keras.engine import data_adapter from wf_psf.data.data_handler import get_data_array from wf_psf.data.data_zernike_utils import ( - ZernikeInputsFactory, + ZernikeInputsFactory, assemble_zernike_contributions, - pad_tf_zernikes + pad_tf_zernikes, ) from wf_psf.psf_models import psf_models as psfm from wf_psf.psf_models.tf_modules.tf_layers import ( @@ -124,7 +124,7 @@ def __init__(self, model_params, training_params, data, coeff_mat=None): self.output_Q = model_params.output_Q self.l2_param = model_params.param_hparams.l2_param self.output_dim = model_params.output_dim - + # Initialise lazy loading of external Zernike prior self._external_prior = None @@ -134,25 +134,26 @@ def __init__(self, model_params, training_params, data, coeff_mat=None): # Compute contributions once eagerly (outside graph) zks_total_contribution_np = self._assemble_zernike_contributions().numpy() - self._zks_total_contribution = tf.convert_to_tensor(zks_total_contribution_np, dtype=tf.float32) - + self._zks_total_contribution = tf.convert_to_tensor( + zks_total_contribution_np, dtype=tf.float32 + ) + # Compute n_zks_total as int self._n_zks_total = max( self.model_params.param_hparams.n_zernikes, - zks_total_contribution_np.shape[1] + zks_total_contribution_np.shape[1], ) - - # Precompute zernike maps as tf.float32 + + # Precompute zernike maps as tf.float32 self._zernike_maps = psfm.generate_zernike_maps_3d( - n_zernikes=self._n_zks_total, - pupil_diam=self.model_params.pupil_diameter - ) + n_zernikes=self._n_zks_total, pupil_diam=self.model_params.pupil_diameter + ) - # Precompute OPD dimension + # Precompute OPD dimension self._opd_dim = self._zernike_maps.shape[1] # Precompute obscurations as tf.complex64 - self._obscurations = psfm.tf_obscurations( + self._obscurations = psfm.tf_obscurations( pupil_diam=self.model_params.pupil_diameter, N_filter=self.model_params.LP_filter_length, rotation_angle=self.model_params.obscuration_rotation_angle, @@ -164,10 +165,10 @@ def __init__(self, model_params, training_params, data, coeff_mat=None): _ = self.tf_np_poly_opd def _get_run_type(self, data): - if hasattr(data, 'run_type'): + if hasattr(data, "run_type"): run_type = data.run_type - elif isinstance(data, dict) and 'run_type' in data: - run_type = data['run_type'] + elif isinstance(data, dict) and "run_type" in data: + run_type = data["run_type"] else: raise ValueError("data must have a 'run_type' attribute or key") @@ -193,17 +194,28 @@ def _assemble_zernike_contributions(self): @property def save_param_history(self) -> bool: """Check if the model should save the optimization history for parametric features.""" - return getattr(self.model_params.param_hparams, "save_optim_history_param", False) - + return getattr( + self.model_params.param_hparams, "save_optim_history_param", False + ) + @property def save_nonparam_history(self) -> bool: """Check if the model should save the optimization history for non-parametric features.""" - return getattr(self.model_params.nonparam_hparams, "save_optim_history_nonparam", False) + return getattr( + self.model_params.nonparam_hparams, "save_optim_history_nonparam", False + ) def get_obs_pos(self): - assert self.run_type in {"training", "simulation", "metrics", "inference"}, f"Unknown run_type: {self.run_type}" - - raw_pos = get_data_array(data=self.data, run_type=self.run_type, key="positions") + assert self.run_type in { + "training", + "simulation", + "metrics", + "inference", + }, f"Unknown run_type: {self.run_type}" + + raw_pos = get_data_array( + data=self.data, run_type=self.run_type, key="positions" + ) obs_pos = ensure_tensor(raw_pos, dtype=tf.float32) @@ -213,7 +225,7 @@ def get_obs_pos(self): @property def zks_total_contribution(self): return self._zks_total_contribution - + @property def n_zks_total(self): """Get the total number of Zernike coefficients.""" @@ -254,40 +266,40 @@ def tf_physical_layer(self): """Lazy loading of the physical layer of the PSF model.""" if not hasattr(self, "_tf_physical_layer"): self._tf_physical_layer = TFPhysicalLayer( - self.obs_pos, - self.zks_total_contribution, - interpolation_type=self.model_params.interpolation_type, - interpolation_args=self.model_params.interpolation_args, - ) + self.obs_pos, + self.zks_total_contribution, + interpolation_type=self.model_params.interpolation_type, + interpolation_args=self.model_params.interpolation_args, + ) return self._tf_physical_layer - + @property def tf_zernike_OPD(self): """Lazy loading of the Zernike Optical Path Difference (OPD) layer.""" if not hasattr(self, "_tf_zernike_OPD"): self._tf_zernike_OPD = TFZernikeOPD(zernike_maps=self.zernike_maps) return self._tf_zernike_OPD - + def _build_tf_batch_poly_PSF(self): """Eagerly build the TFBatchPolychromaticPSF layer with numpy-based obscurations.""" return TFBatchPolychromaticPSF( - obscurations=self.obscurations, - output_Q=self.output_Q, - output_dim=self.output_dim, - ) + obscurations=self.obscurations, + output_Q=self.output_Q, + output_dim=self.output_dim, + ) @property def tf_np_poly_opd(self): """Lazy loading of the non-parametric polynomial variations OPD layer.""" if not hasattr(self, "_tf_np_poly_opd"): self._tf_np_poly_opd = TFNonParametricPolynomialVariationsOPD( - x_lims=self.model_params.x_lims, - y_lims=self.model_params.y_lims, - random_seed=self.model_params.param_hparams.random_seed, - d_max=self.model_params.nonparam_hparams.d_max_nonparam, - opd_dim=self.opd_dim, - ) + x_lims=self.model_params.x_lims, + y_lims=self.model_params.y_lims, + random_seed=self.model_params.param_hparams.random_seed, + d_max=self.model_params.nonparam_hparams.d_max_nonparam, + opd_dim=self.opd_dim, + ) return self._tf_np_poly_opd def get_coeff_matrix(self): @@ -313,7 +325,6 @@ def assign_coeff_matrix(self, coeff_mat: Optional[tf.Tensor]) -> None: """ self.tf_poly_Z_field.assign_coeff_matrix(coeff_mat) - def set_output_Q(self, output_Q: float, output_dim: Optional[int] = None) -> None: """Set the output sampling rate (output_Q) for PSF generation. @@ -453,16 +464,16 @@ def predict_step(self, data, evaluate_step=False): # Compute zernikes from parametric model and physical layer zks_coeffs = self.predict_zernikes(input_positions) - + # Propagate to obtain the OPD param_opd_maps = self.tf_zernike_OPD(zks_coeffs) - + # Calculate the non parametric part nonparam_opd_maps = self.tf_np_poly_opd(input_positions) - + # Add the estimations opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) - + # Compute the polychromatic PSFs poly_psfs = self.tf_batch_poly_PSF([opd_maps, packed_SEDs]) @@ -505,13 +516,13 @@ def predict_mono_psfs(self, input_positions, lambda_obs, phase_N): # Predict zernikes from parametric model and physical layer zks_coeffs = self.predict_zernikes(input_positions) - + # Propagate to obtain the OPD param_opd_maps = self.tf_zernike_OPD(zks_coeffs) - + # Calculate the non parametric part nonparam_opd_maps = self.tf_np_poly_opd(input_positions) - + # Add the estimations opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) @@ -536,13 +547,13 @@ def predict_opd(self, input_positions): """ # Predict zernikes from parametric model and physical layer zks_coeffs = self.predict_zernikes(input_positions) - + # Propagate to obtain the OPD param_opd_maps = self.tf_zernike_OPD(zks_coeffs) - + # Calculate the non parametric part nonparam_opd_maps = self.tf_np_poly_opd(input_positions) - + # Add the estimations opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) @@ -677,27 +688,25 @@ def call(self, inputs, training=True): packed_SEDs = inputs[1] # For the training - if training: + if training: # Compute zernikes from parametric model and physical layer zks_coeffs = self.compute_zernikes(input_positions) - + # Parametric OPD maps from Zernikes param_opd_maps = self.tf_zernike_OPD(zks_coeffs) # Add L2 regularization loss on parametric OPD maps - self.add_loss( - self.l2_param * tf.reduce_sum(tf.square(param_opd_maps)) - ) + self.add_loss(self.l2_param * tf.reduce_sum(tf.square(param_opd_maps))) # Non-parametric correction nonparam_opd_maps = self.tf_np_poly_opd(input_positions) # Combine both contributions opd_maps = tf.add(param_opd_maps, nonparam_opd_maps) - + # Compute the polychromatic PSFs poly_psfs = self.tf_batch_poly_PSF([opd_maps, packed_SEDs]) - + # For the inference else: # Compute predictions diff --git a/src/wf_psf/psf_models/psf_model_loader.py b/src/wf_psf/psf_models/psf_model_loader.py index c30d31ad..e445f7af 100644 --- a/src/wf_psf/psf_models/psf_model_loader.py +++ b/src/wf_psf/psf_models/psf_model_loader.py @@ -7,15 +7,14 @@ Author: Jennifer Pollack """ + import logging -from wf_psf.psf_models.psf_models import ( - get_psf_model, - get_psf_model_weights_filepath -) +from wf_psf.psf_models.psf_models import get_psf_model, get_psf_model_weights_filepath import tensorflow as tf logger = logging.getLogger(__name__) + def load_trained_psf_model(training_conf, data_conf, weights_path_pattern): """ Loads a trained PSF model and applies saved weights. @@ -40,9 +39,11 @@ def load_trained_psf_model(training_conf, data_conf, weights_path_pattern): RuntimeError If loading the model weights fails for any reason. """ - model = get_psf_model(training_conf.training.model_params, - training_conf.training.training_hparams, - data_conf) + model = get_psf_model( + training_conf.training.model_params, + training_conf.training.training_hparams, + data_conf, + ) weights_path = get_psf_model_weights_filepath(weights_path_pattern) @@ -55,4 +56,3 @@ def load_trained_psf_model(training_conf, data_conf, weights_path_pattern): logger.exception("Failed to load model weights.") raise RuntimeError("Model weight loading failed.") from e return model - diff --git a/src/wf_psf/psf_models/tf_modules/tf_layers.py b/src/wf_psf/psf_models/tf_modules/tf_layers.py index 5bb5e2df..7b22b5e8 100644 --- a/src/wf_psf/psf_models/tf_modules/tf_layers.py +++ b/src/wf_psf/psf_models/tf_modules/tf_layers.py @@ -17,6 +17,7 @@ logger = logging.getLogger(__name__) + class TFPolynomialZernikeField(tf.keras.layers.Layer): """Calculate the zernike coefficients for a given position. @@ -964,7 +965,6 @@ def interpolate_independent_Zk(self, positions): return interp_zks[:, :, tf.newaxis, tf.newaxis] - def call(self, positions): """Calculate the prior Zernike coefficients for a batch of positions. diff --git a/src/wf_psf/psf_models/tf_modules/tf_utils.py b/src/wf_psf/psf_models/tf_modules/tf_utils.py index d0a2002c..09540e60 100644 --- a/src/wf_psf/psf_models/tf_modules/tf_utils.py +++ b/src/wf_psf/psf_models/tf_modules/tf_utils.py @@ -1,15 +1,15 @@ """TensorFlow Utilities Module. -Provides lightweight utility functions for safely converting and managing data types +Provides lightweight utility functions for safely converting and managing data types within TensorFlow-based workflows. Includes: - `ensure_tensor`: ensures inputs are TensorFlow tensors with specified dtype -These tools are designed to support PSF model components, including lazy property evaluation, +These tools are designed to support PSF model components, including lazy property evaluation, data input validation, and type normalization. -This module is intended for internal use in model layers and inference components to enforce +This module is intended for internal use in model layers and inference components to enforce TensorFlow-compatible inputs. Authors: Jennifer Pollack @@ -22,68 +22,69 @@ @tf.function def find_position_indices(obs_pos, batch_positions): """Find indices of batch positions within observed positions using vectorized operations. - - This function locates the indices of multiple query positions within a + + This function locates the indices of multiple query positions within a reference set of observed positions using broadcasting and vectorized operations. Each position in the batch must have an exact match in the observed positions. Parameters ---------- obs_pos : tf.Tensor - Reference positions tensor of shape (n_obs, 2), where n_obs is the number of + Reference positions tensor of shape (n_obs, 2), where n_obs is the number of observed positions. Each row contains [x, y] coordinates. batch_positions : tf.Tensor - Query positions tensor of shape (batch_size, 2), where batch_size is the number + Query positions tensor of shape (batch_size, 2), where batch_size is the number of positions to look up. Each row contains [x, y] coordinates. - + Returns ------- indices : tf.Tensor - Tensor of shape (batch_size,) containing the indices of each batch position + Tensor of shape (batch_size,) containing the indices of each batch position within obs_pos. The dtype is tf.int64. - + Raises ------ tf.errors.InvalidArgumentError If any position in batch_positions is not found in obs_pos. - + Notes ----- - Uses exact equality matching - positions must match exactly. More efficient than + Uses exact equality matching - positions must match exactly. More efficient than iterative lookups for multiple positions due to vectorized operations. """ # Shape: obs_pos (n_obs, 2), batch_positions (batch_size, 2) # Expand for broadcasting: (1, n_obs, 2) and (batch_size, 1, 2) obs_expanded = tf.expand_dims(obs_pos, 0) pos_expanded = tf.expand_dims(batch_positions, 1) - + # Compare all positions at once: (batch_size, n_obs) matches = tf.reduce_all(tf.equal(obs_expanded, pos_expanded), axis=2) - + # Find the index of the matching position for each batch item - # argmax returns the first True value's index along axis=1 + # argmax returns the first True value's index along axis=1 indices = tf.argmax(tf.cast(matches, tf.int32), axis=1) - + # Verify all positions were found tf.debugging.assert_equal( tf.reduce_all(tf.reduce_any(matches, axis=1)), True, - message="Some positions not found in obs_pos" + message="Some positions not found in obs_pos", ) - + return indices + def ensure_tensor(input_array, dtype=tf.float32): """ Ensure the input is a TensorFlow tensor of the specified dtype. - + Parameters ---------- input_array : array-like, tf.Tensor, or np.ndarray The input to convert. dtype : tf.DType, optional The desired TensorFlow dtype (default: tf.float32). - + Returns ------- tf.Tensor @@ -97,4 +98,3 @@ def ensure_tensor(input_array, dtype=tf.float32): else: # Convert numpy arrays or other types to tensor return tf.convert_to_tensor(input_array, dtype=dtype) - diff --git a/src/wf_psf/tests/data/validation/masked_loss/results/plot_results.ipynb b/src/wf_psf/tests/data/validation/masked_loss/results/plot_results.ipynb index bec994a4..c36a9e31 100644 --- a/src/wf_psf/tests/data/validation/masked_loss/results/plot_results.ipynb +++ b/src/wf_psf/tests/data/validation/masked_loss/results/plot_results.ipynb @@ -17,19 +17,19 @@ "outputs": [], "source": [ "# Trained on masked data, tested on masked data\n", - "metrics_path_mm = '../runs/masked_train_masked_test/wf-outputs/wf-outputs-202503131718/metrics/metrics-polymask_train_mask_test.npy'\n", + "metrics_path_mm = \"../runs/masked_train_masked_test/wf-outputs/wf-outputs-202503131718/metrics/metrics-polymask_train_mask_test.npy\"\n", "mask_train_mask_test = np.load(metrics_path_mm, allow_pickle=True)[()]\n", "\n", "# Trained on masked data, tested on unmasked data\n", - "metrics_path_mu = '../runs/masked_train_unit_mask_test/wf-outputs/wf-outputs-202503131720/metrics/metrics-polymasked_train_unit_mask_test.npy'\n", + "metrics_path_mu = \"../runs/masked_train_unit_mask_test/wf-outputs/wf-outputs-202503131720/metrics/metrics-polymasked_train_unit_mask_test.npy\"\n", "mask_train_nomask_test = np.load(metrics_path_mu, allow_pickle=True)[()]\n", "\n", "# Trained on unmasked data, tested on unmasked data\n", - "metrics_path_c = '../runs/control_train/wf-outputs/wf-outputs-202503131716/metrics/metrics-polycontrol_train.npy'\n", + "metrics_path_c = \"../runs/control_train/wf-outputs/wf-outputs-202503131716/metrics/metrics-polycontrol_train.npy\"\n", "control_train = np.load(metrics_path_c, allow_pickle=True)[()]\n", "\n", "# Trained and tested with unitary masks\n", - "metrics_path_u = '../runs/unit_masked_train/wf-outputs/wf-outputs-202503131721/metrics/metrics-polyunit_masked_train.npy'\n", + "metrics_path_u = \"../runs/unit_masked_train/wf-outputs/wf-outputs-202503131721/metrics/metrics-polyunit_masked_train.npy\"\n", "unitary = np.load(metrics_path_u, allow_pickle=True)[()]" ] }, @@ -50,8 +50,8 @@ ], "source": [ "print(mask_train_mask_test.keys())\n", - "print(mask_train_mask_test['test_metrics'].keys())\n", - "print(mask_train_mask_test['test_metrics']['poly_metric'].keys())" + "print(mask_train_mask_test[\"test_metrics\"].keys())\n", + "print(mask_train_mask_test[\"test_metrics\"][\"poly_metric\"].keys())" ] }, { @@ -60,17 +60,25 @@ "metadata": {}, "outputs": [], "source": [ - "mask_test_mask_test_rel_rmse = mask_train_mask_test['test_metrics']['poly_metric']['rel_rmse']\n", - "mask_test_mask_test_std_rel_rmse = mask_train_mask_test['test_metrics']['poly_metric']['std_rel_rmse']\n", + "mask_test_mask_test_rel_rmse = mask_train_mask_test[\"test_metrics\"][\"poly_metric\"][\n", + " \"rel_rmse\"\n", + "]\n", + "mask_test_mask_test_std_rel_rmse = mask_train_mask_test[\"test_metrics\"][\"poly_metric\"][\n", + " \"std_rel_rmse\"\n", + "]\n", "\n", - "mask_test_nomask_test_rel_rmse = mask_train_nomask_test['test_metrics']['poly_metric']['rel_rmse']\n", - "mask_test_nomask_test_std_rel_rmse = mask_train_nomask_test['test_metrics']['poly_metric']['std_rel_rmse']\n", + "mask_test_nomask_test_rel_rmse = mask_train_nomask_test[\"test_metrics\"][\"poly_metric\"][\n", + " \"rel_rmse\"\n", + "]\n", + "mask_test_nomask_test_std_rel_rmse = mask_train_nomask_test[\"test_metrics\"][\n", + " \"poly_metric\"\n", + "][\"std_rel_rmse\"]\n", "\n", - "control_test_rel_rmse = control_train['test_metrics']['poly_metric']['rel_rmse']\n", - "control_test_std_rel_rmse = control_train['test_metrics']['poly_metric']['std_rel_rmse']\n", + "control_test_rel_rmse = control_train[\"test_metrics\"][\"poly_metric\"][\"rel_rmse\"]\n", + "control_test_std_rel_rmse = control_train[\"test_metrics\"][\"poly_metric\"][\"std_rel_rmse\"]\n", "\n", - "unitary_test_rel_rmse = unitary['test_metrics']['poly_metric']['rel_rmse']\n", - "unitary_test_std_rel_rmse = unitary['test_metrics']['poly_metric']['std_rel_rmse']" + "unitary_test_rel_rmse = unitary[\"test_metrics\"][\"poly_metric\"][\"rel_rmse\"]\n", + "unitary_test_std_rel_rmse = unitary[\"test_metrics\"][\"poly_metric\"][\"std_rel_rmse\"]" ] }, { @@ -92,12 +100,29 @@ "source": [ "# Plot the results\n", "fig, ax = plt.subplots(figsize=(8, 4))\n", - "plt.title('Relative RMSE 1x - Test dataset')\n", - "ax.errorbar([0, 1, 2, 3], [control_test_rel_rmse, mask_test_mask_test_rel_rmse, mask_test_nomask_test_rel_rmse, unitary_test_rel_rmse], yerr=[control_test_std_rel_rmse, mask_test_mask_test_std_rel_rmse, mask_test_nomask_test_std_rel_rmse, unitary_test_std_rel_rmse], fmt='o')\n", + "plt.title(\"Relative RMSE 1x - Test dataset\")\n", + "ax.errorbar(\n", + " [0, 1, 2, 3],\n", + " [\n", + " control_test_rel_rmse,\n", + " mask_test_mask_test_rel_rmse,\n", + " mask_test_nomask_test_rel_rmse,\n", + " unitary_test_rel_rmse,\n", + " ],\n", + " yerr=[\n", + " control_test_std_rel_rmse,\n", + " mask_test_mask_test_std_rel_rmse,\n", + " mask_test_nomask_test_std_rel_rmse,\n", + " unitary_test_std_rel_rmse,\n", + " ],\n", + " fmt=\"o\",\n", + ")\n", "ax.set_xticks([0, 1, 2, 3])\n", - "ax.set_xticklabels(['Control Train', 'Mask Train Mask Test', 'Mask Train Unit Mask Test', 'Unitary'])\n", - "ax.grid('minor')\n", - "ax.set_ylabel('Relative RMSE')\n", + "ax.set_xticklabels(\n", + " [\"Control Train\", \"Mask Train Mask Test\", \"Mask Train Unit Mask Test\", \"Unitary\"]\n", + ")\n", + "ax.grid(\"minor\")\n", + "ax.set_ylabel(\"Relative RMSE\")\n", "plt.show()" ] }, @@ -107,17 +132,27 @@ "metadata": {}, "outputs": [], "source": [ - "mask_train_mask_test_rel_rmse = mask_train_mask_test['train_metrics']['poly_metric']['rel_rmse']\n", - "mask_train_mask_test_std_rel_rmse = mask_train_mask_test['train_metrics']['poly_metric']['std_rel_rmse']\n", + "mask_train_mask_test_rel_rmse = mask_train_mask_test[\"train_metrics\"][\"poly_metric\"][\n", + " \"rel_rmse\"\n", + "]\n", + "mask_train_mask_test_std_rel_rmse = mask_train_mask_test[\"train_metrics\"][\n", + " \"poly_metric\"\n", + "][\"std_rel_rmse\"]\n", "\n", - "mask_train_nomask_test_rel_rmse = mask_train_nomask_test['train_metrics']['poly_metric']['rel_rmse']\n", - "mask_train_nomask_test_std_rel_rmse = mask_train_nomask_test['train_metrics']['poly_metric']['std_rel_rmse']\n", + "mask_train_nomask_test_rel_rmse = mask_train_nomask_test[\"train_metrics\"][\n", + " \"poly_metric\"\n", + "][\"rel_rmse\"]\n", + "mask_train_nomask_test_std_rel_rmse = mask_train_nomask_test[\"train_metrics\"][\n", + " \"poly_metric\"\n", + "][\"std_rel_rmse\"]\n", "\n", - "control_train_rel_rmse = control_train['train_metrics']['poly_metric']['rel_rmse']\n", - "control_train_std_rel_rmse = control_train['train_metrics']['poly_metric']['std_rel_rmse']\n", + "control_train_rel_rmse = control_train[\"train_metrics\"][\"poly_metric\"][\"rel_rmse\"]\n", + "control_train_std_rel_rmse = control_train[\"train_metrics\"][\"poly_metric\"][\n", + " \"std_rel_rmse\"\n", + "]\n", "\n", - "unitary_rel_rmse = unitary['train_metrics']['poly_metric']['rel_rmse']\n", - "unitary_std_rel_rmse = unitary['train_metrics']['poly_metric']['std_rel_rmse']" + "unitary_rel_rmse = unitary[\"train_metrics\"][\"poly_metric\"][\"rel_rmse\"]\n", + "unitary_std_rel_rmse = unitary[\"train_metrics\"][\"poly_metric\"][\"std_rel_rmse\"]" ] }, { @@ -139,12 +174,29 @@ "source": [ "# Plot the results\n", "fig, ax = plt.subplots(figsize=(8, 4))\n", - "plt.title('Relative RMSE 1x - Train dataset')\n", - "ax.errorbar([0, 1, 2, 3], [control_train_rel_rmse, mask_train_mask_test_rel_rmse, mask_train_nomask_test_rel_rmse, unitary_rel_rmse], yerr=[control_train_std_rel_rmse, mask_train_mask_test_std_rel_rmse, mask_train_nomask_test_std_rel_rmse, unitary_std_rel_rmse], fmt='o')\n", + "plt.title(\"Relative RMSE 1x - Train dataset\")\n", + "ax.errorbar(\n", + " [0, 1, 2, 3],\n", + " [\n", + " control_train_rel_rmse,\n", + " mask_train_mask_test_rel_rmse,\n", + " mask_train_nomask_test_rel_rmse,\n", + " unitary_rel_rmse,\n", + " ],\n", + " yerr=[\n", + " control_train_std_rel_rmse,\n", + " mask_train_mask_test_std_rel_rmse,\n", + " mask_train_nomask_test_std_rel_rmse,\n", + " unitary_std_rel_rmse,\n", + " ],\n", + " fmt=\"o\",\n", + ")\n", "ax.set_xticks([0, 1, 2, 3])\n", - "ax.set_xticklabels(['Control Train', 'Mask Train Mask Test', 'Mask Train Unit Mask Test', 'Unitary'])\n", - "ax.set_ylabel('Relative RMSE')\n", - "ax.grid('minor')\n", + "ax.set_xticklabels(\n", + " [\"Control Train\", \"Mask Train Mask Test\", \"Mask Train Unit Mask Test\", \"Unitary\"]\n", + ")\n", + "ax.set_ylabel(\"Relative RMSE\")\n", + "ax.grid(\"minor\")\n", "plt.show()" ] }, @@ -167,16 +219,50 @@ "source": [ "# Plot test and train relative RMSE in the same plot\n", "fig, ax = plt.subplots(figsize=(8, 4))\n", - "plt.title('Relative RMSE 1x - Train and Test dataset')\n", - "ax.errorbar([0, 1, 2, 3], [control_train_rel_rmse, mask_train_mask_test_rel_rmse, mask_train_nomask_test_rel_rmse, unitary_rel_rmse], yerr=[control_train_std_rel_rmse, mask_train_mask_test_std_rel_rmse, mask_train_nomask_test_std_rel_rmse, unitary_std_rel_rmse], fmt='o', label='Train')\n", - "ax.errorbar([0.02, 1.02, 2.02, 3.02], [control_test_rel_rmse, mask_test_mask_test_rel_rmse, mask_test_nomask_test_rel_rmse, unitary_test_rel_rmse], yerr=[control_test_std_rel_rmse, mask_test_mask_test_std_rel_rmse, mask_test_nomask_test_std_rel_rmse, unitary_test_std_rel_rmse], fmt='o', label='Test')\n", + "plt.title(\"Relative RMSE 1x - Train and Test dataset\")\n", + "ax.errorbar(\n", + " [0, 1, 2, 3],\n", + " [\n", + " control_train_rel_rmse,\n", + " mask_train_mask_test_rel_rmse,\n", + " mask_train_nomask_test_rel_rmse,\n", + " unitary_rel_rmse,\n", + " ],\n", + " yerr=[\n", + " control_train_std_rel_rmse,\n", + " mask_train_mask_test_std_rel_rmse,\n", + " mask_train_nomask_test_std_rel_rmse,\n", + " unitary_std_rel_rmse,\n", + " ],\n", + " fmt=\"o\",\n", + " label=\"Train\",\n", + ")\n", + "ax.errorbar(\n", + " [0.02, 1.02, 2.02, 3.02],\n", + " [\n", + " control_test_rel_rmse,\n", + " mask_test_mask_test_rel_rmse,\n", + " mask_test_nomask_test_rel_rmse,\n", + " unitary_test_rel_rmse,\n", + " ],\n", + " yerr=[\n", + " control_test_std_rel_rmse,\n", + " mask_test_mask_test_std_rel_rmse,\n", + " mask_test_nomask_test_std_rel_rmse,\n", + " unitary_test_std_rel_rmse,\n", + " ],\n", + " fmt=\"o\",\n", + " label=\"Test\",\n", + ")\n", "ax.set_xticks([0, 1, 2, 3])\n", - "ax.set_xticklabels(['Control Train', 'Mask Train Mask Test', 'Mask Train Unit Mask Test', 'Unitary'])\n", - "ax.set_ylabel('Relative RMSE')\n", - "ax.grid('minor')\n", + "ax.set_xticklabels(\n", + " [\"Control Train\", \"Mask Train Mask Test\", \"Mask Train Unit Mask Test\", \"Unitary\"]\n", + ")\n", + "ax.set_ylabel(\"Relative RMSE\")\n", + "ax.grid(\"minor\")\n", "plt.legend()\n", "# plt.show()\n", - "plt.savefig('masked_loss_validation.pdf')\n" + "plt.savefig(\"masked_loss_validation.pdf\")" ] }, { diff --git a/src/wf_psf/tests/test_data/test_data_utils.py b/src/wf_psf/tests/test_data/test_data_utils.py index 1ebc00cb..de111427 100644 --- a/src/wf_psf/tests/test_data/test_data_utils.py +++ b/src/wf_psf/tests/test_data/test_data_utils.py @@ -1,8 +1,13 @@ - class MockDataset: def __init__(self, positions, zernike_priors, star_type, stars, masks): - self.dataset = {"positions": positions, "zernike_prior": zernike_priors, star_type: stars, "masks": masks} - + self.dataset = { + "positions": positions, + "zernike_prior": zernike_priors, + star_type: stars, + "masks": masks, + } + + class MockData: def __init__( self, @@ -16,15 +21,16 @@ def __init__( masks=None, ): self.training_data = MockDataset( - positions=training_positions, + positions=training_positions, zernike_priors=training_zernike_priors, star_type="noisy_stars", stars=noisy_stars, - masks=noisy_masks) + masks=noisy_masks, + ) self.test_data = MockDataset( - positions=test_positions, + positions=test_positions, zernike_priors=test_zernike_priors, star_type="stars", stars=stars, - masks=masks) - + masks=masks, + ) diff --git a/src/wf_psf/tests/test_inference/psf_inference_test.py b/src/wf_psf/tests/test_inference/psf_inference_test.py index ec9f0495..05728de7 100644 --- a/src/wf_psf/tests/test_inference/psf_inference_test.py +++ b/src/wf_psf/tests/test_inference/psf_inference_test.py @@ -14,13 +14,14 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch, PropertyMock from wf_psf.inference.psf_inference import ( - InferenceConfigHandler, + InferenceConfigHandler, PSFInference, - PSFInferenceEngine + PSFInferenceEngine, ) from wf_psf.utils.read_config import RecursiveNamespace + def _patch_data_handler(): """Helper for patching data_handler to avoid full PSF logic.""" patcher = patch.object(PSFInference, "data_handler", new_callable=PropertyMock) @@ -34,6 +35,7 @@ def fake_process(x): mock_instance.process_sed_data.side_effect = fake_process return patcher, mock_instance + @pytest.fixture def mock_training_config(): training_config = RecursiveNamespace( @@ -60,13 +62,13 @@ def mock_training_config(): LP_filter_length=3, param_hparams=RecursiveNamespace( n_zernikes=10, - - ) - ) - ) + ), + ), + ) ) return training_config + @pytest.fixture def mock_inference_config(): inference_config = RecursiveNamespace( @@ -74,16 +76,12 @@ def mock_inference_config(): batch_size=16, cycle=2, configs=RecursiveNamespace( - trained_model_path='/path/to/trained/model', - model_subdir='psf_model', - trained_model_config_path='config/training_config.yaml', - data_config_path=None - ), - model_params=RecursiveNamespace( - n_bins_lda=8, - output_Q=1, - output_dim=64 + trained_model_path="/path/to/trained/model", + model_subdir="psf_model", + trained_model_config_path="config/training_config.yaml", + data_config_path=None, ), + model_params=RecursiveNamespace(n_bins_lda=8, output_Q=1, output_dim=64), ) ) return inference_config @@ -96,14 +94,18 @@ def psf_test_setup(mock_inference_config): output_dim = 32 mock_positions = tf.convert_to_tensor([[0.1, 0.1], [0.2, 0.2]], dtype=tf.float32) - mock_seds = tf.convert_to_tensor(np.random.rand(num_sources, num_bins, 2), dtype=tf.float32) - expected_psfs = np.random.rand(num_sources, output_dim, output_dim).astype(np.float32) + mock_seds = tf.convert_to_tensor( + np.random.rand(num_sources, num_bins, 2), dtype=tf.float32 + ) + expected_psfs = np.random.rand(num_sources, output_dim, output_dim).astype( + np.float32 + ) inference = PSFInference( "dummy_path.yaml", x_field=[0.1, 0.2], y_field=[0.1, 0.2], - seds=np.random.rand(num_sources, num_bins, 2) + seds=np.random.rand(num_sources, num_bins, 2), ) inference._config_handler = MagicMock() inference._config_handler.inference_config = mock_inference_config @@ -116,9 +118,10 @@ def psf_test_setup(mock_inference_config): "expected_psfs": expected_psfs, "num_sources": num_sources, "num_bins": num_bins, - "output_dim": output_dim + "output_dim": output_dim, } + @pytest.fixture def psf_single_star_setup(mock_inference_config): num_sources = 1 @@ -128,14 +131,18 @@ def psf_single_star_setup(mock_inference_config): # Single position mock_positions = tf.convert_to_tensor([[0.1, 0.1]], dtype=tf.float32) # Shape (1, 2, num_bins) - mock_seds = tf.convert_to_tensor(np.random.rand(num_sources, 2, num_bins), dtype=tf.float32) - expected_psfs = np.random.rand(num_sources, output_dim, output_dim).astype(np.float32) + mock_seds = tf.convert_to_tensor( + np.random.rand(num_sources, 2, num_bins), dtype=tf.float32 + ) + expected_psfs = np.random.rand(num_sources, output_dim, output_dim).astype( + np.float32 + ) inference = PSFInference( "dummy_path.yaml", - x_field=0.1, # scalar for single star + x_field=0.1, # scalar for single star y_field=0.1, - seds=np.random.rand(num_bins, 2) # shape (num_bins, 2) before batching + seds=np.random.rand(num_bins, 2), # shape (num_bins, 2) before batching ) inference._config_handler = MagicMock() inference._config_handler.inference_config = mock_inference_config @@ -148,7 +155,7 @@ def psf_single_star_setup(mock_inference_config): "expected_psfs": expected_psfs, "num_sources": num_sources, "num_bins": num_bins, - "output_dim": output_dim + "output_dim": output_dim, } @@ -165,7 +172,9 @@ def test_set_config_paths(mock_inference_config): # Assertions assert config_handler.trained_model_path == Path("/path/to/trained/model") assert config_handler.model_subdir == "psf_model" - assert config_handler.trained_model_config_path == Path("/path/to/trained/model/config/training_config.yaml") + assert config_handler.trained_model_config_path == Path( + "/path/to/trained/model/config/training_config.yaml" + ) assert config_handler.data_config_path == None @@ -175,15 +184,19 @@ def test_overwrite_model_params(mock_training_config, mock_inference_config): training_config = mock_training_config inference_config = mock_inference_config - InferenceConfigHandler.overwrite_model_params( - training_config, inference_config - ) + InferenceConfigHandler.overwrite_model_params(training_config, inference_config) # Assert that the model_params were overwritten correctly - assert training_config.training.model_params.output_Q == 1, "output_Q should be overwritten" - assert training_config.training.model_params.output_dim == 64, "output_dim should be overwritten" - - assert training_config.training.id_name == "mock_id", "id_name should not be overwritten" + assert ( + training_config.training.model_params.output_Q == 1 + ), "output_Q should be overwritten" + assert ( + training_config.training.model_params.output_dim == 64 + ), "output_dim should be overwritten" + + assert ( + training_config.training.id_name == "mock_id" + ), "id_name should not be overwritten" def test_prepare_configs(mock_training_config, mock_inference_config): @@ -196,7 +209,7 @@ def test_prepare_configs(mock_training_config, mock_inference_config): original_model_params = mock_training_config.training.model_params # Instantiate PSFInference - psf_inf = PSFInference('/dummy/path.yaml') + psf_inf = PSFInference("/dummy/path.yaml") # Mock the config handler attribute with a mock InferenceConfigHandler mock_config_handler = MagicMock(spec=InferenceConfigHandler) @@ -204,7 +217,9 @@ def test_prepare_configs(mock_training_config, mock_inference_config): mock_config_handler.inference_config = inference_config # Patch the overwrite_model_params to use the real static method - mock_config_handler.overwrite_model_params.side_effect = InferenceConfigHandler.overwrite_model_params + mock_config_handler.overwrite_model_params.side_effect = ( + InferenceConfigHandler.overwrite_model_params + ) psf_inf._config_handler = mock_config_handler @@ -223,30 +238,43 @@ def test_config_handler_lazy_load(monkeypatch): class DummyHandler: def load_configs(self): - called['load'] = True + called["load"] = True self.inference_config = {} self.training_config = {} self.data_config = {} - def overwrite_model_params(self, *args): pass - monkeypatch.setattr("wf_psf.inference.psf_inference.InferenceConfigHandler", lambda path: DummyHandler()) + def overwrite_model_params(self, *args): + pass + + monkeypatch.setattr( + "wf_psf.inference.psf_inference.InferenceConfigHandler", + lambda path: DummyHandler(), + ) inference.prepare_configs() - assert 'load' in called # Confirm lazy load happened + assert "load" in called # Confirm lazy load happened + def test_batch_size_positive(): inference = PSFInference("dummy_path.yaml") inference._config_handler = MagicMock() inference._config_handler.inference_config = SimpleNamespace( - inference=SimpleNamespace(batch_size=4, model_params=SimpleNamespace(output_dim=32)) + inference=SimpleNamespace( + batch_size=4, model_params=SimpleNamespace(output_dim=32) + ) ) assert inference.batch_size == 4 -@patch('wf_psf.inference.psf_inference.DataHandler') -@patch('wf_psf.inference.psf_inference.load_trained_psf_model') -def test_load_inference_model(mock_load_trained_psf_model, mock_data_handler, mock_training_config, mock_inference_config): +@patch("wf_psf.inference.psf_inference.DataHandler") +@patch("wf_psf.inference.psf_inference.load_trained_psf_model") +def test_load_inference_model( + mock_load_trained_psf_model, + mock_data_handler, + mock_training_config, + mock_inference_config, +): mock_data_config = MagicMock() mock_data_handler.return_value = mock_data_config mock_config_handler = MagicMock(spec=InferenceConfigHandler) @@ -255,29 +283,33 @@ def test_load_inference_model(mock_load_trained_psf_model, mock_data_handler, mo mock_config_handler.inference_config = mock_inference_config mock_config_handler.model_subdir = "psf_model" mock_config_handler.data_config = MagicMock() - + psf_inf = PSFInference("dummy_path.yaml") psf_inf._config_handler = mock_config_handler psf_inf.load_inference_model() weights_path_pattern = os.path.join( - mock_config_handler.trained_model_path, - mock_config_handler.model_subdir, - f"{mock_config_handler.model_subdir}*_{mock_config_handler.training_config.training.model_params.model_name}*{mock_config_handler.training_config.training.id_name}_cycle{mock_config_handler.inference_config.inference.cycle}*" - ) + mock_config_handler.trained_model_path, + mock_config_handler.model_subdir, + f"{mock_config_handler.model_subdir}*_{mock_config_handler.training_config.training.model_params.model_name}*{mock_config_handler.training_config.training.id_name}_cycle{mock_config_handler.inference_config.inference.cycle}*", + ) # Assert calls to the mocked methods mock_load_trained_psf_model.assert_called_once_with( - mock_training_config, - mock_data_config, - weights_path_pattern + mock_training_config, mock_data_config, weights_path_pattern ) -@patch.object(PSFInference, 'prepare_configs') -@patch.object(PSFInference, '_prepare_positions_and_seds') -@patch.object(PSFInferenceEngine, 'compute_psfs') -def test_run_inference(mock_compute_psfs, mock_prepare_positions_and_seds, mock_prepare_configs, psf_test_setup): + +@patch.object(PSFInference, "prepare_configs") +@patch.object(PSFInference, "_prepare_positions_and_seds") +@patch.object(PSFInferenceEngine, "compute_psfs") +def test_run_inference( + mock_compute_psfs, + mock_prepare_positions_and_seds, + mock_prepare_configs, + psf_test_setup, +): inference = psf_test_setup["inference"] mock_positions = psf_test_setup["mock_positions"] mock_seds = psf_test_setup["mock_seds"] @@ -294,8 +326,11 @@ def test_run_inference(mock_compute_psfs, mock_prepare_positions_and_seds, mock_ mock_compute_psfs.assert_called_once_with(mock_positions, mock_seds) mock_prepare_configs.assert_called_once() + @patch("wf_psf.inference.psf_inference.psf_models.simPSF") -def test_simpsf_uses_updated_model_params(mock_simpsf, mock_training_config, mock_inference_config): +def test_simpsf_uses_updated_model_params( + mock_simpsf, mock_training_config, mock_inference_config +): """Test that simPSF uses the updated model parameters.""" training_config = mock_training_config inference_config = mock_inference_config @@ -315,7 +350,7 @@ def test_simpsf_uses_updated_model_params(mock_simpsf, mock_training_config, moc mock_config_handler.inference_config = inference_config mock_config_handler.model_subdir = "psf_model" mock_config_handler.data_config = MagicMock() - + modeller = PSFInference("dummy_path.yaml") modeller._config_handler = mock_config_handler @@ -330,9 +365,11 @@ def test_simpsf_uses_updated_model_params(mock_simpsf, mock_training_config, moc assert result.output_Q == expected_output_Q -@patch.object(PSFInference, '_prepare_positions_and_seds') -@patch.object(PSFInferenceEngine, 'compute_psfs') -def test_get_psfs_runs_inference(mock_compute_psfs, mock_prepare_positions_and_seds, psf_test_setup): +@patch.object(PSFInference, "_prepare_positions_and_seds") +@patch.object(PSFInferenceEngine, "compute_psfs") +def test_get_psfs_runs_inference( + mock_compute_psfs, mock_prepare_positions_and_seds, psf_test_setup +): inference = psf_test_setup["inference"] mock_positions = psf_test_setup["mock_positions"] mock_seds = psf_test_setup["mock_seds"] @@ -355,7 +392,6 @@ def fake_compute_psfs(positions, seds): assert mock_compute_psfs.call_count == 1 - def test_single_star_inference_shape(psf_single_star_setup): setup = psf_single_star_setup @@ -374,9 +410,11 @@ def test_single_star_inference_shape(psf_single_star_setup): input_array = args[0] # Check input SED had the right shape before being tensorized - assert input_array.shape == (1, setup["num_bins"], 2), \ - "process_sed_data should have been called with shape (1, num_bins, 2)" - + assert input_array.shape == ( + 1, + setup["num_bins"], + 2, + ), "process_sed_data should have been called with shape (1, num_bins, 2)" def test_multiple_star_inference_shape(psf_test_setup): @@ -398,9 +436,12 @@ def test_multiple_star_inference_shape(psf_test_setup): input_array = args[0] # Check input SED had the right shape before being tensorized - assert input_array.shape == (2, setup["num_bins"], 2), \ - "process_sed_data should have been called with shape (2, num_bins, 2)" - + assert input_array.shape == ( + 2, + setup["num_bins"], + 2, + ), "process_sed_data should have been called with shape (2, num_bins, 2)" + def test_valueerror_on_mismatched_batches(psf_single_star_setup): """Raise if sed_data batch size != positions batch size and sed_data != 1.""" @@ -416,7 +457,9 @@ def test_valueerror_on_mismatched_batches(psf_single_star_setup): inference.seds = bad_sed inference.positions = np.ones((1, 2), dtype=np.float32) - with pytest.raises(ValueError, match="SEDs batch size 2 does not match number of positions 1"): + with pytest.raises( + ValueError, match="SEDs batch size 2 does not match number of positions 1" + ): inference._prepare_positions_and_seds() finally: patcher.stop() @@ -435,7 +478,9 @@ def test_valueerror_on_mismatched_positions(psf_single_star_setup): inference.x_field = np.ones((3, 1), dtype=np.float32) inference.y_field = np.ones((3, 1), dtype=np.float32) - with pytest.raises(ValueError, match="SEDs batch size 2 does not match number of positions 3"): + with pytest.raises( + ValueError, match="SEDs batch size 2 does not match number of positions 3" + ): inference._prepare_positions_and_seds() finally: - patcher.stop() \ No newline at end of file + patcher.stop() diff --git a/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py b/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py index 95ad6287..e900a6d3 100644 --- a/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py +++ b/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py @@ -32,16 +32,16 @@ def zks_prior(): def mock_data(mocker, zks_prior): mock_instance = mocker.Mock(spec=DataConfigHandler) mock_instance.run_type = "training" - + training_dataset = { "positions": np.array([[1, 2], [3, 4]]), "zernike_prior": zks_prior, - "noisy_stars": np.zeros((2, 1, 1, 1)), + "noisy_stars": np.zeros((2, 1, 1, 1)), } test_dataset = { "positions": np.array([[5, 6], [7, 8]]), "zernike_prior": zks_prior, - "stars": np.zeros((2, 1, 1, 1)), + "stars": np.zeros((2, 1, 1, 1)), } mock_instance.training_data = mocker.Mock() @@ -60,44 +60,53 @@ def mock_model_params(mocker): model_params_mock.pupil_diameter = 256 return model_params_mock + @pytest.fixture def physical_layer_instance(mocker, mock_model_params, mock_data): # Patch expensive methods during construction to avoid errors - with patch("wf_psf.psf_models.models.psf_model_physical_polychromatic.TFPhysicalPolychromaticField._assemble_zernike_contributions", return_value=tf.constant([[[[1.0]]], [[[2.0]]]])): - from wf_psf.psf_models.models.psf_model_physical_polychromatic import TFPhysicalPolychromaticField - instance = TFPhysicalPolychromaticField(mock_model_params, mocker.Mock(), mock_data) + with patch( + "wf_psf.psf_models.models.psf_model_physical_polychromatic.TFPhysicalPolychromaticField._assemble_zernike_contributions", + return_value=tf.constant([[[[1.0]]], [[[2.0]]]]), + ): + from wf_psf.psf_models.models.psf_model_physical_polychromatic import ( + TFPhysicalPolychromaticField, + ) + + instance = TFPhysicalPolychromaticField( + mock_model_params, mocker.Mock(), mock_data + ) return instance + def test_compute_zernikes(mocker, physical_layer_instance): # Expected output of mock components - padded_zernike_param = tf.constant([[[[10]], [[20]], [[30]], [[40]]]], dtype=tf.float32) + padded_zernike_param = tf.constant( + [[[[10]], [[20]], [[30]], [[40]]]], dtype=tf.float32 + ) padded_zernike_prior = tf.constant([[[[1]], [[2]], [[0]], [[0]]]], dtype=tf.float32) n_zks_total = physical_layer_instance.n_zks_total expected_values_list = [11, 22, 30, 40] + [0] * (n_zks_total - 4) expected_values = tf.constant( - [[[[v]] for v in expected_values_list]], - dtype=tf.float32 -) + [[[[v]] for v in expected_values_list]], dtype=tf.float32 + ) # Patch tf_poly_Z_field method mocker.patch.object( TFPhysicalPolychromaticField, "tf_poly_Z_field", - return_value=padded_zernike_param + return_value=padded_zernike_param, ) # Patch tf_physical_layer.call method mock_tf_physical_layer = mocker.Mock() mock_tf_physical_layer.call.return_value = padded_zernike_prior mocker.patch.object( - TFPhysicalPolychromaticField, - "tf_physical_layer", - mock_tf_physical_layer + TFPhysicalPolychromaticField, "tf_physical_layer", mock_tf_physical_layer ) # Patch pad_tf_zernikes function mocker.patch( "wf_psf.data.data_zernike_utils.pad_tf_zernikes", - return_value=(padded_zernike_param, padded_zernike_prior) + return_value=(padded_zernike_param, padded_zernike_prior), ) # Run the test diff --git a/src/wf_psf/tests/test_psf_models/psf_models_test.py b/src/wf_psf/tests/test_psf_models/psf_models_test.py index b7c906f6..2b907eff 100644 --- a/src/wf_psf/tests/test_psf_models/psf_models_test.py +++ b/src/wf_psf/tests/test_psf_models/psf_models_test.py @@ -10,7 +10,7 @@ from wf_psf.psf_models import psf_models from wf_psf.psf_models.models import ( psf_model_semiparametric, - psf_model_physical_polychromatic + psf_model_physical_polychromatic, ) import tensorflow as tf import numpy as np diff --git a/src/wf_psf/tests/test_utils/utils_test.py b/src/wf_psf/tests/test_utils/utils_test.py index dacf0bdc..cc7f2a2b 100644 --- a/src/wf_psf/tests/test_utils/utils_test.py +++ b/src/wf_psf/tests/test_utils/utils_test.py @@ -20,11 +20,6 @@ from unittest import mock - -def test_sanity(): - assert 1 + 1 == 2 - - def test_downsample_basic(): """Test apply_mask when a zeroed mask is provided.""" img_dim = (10, 10) @@ -37,9 +32,9 @@ def test_downsample_basic(): # The result should be an array of False values, as the mask excludes all pixels expected_result = np.zeros(img_dim, dtype=bool) - assert np.array_equal(result, expected_result), ( - "apply_mask did not handle the zeroed mask correctly." - ) + assert np.array_equal( + result, expected_result + ), "apply_mask did not handle the zeroed mask correctly." def test_initialization(): @@ -121,9 +116,9 @@ def test_apply_mask_with_none_mask(): result = estimator.apply_mask(None) # Pass None as the mask # It should return the window itself when no mask is provided - assert np.array_equal(result, estimator.window), ( - "apply_mask should return the window when mask is None." - ) + assert np.array_equal( + result, estimator.window + ), "apply_mask should return the window when mask is None." def test_apply_mask_with_valid_mask(): @@ -139,9 +134,9 @@ def test_apply_mask_with_valid_mask(): # Check that the mask was applied correctly: pixel (5, 5) should be False, others True expected_result = estimator.window & custom_mask - assert np.array_equal(result, expected_result), ( - "apply_mask did not apply the mask correctly." - ) + assert np.array_equal( + result, expected_result + ), "apply_mask did not apply the mask correctly." def test_apply_mask_with_zeroed_mask(): @@ -156,9 +151,9 @@ def test_apply_mask_with_zeroed_mask(): # The result should be an array of False values, as the mask excludes all pixels expected_result = np.zeros(img_dim, dtype=bool) - assert np.array_equal(result, expected_result), ( - "apply_mask did not handle the zeroed mask correctly." - ) + assert np.array_equal( + result, expected_result + ), "apply_mask did not handle the zeroed mask correctly." def test_unobscured_zernike_projection(): @@ -252,6 +247,7 @@ def test_tf_decompose_obscured_opd_basis(): assert rmse_error < tol + def test_downsample_basic(): """Downsample a small array to a smaller square size.""" arr = np.arange(16).reshape(4, 4).astype(np.float32) @@ -262,9 +258,10 @@ def test_downsample_basic(): assert result.shape == (output_dim, output_dim), "Output shape mismatch" # Values should be averaged/downsampled; simple check - assert np.all(result >= arr.min()) and np.all(result <= arr.max()), \ - "Values outside input range" - + assert np.all(result >= arr.min()) and np.all( + result <= arr.max() + ), "Values outside input range" + def test_downsample_identity(): """Downsample to the same size should return same array (approximately).""" @@ -274,10 +271,12 @@ def test_downsample_identity(): # Since OpenCV / skimage may do minor interpolation, allow small tolerance np.testing.assert_allclose(result, arr, rtol=1e-6, atol=1e-6) + # ---------------------------- # Backend fallback tests # ---------------------------- + @mock.patch("wf_psf.utils.utils._HAS_CV2", False) @mock.patch("wf_psf.utils.utils._HAS_SKIMAGE", False) def test_downsample_no_backend(): @@ -296,10 +295,11 @@ def test_downsample_values_average(): # All output values should be close to input value np.testing.assert_allclose(result, 3.0, rtol=1e-6, atol=1e-6) + @mock.patch("wf_psf.utils.utils._HAS_CV2", True) def test_downsample_non_square_array(): """Check downsampling works for non-square arrays.""" arr = np.arange(12).reshape(3, 4).astype(np.float32) output_dim = 2 result = downsample_im(arr, output_dim) - assert result.shape == (2, 2) \ No newline at end of file + assert result.shape == (2, 2) diff --git a/src/wf_psf/training/train.py b/src/wf_psf/training/train.py index f2366ca7..ab2f0ac1 100644 --- a/src/wf_psf/training/train.py +++ b/src/wf_psf/training/train.py @@ -274,10 +274,7 @@ def _prepare_callbacks( def get_loss_metrics_monitor_and_outputs(training_handler, data_conf): - """Generate factory for loss, metrics, monitor, and outputs. - - A function to generate loss, metrics, monitor, and outputs - for training. + """Factory to return fresh loss, metrics (param & non-param), monitor, and outputs for the current cycle. Parameters ---------- @@ -370,12 +367,12 @@ def train( psf_model_dir : str Directory where the final trained PSF model weights will be saved per cycle. - Notes - ----- - - Utilizes TensorFlow and TensorFlow Addons for model training and optimization. - - Supports masked mean squared error loss for training with masked data. - - Allows for projection of data-driven features onto parametric models between cycles. - - Supports resetting of non-parametric features to initial states. + Returns + ------- + None + + Side Effects + ------------ - Saves model weights to `psf_model_dir` per training cycle (or final one if not all saved) - Saves optimizer histories to `optimizer_dir` - Logs cycle information and time durations From 0e85c0f198417063b1458ffc11dee9e07bd92648 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Fri, 5 Sep 2025 19:31:46 +0100 Subject: [PATCH 110/129] Correct type hint errors --- src/wf_psf/data/data_handler.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index c0ddf7b0..bdcf9a6b 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -241,7 +241,7 @@ def process_sed_data(self, sed_data): def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: - """ + """ Extract and concatenate star-related data from training and test datasets. This function retrieves arrays (e.g., postage stamps, masks, positions) from @@ -310,8 +310,8 @@ def get_data_array( key: str = None, train_key: str = None, test_key: str = None, - allow_missing: bool = False, -) -> np.ndarray | None: + allow_missing: bool = True, +) -> Optional[np.ndarray]: """ Retrieve data from dataset depending on run type. @@ -337,7 +337,7 @@ def get_data_array( test_key : str, optional Key for test dataset access. If None, defaults to the resolved train_key value. Default is None. - allow_missing : bool, default False + allow_missing : bool, default True Control behavior when data is missing or keys are not found: - True: Return None instead of raising exceptions - False: Raise appropriate exceptions (KeyError, ValueError) @@ -401,7 +401,7 @@ def get_data_array( raise -def _get_direct_data(data, key: str, allow_missing: bool) -> np.ndarray | None: +def _get_direct_data(data, key: str, allow_missing: bool) -> Optional[np.ndarray]: """ Extract data directly with proper error handling and type conversion. From 2bdd6225f509158dfe6480852948c766b4870b06 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Fri, 5 Sep 2025 19:32:16 +0100 Subject: [PATCH 111/129] Remove unused import --- src/wf_psf/instrument/ccd_misalignments.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/wf_psf/instrument/ccd_misalignments.py b/src/wf_psf/instrument/ccd_misalignments.py index f739c4fc..49bc8180 100644 --- a/src/wf_psf/instrument/ccd_misalignments.py +++ b/src/wf_psf/instrument/ccd_misalignments.py @@ -10,7 +10,6 @@ import numpy as np import matplotlib.path as mpltPath from scipy.spatial import KDTree -from wf_psf.data.data_handler import get_np_obs_positions def compute_ccd_misalignment(model_params, positions: np.ndarray) -> np.ndarray: From 6134d8e7cf7c6ed629b873c7f90ba0ae198dd7d8 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Fri, 5 Sep 2025 19:32:56 +0100 Subject: [PATCH 112/129] Replace call to deprecated get_np_obs_positions with get_data_array --- src/wf_psf/psf_models/tf_modules/tf_psf_field.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/wf_psf/psf_models/tf_modules/tf_psf_field.py b/src/wf_psf/psf_models/tf_modules/tf_psf_field.py index b2d1b53e..21c0f9a4 100644 --- a/src/wf_psf/psf_models/tf_modules/tf_psf_field.py +++ b/src/wf_psf/psf_models/tf_modules/tf_psf_field.py @@ -16,7 +16,7 @@ TFPhysicalLayer, ) from wf_psf.psf_models.models.psf_model_semiparametric import TFSemiParametricField -from wf_psf.data.data_handler import get_np_obs_positions +from wf_psf.data.data_handler import get_data_array from wf_psf.psf_models import psf_models as psfm from wf_psf.psf_models.tf_modules.tf_utils import ensure_tensor import logging @@ -222,7 +222,7 @@ def __init__(self, model_params, training_params, data, coeff_mat): self.output_Q = model_params.output_Q # Inputs: TF_physical_layer - self.obs_pos = ensure_tensor(get_np_obs_positions(data), dtype=tf.float32) + self.obs_pos = ensure_tensor(get_data_array(data, data.run_type, key="positions"), dtype=tf.float32) self.zks_prior = get_ground_truth_zernike(data) self.n_zks_prior = tf.shape(self.zks_prior)[1].numpy() From e309487a2963fde3c83063278ee1d5012d8229fa Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 31 Oct 2025 12:33:40 +0100 Subject: [PATCH 113/129] Remove unused imports and reformat --- src/wf_psf/data/centroids.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/wf_psf/data/centroids.py b/src/wf_psf/data/centroids.py index 20391793..01135428 100644 --- a/src/wf_psf/data/centroids.py +++ b/src/wf_psf/data/centroids.py @@ -8,9 +8,7 @@ import numpy as np import scipy.signal as scisig -from wf_psf.data.data_handler import extract_star_data from fractions import Fraction -import tensorflow as tf from typing import Optional @@ -252,7 +250,6 @@ def __init__( self, im, mask=None, sigma_init=7.5, n_iter=5, auto_run=True, xc=None, yc=None ): """Initialize class attributes.""" - # Convert to np.ndarray if not already im = np.asarray(im) if mask is not None: From 88325849ecdc4f7de49476b134135b9a71a1439c Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 31 Oct 2025 12:35:13 +0100 Subject: [PATCH 114/129] Update fixtures and unit tests --- src/wf_psf/tests/test_data/centroids_test.py | 38 ++--- src/wf_psf/tests/test_data/conftest.py | 16 +- .../tests/test_data/data_handler_test.py | 148 ++++++++++++++++-- .../test_data/data_zernike_utils_test.py | 30 +++- .../test_inference/psf_inference_test.py | 1 - 5 files changed, 192 insertions(+), 41 deletions(-) diff --git a/src/wf_psf/tests/test_data/centroids_test.py b/src/wf_psf/tests/test_data/centroids_test.py index 85719f4f..185da8d7 100644 --- a/src/wf_psf/tests/test_data/centroids_test.py +++ b/src/wf_psf/tests/test_data/centroids_test.py @@ -9,8 +9,6 @@ import numpy as np import pytest from wf_psf.data.centroids import compute_centroid_correction, CentroidEstimator -from wf_psf.data.data_handler import extract_star_data -from wf_psf.data.data_zernike_utils import compute_zernike_tip_tilt from wf_psf.utils.read_config import RecursiveNamespace from unittest.mock import MagicMock, patch @@ -116,24 +114,23 @@ def test_compute_centroid_correction_with_masks(mock_data): reference_shifts=["-1/3", "-1/3"], ) + # Wrap mock_data into a dict to match the function signature + centroid_dataset = { + "stamps": mock_data.training_data.dataset["noisy_stars"], + "masks": mock_data.training_data.dataset["masks"], + } + # Mock the internal function calls: with ( - patch("wf_psf.data.centroids.extract_star_data") as mock_extract_star_data, patch( "wf_psf.data.centroids.compute_zernike_tip_tilt" ) as mock_compute_zernike_tip_tilt, ): - - # Mock the return values of extract_star_data and compute_zernike_tip_tilt - mock_extract_star_data.side_effect = lambda data, train_key, test_key: ( - np.array([[1, 2], [3, 4]]) - if train_key == "noisy_stars" - else np.array([[5, 6], [7, 8]]) - ) + # Mock compute_zernike_tip_tilt to return synthetic Zernike coefficients mock_compute_zernike_tip_tilt.return_value = np.array([[0.1, 0.2], [0.3, 0.4]]) # Call the function under test - result = compute_centroid_correction(model_params, mock_data) + result = compute_centroid_correction(model_params, centroid_dataset) # Ensure the result has the correct shape assert result.shape == (4, 3) # Should be (n_stars, 3 Zernike components) @@ -148,10 +145,6 @@ def test_compute_centroid_correction_with_masks(mock_data): def test_compute_centroid_correction_without_masks(mock_data): """Test compute_centroid_correction function when no masks are provided.""" - # Remove masks from mock_data - mock_data.test_data.dataset["masks"] = None - mock_data.training_data.dataset["masks"] = None - # Define model parameters model_params = RecursiveNamespace( pix_sampling=12e-6, # Example pixel sampling in meters @@ -159,26 +152,23 @@ def test_compute_centroid_correction_without_masks(mock_data): reference_shifts=["-1/3", "-1/3"], ) + # Wrap mock_data into a dict to match the function signature + centroid_dataset = { + "stamps": mock_data.training_data.dataset["noisy_stars"], + } + # Mock internal function calls with ( - patch("wf_psf.data.centroids.extract_star_data") as mock_extract_star_data, patch( "wf_psf.data.centroids.compute_zernike_tip_tilt" ) as mock_compute_zernike_tip_tilt, ): - # Mock extract_star_data to return synthetic star postage stamps - mock_extract_star_data.side_effect = lambda data, train_key, test_key: ( - np.array([[1, 2], [3, 4]]) - if train_key == "noisy_stars" - else np.array([[5, 6], [7, 8]]) - ) - # Mock compute_zernike_tip_tilt assuming no masks mock_compute_zernike_tip_tilt.return_value = np.array([[0.1, 0.2], [0.3, 0.4]]) # Call function under test - result = compute_centroid_correction(model_params, mock_data) + result = compute_centroid_correction(model_params, centroid_dataset) # Validate result shape assert result.shape == (4, 3) # (n_stars, 3 Zernike components) diff --git a/src/wf_psf/tests/test_data/conftest.py b/src/wf_psf/tests/test_data/conftest.py index 47eed929..131922e5 100644 --- a/src/wf_psf/tests/test_data/conftest.py +++ b/src/wf_psf/tests/test_data/conftest.py @@ -11,9 +11,11 @@ import pytest import numpy as np import tensorflow as tf +from types import SimpleNamespace + from wf_psf.utils.read_config import RecursiveNamespace from wf_psf.psf_models import psf_models -from wf_psf.tests.test_data.test_data_utils import MockData, MockDataset +from wf_psf.tests.test_data.test_data_utils import MockData training_config = RecursiveNamespace( id_name="-coherent_euclid_200stars", @@ -129,6 +131,18 @@ def mock_data(scope="module"): ) +@pytest.fixture +def mock_data_inference(): + """Flat dataset for inference path only.""" + return SimpleNamespace( + dataset={ + "positions": np.array([[9, 9], [10, 10]]), + "zernike_prior": np.array([[0.9, 0.9]]), + # no "missing_key" → used to trigger allow_missing behavior + } + ) + + @pytest.fixture def simple_image(scope="module"): """Fixture for a simple star image.""" diff --git a/src/wf_psf/tests/test_data/data_handler_test.py b/src/wf_psf/tests/test_data/data_handler_test.py index c1310d21..d29771a1 100644 --- a/src/wf_psf/tests/test_data/data_handler_test.py +++ b/src/wf_psf/tests/test_data/data_handler_test.py @@ -3,12 +3,10 @@ import tensorflow as tf from wf_psf.data.data_handler import ( DataHandler, - get_np_obs_positions, + get_data_array, extract_star_data, ) from wf_psf.utils.read_config import RecursiveNamespace -import logging -from unittest.mock import patch def mock_sed(): @@ -151,12 +149,6 @@ def test_load_test_dataset_missing_stars(tmp_path, simPSF): data_handler.validate_and_process_dataset() -def test_get_np_obs_positions(mock_data): - observed_positions = get_np_obs_positions(mock_data) - expected_positions = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) - assert tf.reduce_all(tf.equal(observed_positions, expected_positions)) - - def test_extract_star_data_valid_keys(mock_data): """Test extracting valid data from the dataset.""" result = extract_star_data(mock_data, train_key="noisy_stars", test_key="stars") @@ -229,3 +221,141 @@ def test_reference_shifts_broadcasting(): # Test the result assert displacements.shape == shifts.shape, "Shapes do not match" assert np.all(displacements.shape == (2, 2400)), "Broadcasting failed" + + +@pytest.mark.parametrize( + "run_type,data_fixture,key,train_key,test_key,allow_missing,expect", + [ + # =================================================== + # training/simulation/metrics → extract_star_data path + # =================================================== + ( + "training", + "mock_data", + None, + "positions", + None, + False, + np.array([[1, 2], [3, 4], [5, 6], [7, 8]]), + ), + ( + "simulation", + "mock_data", + "none", + "noisy_stars", + "stars", + True, + # will concatenate noisy_stars from train and stars from test + # expected shape: (4, 5, 5) + # validate shape only, not full content (too large) + "shape:(4, 5, 5)", + ), + ( + "metrics", + "mock_data", + "zernike_prior", + None, + None, + True, + np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]]), + ), + # ================= + # inference (success) + # ================= + ( + "inference", + "mock_data_inference", + "positions", + None, + None, + False, + np.array([[9, 9], [10, 10]]), + ), + ( + "inference", + "mock_data_inference", + "zernike_prior", + None, + None, + False, + np.array([[0.9, 0.9]]), + ), + # ============================== + # inference → allow_missing=True + # ============================== + ( + "inference", + "mock_data_inference", + None, + None, + None, + True, + None, + ), + ( + "inference", + "mock_data_inference", + "missing_key", + None, + None, + True, + None, + ), + # ================================= + # inference → allow_missing=False → errors + # ================================= + ( + "inference", + "mock_data_inference", + None, + None, + None, + False, + pytest.raises(ValueError), + ), + ( + "inference", + "mock_data_inference", + "missing_key", + None, + None, + False, + pytest.raises(KeyError), + ), + ], +) +def test_get_data_array_v2( + request, run_type, data_fixture, key, train_key, test_key, allow_missing, expect +): + data = request.getfixturevalue(data_fixture) + + if hasattr(expect, "__enter__") and hasattr(expect, "__exit__"): + with expect: + get_data_array( + data, + run_type, + key=key, + train_key=train_key, + test_key=test_key, + allow_missing=allow_missing, + ) + return + + result = get_data_array( + data, + run_type, + key=key, + train_key=train_key, + test_key=test_key, + allow_missing=allow_missing, + ) + + if expect is None: + assert result is None + elif isinstance(expect, str) and expect.startswith("shape:"): + expected_shape = tuple(eval(expect.replace("shape:", ""))) + assert isinstance(result, np.ndarray) + assert result.shape == expected_shape + else: + assert isinstance(result, np.ndarray) + assert np.allclose(result, expect, rtol=1e-6, atol=1e-8) diff --git a/src/wf_psf/tests/test_data/data_zernike_utils_test.py b/src/wf_psf/tests/test_data/data_zernike_utils_test.py index 390d10f9..66d23309 100644 --- a/src/wf_psf/tests/test_data/data_zernike_utils_test.py +++ b/src/wf_psf/tests/test_data/data_zernike_utils_test.py @@ -11,7 +11,6 @@ compute_zernike_tip_tilt, pad_tf_zernikes, ) -from wf_psf.tests.test_data.test_data_utils import MockData from types import SimpleNamespace as RecursiveNamespace @@ -46,7 +45,27 @@ def test_training_without_prior(mock_model_params, mock_data): data=mock_data, run_type="training", model_params=mock_model_params ) - assert zinputs.centroid_dataset is mock_data + mock_data_stamps = np.concatenate( + [ + mock_data.training_data.dataset["noisy_stars"], + mock_data.test_data.dataset["stars"], + ] + ) + mock_data_masks = np.concatenate( + [ + mock_data.training_data.dataset["masks"], + mock_data.test_data.dataset["masks"], + ] + ) + + assert np.allclose( + zinputs.centroid_dataset["stamps"], mock_data_stamps, rtol=1e-6, atol=1e-8 + ) + + assert np.allclose( + zinputs.centroid_dataset["masks"], mock_data_masks, rtol=1e-6, atol=1e-8 + ) + assert zinputs.zernike_prior is None expected_positions = np.concatenate( @@ -88,7 +107,7 @@ def test_training_with_explicit_prior(mock_model_params, caplog): data, "training", mock_model_params, prior=explicit_prior ) - assert "Zernike prior explicitly provided" in caplog.text + assert "Explicit prior provided; ignoring dataset-based prior." in caplog.text assert (zinputs.zernike_prior == explicit_prior).all() @@ -103,7 +122,8 @@ def test_inference_with_dict_and_prior(mock_model_params): zinputs = ZernikeInputsFactory.build(data, "inference", mock_model_params) - assert zinputs.centroid_dataset is None + for key in ["stamps", "masks"]: + assert zinputs.centroid_dataset[key] is None # NumPy array comparison np.testing.assert_array_equal( @@ -397,7 +417,6 @@ def test_pad_zernikes_param_greater_than_prior(): def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_mask): """Test compute_zernike_tip_tilt handling with single batch input and mocks.""" - # Mock the CentroidEstimator class mock_centroid_calc = mocker.patch( "wf_psf.data.centroids.CentroidEstimator", autospec=True @@ -456,7 +475,6 @@ def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_ma def test_compute_zernike_tip_tilt_batch(mocker, multiple_images): """Test compute_zernike_tip_tilt batch handling of multiple inputs.""" - # Mock the CentroidEstimator class mock_centroid_calc = mocker.patch( "wf_psf.data.centroids.CentroidEstimator", autospec=True diff --git a/src/wf_psf/tests/test_inference/psf_inference_test.py b/src/wf_psf/tests/test_inference/psf_inference_test.py index 05728de7..28a2c1af 100644 --- a/src/wf_psf/tests/test_inference/psf_inference_test.py +++ b/src/wf_psf/tests/test_inference/psf_inference_test.py @@ -161,7 +161,6 @@ def psf_single_star_setup(mock_inference_config): def test_set_config_paths(mock_inference_config): """Test setting configuration paths.""" - # Initialize handler and inject mock config config_handler = InferenceConfigHandler("fake/path") config_handler.inference_config = mock_inference_config From 529fcd3c8dcb330d76bfd0e11a0a9d2945fa2555 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 31 Oct 2025 12:36:32 +0100 Subject: [PATCH 115/129] Reformat with black --- src/wf_psf/data/data_handler.py | 41 ++- src/wf_psf/data/data_zernike_utils.py | 2 - src/wf_psf/data/old_zernike_prior.py | 335 ++++++++++++++++++ src/wf_psf/instrument/ccd_misalignments.py | 6 +- .../psf_model_physical_polychromatic.py | 3 - src/wf_psf/psf_models/psf_model_loader.py | 1 - src/wf_psf/psf_models/tf_modules/tf_layers.py | 1 - .../psf_models/tf_modules/tf_psf_field.py | 4 +- src/wf_psf/psf_models/tf_modules/tf_utils.py | 1 - .../test_metrics/metrics_interface_test.py | 2 +- .../tests/test_utils/configs_handler_test.py | 1 - src/wf_psf/utils/utils.py | 1 + 12 files changed, 361 insertions(+), 37 deletions(-) create mode 100644 src/wf_psf/data/old_zernike_prior.py diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index bdcf9a6b..052fe730 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -118,7 +118,6 @@ def __init__( `load_data=True` is used. - TensorFlow conversion is performed at the end of initialization via `convert_dataset_to_tensorflow()`. """ - self.dataset_type = dataset_type self.data_params = data_params self.simPSF = simPSF @@ -182,7 +181,6 @@ def _validate_dataset_structure(self): def _convert_dataset_to_tensorflow(self): """Convert dataset to TensorFlow tensors.""" - self.dataset["positions"] = ensure_tensor( self.dataset["positions"], dtype=tf.float32 ) @@ -244,8 +242,8 @@ def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: """ Extract and concatenate star-related data from training and test datasets. - This function retrieves arrays (e.g., postage stamps, masks, positions) from - both the training and test datasets using the specified keys, converts them + This function retrieves arrays (e.g., postage stamps, masks, positions) from + both the training and test datasets using the specified keys, converts them to NumPy if necessary, and concatenates them along the first axis. Parameters @@ -253,27 +251,27 @@ def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: data : DataConfigHandler Object containing training and test datasets. train_key : str - Key to retrieve data from the training dataset + Key to retrieve data from the training dataset (e.g., 'noisy_stars', 'masks'). test_key : str - Key to retrieve data from the test dataset + Key to retrieve data from the test dataset (e.g., 'stars', 'masks'). Returns ------- np.ndarray - Concatenated NumPy array containing the selected data from both + Concatenated NumPy array containing the selected data from both training and test sets. Raises ------ KeyError - If either the training or test dataset does not contain the + If either the training or test dataset does not contain the requested key. Notes ----- - - Designed for datasets with separate train/test splits, such as when + - Designed for datasets with separate train/test splits, such as when evaluating metrics on held-out data. - TensorFlow tensors are automatically converted to NumPy arrays. - Requires eager execution if TensorFlow tensors are present. @@ -304,6 +302,7 @@ def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: # Concatenate and return return np.concatenate((train_data, test_data), axis=0) + def get_data_array( data, run_type: str, @@ -334,7 +333,7 @@ def get_data_array( train_key : str, optional Key for training dataset access. If None and key is provided, defaults to key value. Default is None. - test_key : str, optional + test_key : str, optional Key for test dataset access. If None, defaults to the resolved train_key value. Default is None. allow_missing : bool, default True @@ -355,12 +354,12 @@ def get_data_array( resolved for the operation and allow_missing=False. KeyError If the specified key is not found in the dataset and allow_missing=False. - + Notes ----- Key resolution follows this priority order: 1. train_key = train_key or key - 2. test_key = test_key or resolved_train_key + 2. test_key = test_key or resolved_train_key 3. key = key or resolved_train_key (for inference fallback) For TensorFlow tensors, the .numpy() method is called to convert to NumPy. @@ -370,13 +369,13 @@ def get_data_array( -------- >>> # Training data retrieval >>> train_data = get_data_array(data, "training", train_key="noisy_stars") - + >>> # Inference with fallback handling - >>> inference_data = get_data_array(data, "inference", key="positions", + >>> inference_data = get_data_array(data, "inference", key="positions", ... allow_missing=True) >>> if inference_data is None: ... print("No inference data available") - + >>> # Using key parameter for both train and inference >>> result = get_data_array(data, "inference", key="positions") """ @@ -384,18 +383,18 @@ def get_data_array( valid_run_types = {"training", "simulation", "metrics", "inference"} if run_type not in valid_run_types: raise ValueError(f"run_type must be one of {valid_run_types}, got '{run_type}'") - + # Simplify key resolution with clear precedence effective_train_key = train_key or key effective_test_key = test_key or effective_train_key effective_key = key or effective_train_key - + try: if run_type in {"simulation", "training", "metrics"}: return extract_star_data(data, effective_train_key, effective_test_key) else: # inference return _get_direct_data(data, effective_key, allow_missing) - except Exception as e: + except Exception: if allow_missing: return None raise @@ -417,7 +416,7 @@ def _get_direct_data(data, key: str, allow_missing: bool) -> Optional[np.ndarray Returns ------- np.ndarray or None - Data converted to NumPy array, or None if allow_missing=True and + Data converted to NumPy array, or None if allow_missing=True and data is unavailable. Raises @@ -437,11 +436,11 @@ def _get_direct_data(data, key: str, allow_missing: bool) -> Optional[np.ndarray if allow_missing: return None raise ValueError("No key provided for inference data") - + value = data.dataset.get(key, None) if value is None: if allow_missing: return None raise KeyError(f"Key '{key}' not found in inference dataset") - + return value.numpy() if tf.is_tensor(value) else np.asarray(value) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index 399ee9ef..0fad9c8e 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -181,7 +181,6 @@ def pad_tf_zernikes(zk_param: tf.Tensor, zk_prior: tf.Tensor, n_zks_total: int): padded_zk_prior : tf.Tensor Padded Zernike coefficients for the prior part. Shape [batch, n_zks_total, 1, 1]. """ - pad_num_param = n_zks_total - tf.shape(zk_param)[1] pad_num_prior = n_zks_total - tf.shape(zk_prior)[1] @@ -230,7 +229,6 @@ def assemble_zernike_contributions( tf.Tensor A tensor representing the full Zernike contribution map. """ - zernike_contribution_list = [] # Prior diff --git a/src/wf_psf/data/old_zernike_prior.py b/src/wf_psf/data/old_zernike_prior.py new file mode 100644 index 00000000..0feb3e70 --- /dev/null +++ b/src/wf_psf/data/old_zernike_prior.py @@ -0,0 +1,335 @@ +import numpy as np +import tensorflow as tf +from wf_psf.instrument.ccd_misalignments import CCDMisalignmentCalculator +from wf_psf.data.centroids import compute_zernike_tip_tilt +from fractions import Fraction +import logging + +logger = logging.getLogger(__name__) + + +def get_np_obs_positions(data): + """Get observed positions in numpy from the provided dataset. + + This method concatenates the positions of the stars from both the training + and test datasets to obtain the observed positions. + + Parameters + ---------- + data : DataConfigHandler + Object containing training and test datasets. + + Returns + ------- + np.ndarray + Numpy array containing the observed positions of the stars. + + Notes + ----- + The observed positions are obtained by concatenating the positions of stars + from both the training and test datasets along the 0th axis. + """ + obs_positions = np.concatenate( + ( + data.training_data.dataset["positions"], + data.test_data.dataset["positions"], + ), + axis=0, + ) + + return obs_positions + + +def get_obs_positions(data): + """Get observed positions from the provided dataset. + + Parameters + ---------- + data : DataConfigHandler + Object containing training and test datasets. + + Returns + ------- + tf.Tensor + Tensor containing the observed positions of the stars. + """ + obs_positions = get_np_obs_positions(data) + + return tf.convert_to_tensor(obs_positions, dtype=tf.float32) + + +def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: + """Extract specific star-related data from training and test datasets. + + This function retrieves and concatenates specific star-related data (e.g., stars, masks) from the + star training and test datasets such as star stamps or masks, based on the provided keys. + + Parameters + ---------- + data : DataConfigHandler + Object containing training and test datasets. + train_key : str + The key to retrieve data from the training dataset (e.g., 'noisy_stars', 'masks'). + test_key : str + The key to retrieve data from the test dataset (e.g., 'stars', 'masks'). + + Returns + ------- + np.ndarray + A NumPy array containing the concatenated data for the given keys. + + Raises + ------ + KeyError + If the specified keys do not exist in the training or test datasets. + + Notes + ----- + - If the dataset contains TensorFlow tensors, they will be converted to NumPy arrays. + - Ensure that eager execution is enabled when calling this function. + """ + # Ensure the requested keys exist in both training and test datasets + missing_keys = [ + key + for key, dataset in [ + (train_key, data.training_data.dataset), + (test_key, data.test_data.dataset), + ] + if key not in dataset + ] + + if missing_keys: + raise KeyError(f"Missing keys in dataset: {missing_keys}") + + # Retrieve data from training and test sets + train_data = data.training_data.dataset[train_key] + test_data = data.test_data.dataset[test_key] + + # Convert to NumPy if necessary + if tf.is_tensor(train_data): + train_data = train_data.numpy() + if tf.is_tensor(test_data): + test_data = test_data.numpy() + + # Concatenate and return + return np.concatenate((train_data, test_data), axis=0) + + +def get_np_zernike_prior(data): + """Get the zernike prior from the provided dataset. + + This method concatenates the stars from both the training + and test datasets to obtain the full prior. + + Parameters + ---------- + data : DataConfigHandler + Object containing training and test datasets. + + Returns + ------- + zernike_prior : np.ndarray + Numpy array containing the full prior. + """ + zernike_prior = np.concatenate( + ( + data.training_data.dataset["zernike_prior"], + data.test_data.dataset["zernike_prior"], + ), + axis=0, + ) + + return zernike_prior + + +def compute_centroid_correction(model_params, data, batch_size: int = 1) -> np.ndarray: + """Compute centroid corrections using Zernike polynomials. + + This function calculates the Zernike contributions required to match the centroid + of the WaveDiff PSF model to the observed star centroids, processing in batches. + + Parameters + ---------- + model_params : RecursiveNamespace + An object containing parameters for the PSF model, including pixel sampling and initial centroid window parameters. + + data : DataConfigHandler + An object containing training and test datasets, including observed PSFs + and optional star masks. + + batch_size : int, optional + The batch size to use when processing the stars. Default is 16. + + + Returns + ------- + zernike_centroid_array : np.ndarray + A 2D NumPy array of shape `(n_stars, 3)`, where `n_stars` is the number of + observed stars. The array contains the computed Zernike contributions, + with zero padding applied to the first column to ensure a consistent shape. + """ + star_postage_stamps = extract_star_data( + data=data, train_key="noisy_stars", test_key="stars" + ) + + # Get star mask catalogue only if "masks" exist in both training and test datasets + star_masks = ( + extract_star_data(data=data, train_key="masks", test_key="masks") + if ( + data.training_data.dataset.get("masks") is not None + and data.test_data.dataset.get("masks") is not None + and tf.size(data.training_data.dataset["masks"]) > 0 + and tf.size(data.test_data.dataset["masks"]) > 0 + ) + else None + ) + + pix_sampling = model_params.pix_sampling * 1e-6 # Change units from [um] to [m] + + # Ensure star_masks is properly handled + star_masks = star_masks if star_masks is not None else None + + reference_shifts = [ + float(Fraction(value)) for value in model_params.reference_shifts + ] + + n_stars = len(star_postage_stamps) + zernike_centroid_array = [] + + # Batch process the stars + for i in range(0, n_stars, batch_size): + batch_postage_stamps = star_postage_stamps[i : i + batch_size] + batch_masks = star_masks[i : i + batch_size] if star_masks is not None else None + + # Compute Zernike 1 and Zernike 2 for the batch + zk1_2_batch = -1.0 * compute_zernike_tip_tilt( + batch_postage_stamps, batch_masks, pix_sampling, reference_shifts + ) + + # Zero pad array for each batch and append + zernike_centroid_array.append( + np.pad( + zk1_2_batch, + pad_width=[(0, 0), (1, 0)], + mode="constant", + constant_values=0, + ) + ) + + # Combine all batches into a single array + return np.concatenate(zernike_centroid_array, axis=0) + + +def compute_ccd_misalignment(model_params, data): + """Compute CCD misalignment. + + Parameters + ---------- + model_params : RecursiveNamespace + Object containing parameters for this PSF model class. + data : DataConfigHandler + Object containing training and test datasets. + + Returns + ------- + zernike_ccd_misalignment_array : np.ndarray + Numpy array containing the Zernike contributions to model the CCD chip misalignments. + """ + obs_positions = get_np_obs_positions(data) + + ccd_misalignment_calculator = CCDMisalignmentCalculator( + tiles_path=model_params.ccd_misalignments_input_path, + x_lims=model_params.x_lims, + y_lims=model_params.y_lims, + tel_focal_length=model_params.tel_focal_length, + tel_diameter=model_params.tel_diameter, + ) + # Compute required zernike 4 for each position + zk4_values = np.array( + [ + ccd_misalignment_calculator.get_zk4_from_position(single_pos) + for single_pos in obs_positions + ] + ).reshape(-1, 1) + + # Zero pad array to get shape (n_stars, n_zernike=4) + zernike_ccd_misalignment_array = np.pad( + zk4_values, pad_width=[(0, 0), (3, 0)], mode="constant", constant_values=0 + ) + + return zernike_ccd_misalignment_array + + +def get_zernike_prior(model_params, data, batch_size: int = 16): + """Get Zernike priors from the provided dataset. + + This method concatenates the Zernike priors from both the training + and test datasets. + + Parameters + ---------- + model_params : RecursiveNamespace + Object containing parameters for this PSF model class. + data : DataConfigHandler + Object containing training and test datasets. + batch_size : int, optional + The batch size to use when processing the stars. Default is 16. + + Returns + ------- + tf.Tensor + Tensor containing the observed positions of the stars. + + Notes + ----- + The Zernike prior are obtained by concatenating the Zernike priors + from both the training and test datasets along the 0th axis. + + """ + # List of zernike contribution + zernike_contribution_list = [] + + if model_params.use_prior: + logger.info("Reading in Zernike prior into Zernike contribution list...") + zernike_contribution_list.append(get_np_zernike_prior(data)) + + if model_params.correct_centroids: + logger.info("Adding centroid correction to Zernike contribution list...") + zernike_contribution_list.append( + compute_centroid_correction(model_params, data, batch_size) + ) + + if model_params.add_ccd_misalignments: + logger.info("Adding CCD mis-alignments to Zernike contribution list...") + zernike_contribution_list.append(compute_ccd_misalignment(model_params, data)) + + if len(zernike_contribution_list) == 1: + zernike_contribution = zernike_contribution_list[0] + else: + # Get max zk order + max_zk_order = np.max( + np.array( + [ + zk_contribution.shape[1] + for zk_contribution in zernike_contribution_list + ] + ) + ) + + zernike_contribution = np.zeros( + (zernike_contribution_list[0].shape[0], max_zk_order) + ) + + # Pad arrays to get the same length and add the final contribution + for it in range(len(zernike_contribution_list)): + current_zk_order = zernike_contribution_list[it].shape[1] + current_zernike_contribution = np.pad( + zernike_contribution_list[it], + pad_width=[(0, 0), (0, int(max_zk_order - current_zk_order))], + mode="constant", + constant_values=0, + ) + + zernike_contribution += current_zernike_contribution + + return tf.convert_to_tensor(zernike_contribution, dtype=tf.float32) diff --git a/src/wf_psf/instrument/ccd_misalignments.py b/src/wf_psf/instrument/ccd_misalignments.py index 49bc8180..873509e5 100644 --- a/src/wf_psf/instrument/ccd_misalignments.py +++ b/src/wf_psf/instrument/ccd_misalignments.py @@ -161,11 +161,7 @@ def _preprocess_tile_data(self) -> None: self.tiles_z_average = np.mean(self.tiles_z_lims) def _initialize_polygons(self): - """Initialize polygons to look for CCD IDs. - - Each CCD is represented by a polygon defined by its corner points. - - """ + """Initialize polygons to look for CCD IDs""" # Build polygon list corresponding to each CCD self.ccd_polygons = [] diff --git a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index e2c84868..fb8bc902 100644 --- a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -26,8 +26,6 @@ TFPhysicalLayer, ) from wf_psf.psf_models.tf_modules.tf_utils import ensure_tensor -from wf_psf.utils.read_config import RecursiveNamespace -from wf_psf.utils.configs_handler import DataConfigHandler import logging @@ -282,7 +280,6 @@ def tf_zernike_OPD(self): def _build_tf_batch_poly_PSF(self): """Eagerly build the TFBatchPolychromaticPSF layer with numpy-based obscurations.""" - return TFBatchPolychromaticPSF( obscurations=self.obscurations, output_Q=self.output_Q, diff --git a/src/wf_psf/psf_models/psf_model_loader.py b/src/wf_psf/psf_models/psf_model_loader.py index e445f7af..e41e3536 100644 --- a/src/wf_psf/psf_models/psf_model_loader.py +++ b/src/wf_psf/psf_models/psf_model_loader.py @@ -10,7 +10,6 @@ import logging from wf_psf.psf_models.psf_models import get_psf_model, get_psf_model_weights_filepath -import tensorflow as tf logger = logging.getLogger(__name__) diff --git a/src/wf_psf/psf_models/tf_modules/tf_layers.py b/src/wf_psf/psf_models/tf_modules/tf_layers.py index 7b22b5e8..cf11d9cd 100644 --- a/src/wf_psf/psf_models/tf_modules/tf_layers.py +++ b/src/wf_psf/psf_models/tf_modules/tf_layers.py @@ -999,7 +999,6 @@ def call(self, positions): If the shape of the input `positions` tensor is not compatible. """ - # Find indices for all positions in one batch operation idx = find_position_indices(self.obs_pos, positions) diff --git a/src/wf_psf/psf_models/tf_modules/tf_psf_field.py b/src/wf_psf/psf_models/tf_modules/tf_psf_field.py index 21c0f9a4..07b523d1 100644 --- a/src/wf_psf/psf_models/tf_modules/tf_psf_field.py +++ b/src/wf_psf/psf_models/tf_modules/tf_psf_field.py @@ -222,7 +222,9 @@ def __init__(self, model_params, training_params, data, coeff_mat): self.output_Q = model_params.output_Q # Inputs: TF_physical_layer - self.obs_pos = ensure_tensor(get_data_array(data, data.run_type, key="positions"), dtype=tf.float32) + self.obs_pos = ensure_tensor( + get_data_array(data, data.run_type, key="positions"), dtype=tf.float32 + ) self.zks_prior = get_ground_truth_zernike(data) self.n_zks_prior = tf.shape(self.zks_prior)[1].numpy() diff --git a/src/wf_psf/psf_models/tf_modules/tf_utils.py b/src/wf_psf/psf_models/tf_modules/tf_utils.py index 09540e60..4bd1246a 100644 --- a/src/wf_psf/psf_models/tf_modules/tf_utils.py +++ b/src/wf_psf/psf_models/tf_modules/tf_utils.py @@ -16,7 +16,6 @@ """ import tensorflow as tf -import numpy as np @tf.function diff --git a/src/wf_psf/tests/test_metrics/metrics_interface_test.py b/src/wf_psf/tests/test_metrics/metrics_interface_test.py index bf52f0aa..5672d648 100644 --- a/src/wf_psf/tests/test_metrics/metrics_interface_test.py +++ b/src/wf_psf/tests/test_metrics/metrics_interface_test.py @@ -1,6 +1,6 @@ from unittest.mock import patch, MagicMock import pytest -from wf_psf.metrics.metrics_interface import evaluate_model, MetricsParamsHandler +from wf_psf.metrics.metrics_interface import evaluate_model from wf_psf.data.data_handler import DataHandler diff --git a/src/wf_psf/tests/test_utils/configs_handler_test.py b/src/wf_psf/tests/test_utils/configs_handler_test.py index f8299997..57dfdc8a 100644 --- a/src/wf_psf/tests/test_utils/configs_handler_test.py +++ b/src/wf_psf/tests/test_utils/configs_handler_test.py @@ -13,7 +13,6 @@ from wf_psf.utils.io import FileIOHandler from wf_psf.utils.configs_handler import ( TrainingConfigHandler, - MetricsConfigHandler, DataConfigHandler, ) import os diff --git a/src/wf_psf/utils/utils.py b/src/wf_psf/utils/utils.py index 17219ad9..2027d01c 100644 --- a/src/wf_psf/utils/utils.py +++ b/src/wf_psf/utils/utils.py @@ -5,6 +5,7 @@ """ import numpy as np +from typing import Tuple import tensorflow as tf import PIL import zernike as zk From 6c3e9d099e6256edfc2ff1429691c988548191d9 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Mon, 8 Dec 2025 11:01:46 +0100 Subject: [PATCH 116/129] Remove outdated back module --- src/wf_psf/data/old_zernike_prior.py | 335 --------------------------- 1 file changed, 335 deletions(-) delete mode 100644 src/wf_psf/data/old_zernike_prior.py diff --git a/src/wf_psf/data/old_zernike_prior.py b/src/wf_psf/data/old_zernike_prior.py deleted file mode 100644 index 0feb3e70..00000000 --- a/src/wf_psf/data/old_zernike_prior.py +++ /dev/null @@ -1,335 +0,0 @@ -import numpy as np -import tensorflow as tf -from wf_psf.instrument.ccd_misalignments import CCDMisalignmentCalculator -from wf_psf.data.centroids import compute_zernike_tip_tilt -from fractions import Fraction -import logging - -logger = logging.getLogger(__name__) - - -def get_np_obs_positions(data): - """Get observed positions in numpy from the provided dataset. - - This method concatenates the positions of the stars from both the training - and test datasets to obtain the observed positions. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - np.ndarray - Numpy array containing the observed positions of the stars. - - Notes - ----- - The observed positions are obtained by concatenating the positions of stars - from both the training and test datasets along the 0th axis. - """ - obs_positions = np.concatenate( - ( - data.training_data.dataset["positions"], - data.test_data.dataset["positions"], - ), - axis=0, - ) - - return obs_positions - - -def get_obs_positions(data): - """Get observed positions from the provided dataset. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - tf.Tensor - Tensor containing the observed positions of the stars. - """ - obs_positions = get_np_obs_positions(data) - - return tf.convert_to_tensor(obs_positions, dtype=tf.float32) - - -def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: - """Extract specific star-related data from training and test datasets. - - This function retrieves and concatenates specific star-related data (e.g., stars, masks) from the - star training and test datasets such as star stamps or masks, based on the provided keys. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - train_key : str - The key to retrieve data from the training dataset (e.g., 'noisy_stars', 'masks'). - test_key : str - The key to retrieve data from the test dataset (e.g., 'stars', 'masks'). - - Returns - ------- - np.ndarray - A NumPy array containing the concatenated data for the given keys. - - Raises - ------ - KeyError - If the specified keys do not exist in the training or test datasets. - - Notes - ----- - - If the dataset contains TensorFlow tensors, they will be converted to NumPy arrays. - - Ensure that eager execution is enabled when calling this function. - """ - # Ensure the requested keys exist in both training and test datasets - missing_keys = [ - key - for key, dataset in [ - (train_key, data.training_data.dataset), - (test_key, data.test_data.dataset), - ] - if key not in dataset - ] - - if missing_keys: - raise KeyError(f"Missing keys in dataset: {missing_keys}") - - # Retrieve data from training and test sets - train_data = data.training_data.dataset[train_key] - test_data = data.test_data.dataset[test_key] - - # Convert to NumPy if necessary - if tf.is_tensor(train_data): - train_data = train_data.numpy() - if tf.is_tensor(test_data): - test_data = test_data.numpy() - - # Concatenate and return - return np.concatenate((train_data, test_data), axis=0) - - -def get_np_zernike_prior(data): - """Get the zernike prior from the provided dataset. - - This method concatenates the stars from both the training - and test datasets to obtain the full prior. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - zernike_prior : np.ndarray - Numpy array containing the full prior. - """ - zernike_prior = np.concatenate( - ( - data.training_data.dataset["zernike_prior"], - data.test_data.dataset["zernike_prior"], - ), - axis=0, - ) - - return zernike_prior - - -def compute_centroid_correction(model_params, data, batch_size: int = 1) -> np.ndarray: - """Compute centroid corrections using Zernike polynomials. - - This function calculates the Zernike contributions required to match the centroid - of the WaveDiff PSF model to the observed star centroids, processing in batches. - - Parameters - ---------- - model_params : RecursiveNamespace - An object containing parameters for the PSF model, including pixel sampling and initial centroid window parameters. - - data : DataConfigHandler - An object containing training and test datasets, including observed PSFs - and optional star masks. - - batch_size : int, optional - The batch size to use when processing the stars. Default is 16. - - - Returns - ------- - zernike_centroid_array : np.ndarray - A 2D NumPy array of shape `(n_stars, 3)`, where `n_stars` is the number of - observed stars. The array contains the computed Zernike contributions, - with zero padding applied to the first column to ensure a consistent shape. - """ - star_postage_stamps = extract_star_data( - data=data, train_key="noisy_stars", test_key="stars" - ) - - # Get star mask catalogue only if "masks" exist in both training and test datasets - star_masks = ( - extract_star_data(data=data, train_key="masks", test_key="masks") - if ( - data.training_data.dataset.get("masks") is not None - and data.test_data.dataset.get("masks") is not None - and tf.size(data.training_data.dataset["masks"]) > 0 - and tf.size(data.test_data.dataset["masks"]) > 0 - ) - else None - ) - - pix_sampling = model_params.pix_sampling * 1e-6 # Change units from [um] to [m] - - # Ensure star_masks is properly handled - star_masks = star_masks if star_masks is not None else None - - reference_shifts = [ - float(Fraction(value)) for value in model_params.reference_shifts - ] - - n_stars = len(star_postage_stamps) - zernike_centroid_array = [] - - # Batch process the stars - for i in range(0, n_stars, batch_size): - batch_postage_stamps = star_postage_stamps[i : i + batch_size] - batch_masks = star_masks[i : i + batch_size] if star_masks is not None else None - - # Compute Zernike 1 and Zernike 2 for the batch - zk1_2_batch = -1.0 * compute_zernike_tip_tilt( - batch_postage_stamps, batch_masks, pix_sampling, reference_shifts - ) - - # Zero pad array for each batch and append - zernike_centroid_array.append( - np.pad( - zk1_2_batch, - pad_width=[(0, 0), (1, 0)], - mode="constant", - constant_values=0, - ) - ) - - # Combine all batches into a single array - return np.concatenate(zernike_centroid_array, axis=0) - - -def compute_ccd_misalignment(model_params, data): - """Compute CCD misalignment. - - Parameters - ---------- - model_params : RecursiveNamespace - Object containing parameters for this PSF model class. - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - zernike_ccd_misalignment_array : np.ndarray - Numpy array containing the Zernike contributions to model the CCD chip misalignments. - """ - obs_positions = get_np_obs_positions(data) - - ccd_misalignment_calculator = CCDMisalignmentCalculator( - tiles_path=model_params.ccd_misalignments_input_path, - x_lims=model_params.x_lims, - y_lims=model_params.y_lims, - tel_focal_length=model_params.tel_focal_length, - tel_diameter=model_params.tel_diameter, - ) - # Compute required zernike 4 for each position - zk4_values = np.array( - [ - ccd_misalignment_calculator.get_zk4_from_position(single_pos) - for single_pos in obs_positions - ] - ).reshape(-1, 1) - - # Zero pad array to get shape (n_stars, n_zernike=4) - zernike_ccd_misalignment_array = np.pad( - zk4_values, pad_width=[(0, 0), (3, 0)], mode="constant", constant_values=0 - ) - - return zernike_ccd_misalignment_array - - -def get_zernike_prior(model_params, data, batch_size: int = 16): - """Get Zernike priors from the provided dataset. - - This method concatenates the Zernike priors from both the training - and test datasets. - - Parameters - ---------- - model_params : RecursiveNamespace - Object containing parameters for this PSF model class. - data : DataConfigHandler - Object containing training and test datasets. - batch_size : int, optional - The batch size to use when processing the stars. Default is 16. - - Returns - ------- - tf.Tensor - Tensor containing the observed positions of the stars. - - Notes - ----- - The Zernike prior are obtained by concatenating the Zernike priors - from both the training and test datasets along the 0th axis. - - """ - # List of zernike contribution - zernike_contribution_list = [] - - if model_params.use_prior: - logger.info("Reading in Zernike prior into Zernike contribution list...") - zernike_contribution_list.append(get_np_zernike_prior(data)) - - if model_params.correct_centroids: - logger.info("Adding centroid correction to Zernike contribution list...") - zernike_contribution_list.append( - compute_centroid_correction(model_params, data, batch_size) - ) - - if model_params.add_ccd_misalignments: - logger.info("Adding CCD mis-alignments to Zernike contribution list...") - zernike_contribution_list.append(compute_ccd_misalignment(model_params, data)) - - if len(zernike_contribution_list) == 1: - zernike_contribution = zernike_contribution_list[0] - else: - # Get max zk order - max_zk_order = np.max( - np.array( - [ - zk_contribution.shape[1] - for zk_contribution in zernike_contribution_list - ] - ) - ) - - zernike_contribution = np.zeros( - (zernike_contribution_list[0].shape[0], max_zk_order) - ) - - # Pad arrays to get the same length and add the final contribution - for it in range(len(zernike_contribution_list)): - current_zk_order = zernike_contribution_list[it].shape[1] - current_zernike_contribution = np.pad( - zernike_contribution_list[it], - pad_width=[(0, 0), (0, int(max_zk_order - current_zk_order))], - mode="constant", - constant_values=0, - ) - - zernike_contribution += current_zernike_contribution - - return tf.convert_to_tensor(zernike_contribution, dtype=tf.float32) From af12a6cf905de372531a084f8cb05b928bbca6d9 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Mon, 8 Dec 2025 13:21:46 +0100 Subject: [PATCH 117/129] Remove unused import left after rebase --- src/wf_psf/utils/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/wf_psf/utils/utils.py b/src/wf_psf/utils/utils.py index 2027d01c..17219ad9 100644 --- a/src/wf_psf/utils/utils.py +++ b/src/wf_psf/utils/utils.py @@ -5,7 +5,6 @@ """ import numpy as np -from typing import Tuple import tensorflow as tf import PIL import zernike as zk From 07e88c7004c0fc3490649e57874d41b19af9f4ac Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 12 Dec 2025 13:17:20 +0100 Subject: [PATCH 118/129] Update doc strings and cache handling - Updated classes and methods with complete doc strings - Added two cache clearing methods to PSFInference and PSFInferenceEngine classes --- src/wf_psf/inference/psf_inference.py | 352 +++++++++++++++++++++++++- 1 file changed, 341 insertions(+), 11 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index c7d73249..dc7b5ac7 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -20,6 +20,37 @@ class InferenceConfigHandler: + """ + Handle configuration loading and management for PSF inference. + + This class manages the loading of inference, training, and data configuration + files required for PSF inference operations. + + Parameters + ---------- + inference_config_path : str + Path to the inference configuration YAML file. + + Attributes + ---------- + inference_config_path : str + Path to the inference configuration file. + inference_config : RecursiveNamespace or None + Loaded inference configuration. + training_config : RecursiveNamespace or None + Loaded training configuration. + data_config : RecursiveNamespace or None + Loaded data configuration. + trained_model_path : Path + Path to the trained model directory. + model_subdir : str + Subdirectory name for model files. + trained_model_config_path : Path + Path to the training configuration file. + data_config_path : str or None + Path to the data configuration file. + """ + ids = ("inference_conf",) def __init__(self, inference_config_path: str): @@ -29,7 +60,20 @@ def __init__(self, inference_config_path: str): self.data_config = None def load_configs(self): - """Load configuration files based on the inference config.""" + """ + Load configuration files based on the inference config. + + Loads the inference configuration first, then uses it to determine and load + the training and data configurations. + + Notes + ----- + Updates the following attributes in-place: + - inference_config + - training_config + - data_config (if data_config_path is specified) + """ + self.inference_config = read_conf(self.inference_config_path) self.set_config_paths() self.training_config = read_conf(self.trained_model_config_path) @@ -39,7 +83,15 @@ def load_configs(self): self.data_config = read_conf(self.data_config_path) def set_config_paths(self): - """Extract and set the configuration paths.""" + """ + Extract and set the configuration paths from the inference config. + + Sets the following attributes: + - trained_model_path + - model_subdir + - trained_model_config_path + - data_config_path + """ # Set config paths config_paths = self.inference_config.inference.configs @@ -97,6 +149,35 @@ class PSFInference: Postage stamps of sources, e.g. star images (shape: [n_stars, h, w]). masks : array-like, optional Corresponding masks for the sources (same shape as sources). Defaults to None. + + + Attributes + ---------- + inference_config_path : str + Path to the inference configuration file. + x_field : array-like or None + x coordinates for PSF positions. + y_field : array-like or None + y coordinates for PSF positions. + seds : array-like or None + Spectral energy distributions. + sources : array-like or None + Source postage stamps. + masks : array-like or None + Source masks. + engine : PSFInferenceEngine or None + The inference engine instance. + + Examples + -------- + >>> psf_inf = PSFInference( + ... inference_config_path="config.yaml", + ... x_field=[100.5, 200.3], + ... y_field=[150.2, 250.8], + ... seds=sed_array + ... ) + >>> psf_inf.run_inference() + >>> psf = psf_inf.get_psf(0) """ def __init__( @@ -133,13 +214,25 @@ def __init__( @property def config_handler(self): + """ + Get or create the configuration handler. + + Returns + ------- + InferenceConfigHandler + The configuration handler instance with loaded configs. + """ if self._config_handler is None: self._config_handler = InferenceConfigHandler(self.inference_config_path) self._config_handler.load_configs() return self._config_handler def prepare_configs(self): - """Prepare the configuration for inference.""" + """ + Prepare the configuration for inference. + + Overwrites training model parameters with inference configuration values. + """ # Overwrite model parameters with inference config self.config_handler.overwrite_model_params( self.training_config, self.inference_config @@ -147,24 +240,63 @@ def prepare_configs(self): @property def inference_config(self): + """ + Get the inference configuration. + + Returns + ------- + RecursiveNamespace + The inference configuration object. + """ return self.config_handler.inference_config @property def training_config(self): + """ + Get the training configuration. + + Returns + ------- + RecursiveNamespace + The training configuration object. + """ return self.config_handler.training_config @property def data_config(self): + """ + Get the data configuration. + + Returns + ------- + RecursiveNamespace or None + The data configuration object, or None if not available. + """ return self.config_handler.data_config @property def simPSF(self): + """ + Get or create the PSF simulator. + + Returns + ------- + simPSF + The PSF simulator instance. + """ if self._simPSF is None: self._simPSF = psf_models.simPSF(self.training_config.training.model_params) return self._simPSF def _prepare_dataset_for_inference(self): - """Prepare dataset dictionary for inference, returning None if positions are invalid.""" + """ + Prepare dataset dictionary for inference. + + Returns + ------- + dict or None + Dictionary containing positions, sources, and masks, or None if positions are invalid. + """ positions = self.get_positions() if positions is None: return None @@ -172,6 +304,14 @@ def _prepare_dataset_for_inference(self): @property def data_handler(self): + """ + Get or create the data handler. + + Returns + ------- + DataHandler + The data handler instance configured for inference. + """ if self._data_handler is None: # Instantiate the data handler self._data_handler = DataHandler( @@ -188,6 +328,14 @@ def data_handler(self): @property def trained_psf_model(self): + """ + Get or load the trained PSF model. + + Returns + ------- + Model + The loaded trained PSF model. + """ if self._trained_psf_model is None: self._trained_psf_model = self.load_inference_model() return self._trained_psf_model @@ -229,7 +377,19 @@ def get_positions(self): return np.column_stack((x_flat, y_flat)) def load_inference_model(self): - """Load the trained PSF model based on the inference configuration.""" + """Load the trained PSF model based on the inference configuration. + + Returns + ------- + Model + The loaded trained PSF model. + + Notes + ----- + Constructs the weights path pattern based on the trained model path, + model subdirectory, model name, id name, and cycle number specified in the + configuration files. + """ model_path = self.config_handler.trained_model_path model_dir = self.config_handler.model_subdir model_name = self.training_config.training.model_params.model_name @@ -250,6 +410,12 @@ def load_inference_model(self): @property def n_bins_lambda(self): + """Get the number of wavelength bins for inference. + + Returns + ------- + int + The number of wavelength bins used during inference.""" if self._n_bins_lambda is None: self._n_bins_lambda = ( self.inference_config.inference.model_params.n_bins_lda @@ -258,6 +424,14 @@ def n_bins_lambda(self): @property def batch_size(self): + """ + Get the batch size for inference. + + Returns + ------- + int + The batch size for processing during inference. + """ if self._batch_size is None: self._batch_size = self.inference_config.inference.batch_size assert self._batch_size > 0, "Batch size must be greater than 0." @@ -265,12 +439,26 @@ def batch_size(self): @property def cycle(self): + """Get the cycle number for inference. + + Returns + ------- + int + The cycle number used for loading the trained model. + """ if self._cycle is None: self._cycle = self.inference_config.inference.cycle return self._cycle @property def output_dim(self): + """Get the output dimension for PSF inference. + + Returns + ------- + int + The output dimension (height and width) of the inferred PSFs. + """ if self._output_dim is None: self._output_dim = self.inference_config.inference.model_params.output_dim return self._output_dim @@ -317,7 +505,18 @@ def _prepare_positions_and_seds(self): return positions, sed_data_tensor def run_inference(self): - """Run PSF inference and return the full PSF array.""" + """Run PSF inference and return the full PSF array. + + Returns + ------- + numpy.ndarray + Array of inferred PSFs with shape (n_samples, output_dim, output_dim). + + Notes + ----- + Prepares configurations and input data, initializes the inference engine, + and computes the PSF for all input positions. + """ # Prepare the configuration for inference self.prepare_configs() @@ -332,10 +531,25 @@ def run_inference(self): return self.engine.compute_psfs(positions, sed_data) def _ensure_psf_inference_completed(self): + """Ensure that PSF inference has been completed. + + Runs inference if it has not been done yet. + """ if self.engine is None or self.engine.inferred_psfs is None: self.run_inference() def get_psfs(self): + """Get all inferred PSFs. + + Returns + ------- + numpy.ndarray + Array of inferred PSFs with shape (n_samples, output_dim, output_dim). + + Notes + ----- + Ensures automatically that inference has been completed before accessing the PSFs. + """ self._ensure_psf_inference_completed() return self.engine.get_psfs() @@ -343,7 +557,21 @@ def get_psf(self, index: int = 0) -> np.ndarray: """ Get the PSF at a specific index. - If only a single star was passed during instantiation, the index defaults to 0. + Parameters + ---------- + index : int, optional + Index of the PSF to retrieve (default is 0). + + Returns + ------- + numpy.ndarray + The inferred PSF at the specified index with shape (output_dim, output_dim). + + Notes + ----- + Ensures automatically that inference has been completed before accessing the PSF. + If only a single star was passed during instantiation, the index defaults to 0 + and bounds checking is relaxed. """ self._ensure_psf_inference_completed() @@ -356,8 +584,60 @@ def get_psf(self, index: int = 0) -> np.ndarray: # Otherwise, return the PSF at the requested index return inferred_psfs[index] + def clear_cache(self): + """ + Clear all cached properties and reset the instance. + + This method resets all lazy-loaded properties, including the config handler, + PSF simulator, data handler, trained model, and inference engine. Useful for + freeing memory or forcing a fresh initialization. + + Notes + ----- + After calling this method, accessing any property will trigger re-initialization. + """ + self._config_handler = None + self._simPSF = None + self._data_handler = None + self._trained_psf_model = None + self._n_bins_lambda = None + self._batch_size = None + self._cycle = None + self._output_dim = None + self.engine = None + class PSFInferenceEngine: + """Engine to perform PSF inference using a trained model. + + This class handles the batch-wise computation of PSFs using a trained PSF model. + It manages the batching of input positions and SEDs, and caches the inferred PSFs for later access. + + Parameters + ---------- + trained_model : Model + The trained PSF model to use for inference. + batch_size : int + The batch size for processing during inference. + output_dim : int + The output dimension (height and width) of the inferred PSFs. + + Attributes + ---------- + trained_model : Model + The trained PSF model used for inference. + batch_size : int + The batch size for processing during inference. + output_dim : int + The output dimension (height and width) of the inferred PSFs. + + Examples + -------- + >>> engine = PSFInferenceEngine(model, batch_size=32, output_dim=64) + >>> psfs = engine.compute_psfs(positions, seds) + >>> single_psf = engine.get_psf(0) + """ + def __init__(self, trained_model, batch_size: int, output_dim: int): self.trained_model = trained_model self.batch_size = batch_size @@ -366,11 +646,35 @@ def __init__(self, trained_model, batch_size: int, output_dim: int): @property def inferred_psfs(self) -> np.ndarray: - """Access the cached inferred PSFs, if available.""" + """Access the cached inferred PSFs, if available. + + Returns + ------- + numpy.ndarray or None + The cached inferred PSFs, or None if not yet computed. + """ return self._inferred_psfs def compute_psfs(self, positions: tf.Tensor, sed_data: tf.Tensor) -> np.ndarray: - """Compute and cache PSFs for the input source parameters.""" + """Compute and cache PSFs for the input source parameters. + + Parameters + ---------- + positions : tf.Tensor + Tensor of shape (n_samples, 2) containing the (x, y) positions + sed_data : tf.Tensor + Tensor of shape (n_samples, n_bins, 2) containing the SEDs + + Returns + ------- + numpy.ndarray + Array of inferred PSFs with shape (n_samples, output_dim, output_dim). + + Notes + ----- + PSFs are computed in batches according to the specified batch_size. + Results are cached internally for subsequent access via get_psfs() or get_psf(). + """ n_samples = positions.shape[0] self._inferred_psfs = np.zeros( (n_samples, self.output_dim, self.output_dim), dtype=np.float32 @@ -397,13 +701,39 @@ def compute_psfs(self, positions: tf.Tensor, sed_data: tf.Tensor) -> np.ndarray: return self._inferred_psfs def get_psfs(self) -> np.ndarray: - """Get all the generated PSFs.""" + """Get all the generated PSFs. + + Returns + ------- + numpy.ndarray + Array of inferred PSFs with shape (n_samples, output_dim, output_dim). + """ if self._inferred_psfs is None: raise ValueError("PSFs not yet computed. Call compute_psfs() first.") return self._inferred_psfs def get_psf(self, index: int) -> np.ndarray: - """Get the PSF at a specific index.""" + """Get the PSF at a specific index. + + Returns + ------- + numpy.ndarray + The inferred PSF at the specified index with shape (output_dim, output_dim). + + Raises + ------ + ValueError + If PSFs have not yet been computed. + """ if self._inferred_psfs is None: raise ValueError("PSFs not yet computed. Call compute_psfs() first.") return self._inferred_psfs[index] + + def clear_cache(self): + """ + Clear cached inferred PSFs. + + Resets the internal PSF cache to free memory. After calling this method, + compute_psfs() must be called again before accessing PSFs. + """ + self._inferred_psfs = None From 43f894ffe4404ac6840dcaf17128538abda101fe Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 12 Dec 2025 13:20:14 +0100 Subject: [PATCH 119/129] Update psf_test_setup fixture for reusability and add unit tests for cache handling --- .../test_inference/psf_inference_test.py | 111 ++++++++++++++++-- 1 file changed, 98 insertions(+), 13 deletions(-) diff --git a/src/wf_psf/tests/test_inference/psf_inference_test.py b/src/wf_psf/tests/test_inference/psf_inference_test.py index 28a2c1af..4cff7a13 100644 --- a/src/wf_psf/tests/test_inference/psf_inference_test.py +++ b/src/wf_psf/tests/test_inference/psf_inference_test.py @@ -159,6 +159,44 @@ def psf_single_star_setup(mock_inference_config): } +@pytest.fixture +def mock_compute_psfs_with_cache(psf_test_setup): + """ + Fixture that patches PSFInferenceEngine.compute_psfs with a side effect + that populates the engine's cache. + + Returns + ------- + dict + Dictionary containing: + - 'mock': The mock object for compute_psfs + - 'inference': The PSFInference instance + - 'positions': Mock positions tensor + - 'seds': Mock SEDs tensor + - 'expected_psfs': Expected PSF array + """ + inference = psf_test_setup["inference"] + mock_positions = psf_test_setup["mock_positions"] + mock_seds = psf_test_setup["mock_seds"] + expected_psfs = psf_test_setup["expected_psfs"] + + with patch.object(PSFInferenceEngine, "compute_psfs") as mock_compute_psfs: + + def fake_compute_psfs(positions, seds): + inference.engine._inferred_psfs = expected_psfs + return expected_psfs + + mock_compute_psfs.side_effect = fake_compute_psfs + + yield { + "mock": mock_compute_psfs, + "inference": inference, + "positions": mock_positions, + "seds": mock_seds, + "expected_psfs": expected_psfs, + } + + def test_set_config_paths(mock_inference_config): """Test setting configuration paths.""" # Initialize handler and inject mock config @@ -365,30 +403,25 @@ def test_simpsf_uses_updated_model_params( @patch.object(PSFInference, "_prepare_positions_and_seds") -@patch.object(PSFInferenceEngine, "compute_psfs") def test_get_psfs_runs_inference( - mock_compute_psfs, mock_prepare_positions_and_seds, psf_test_setup + mock_prepare_positions_and_seds, mock_compute_psfs_with_cache ): - inference = psf_test_setup["inference"] - mock_positions = psf_test_setup["mock_positions"] - mock_seds = psf_test_setup["mock_seds"] - expected_psfs = psf_test_setup["expected_psfs"] + """Test that get_psfs uses cached PSFs after first computation.""" + mock = mock_compute_psfs_with_cache["mock"] + inference = mock_compute_psfs_with_cache["inference"] + mock_positions = mock_compute_psfs_with_cache["positions"] + mock_seds = mock_compute_psfs_with_cache["seds"] + expected_psfs = mock_compute_psfs_with_cache["expected_psfs"] mock_prepare_positions_and_seds.return_value = (mock_positions, mock_seds) - def fake_compute_psfs(positions, seds): - inference.engine._inferred_psfs = expected_psfs - return expected_psfs - - mock_compute_psfs.side_effect = fake_compute_psfs - psfs_1 = inference.get_psfs() assert np.all(psfs_1 == expected_psfs) psfs_2 = inference.get_psfs() assert np.all(psfs_2 == expected_psfs) - assert mock_compute_psfs.call_count == 1 + assert mock.call_count == 1 def test_single_star_inference_shape(psf_single_star_setup): @@ -483,3 +516,55 @@ def test_valueerror_on_mismatched_positions(psf_single_star_setup): inference._prepare_positions_and_seds() finally: patcher.stop() + + +def test_inference_clear_cache(psf_test_setup): + """Test that PSFInference clear_cache resets the instance of PSFInference.""" + inference = psf_test_setup["inference"] + inference._simPSF = MagicMock() + inference._data_handler = MagicMock() + inference._trained_psf_model = MagicMock() + inference._n_bins_lambda = MagicMock() + inference._batch_size = MagicMock() + inference._cycle = MagicMock() + inference._output_dim = MagicMock() + inference.engine = MagicMock() + + # Clear the cache + inference.clear_cache() + + # Check that the internal cache is None + assert ( + inference._config_handler == None, + inference._simPSF == None, + inference._data_handler == None, + inference._trained_psf_model == None, + inference._n_bins_lambda == None, + inference._batch_size == None, + inference._cycle == None, + inference._output_dim == None, + inference.engine == None, + ), "Inference attributes should be cleared to None" # type: ignore + + +def test_engine_clear_cache(psf_test_setup): + """Test that clear_cache resets the internal PSF cache.""" + inference = psf_test_setup["inference"] + expected_psfs = psf_test_setup["expected_psfs"] + + # Create the engine and compute PSFs + inference.engine = PSFInferenceEngine( + trained_model=inference.trained_psf_model, + batch_size=inference.batch_size, + output_dim=inference.output_dim, + ) + + inference.engine._inferred_psfs = expected_psfs + + # Clear the cache + inference.engine.clear_cache() + + # Check that the internal cache is None + assert ( + inference.engine._inferred_psfs is None + ), "PSF cache should be cleared to None" From a88c455357a540ddc0ce07271e70995bea387c85 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 12 Dec 2025 13:21:50 +0100 Subject: [PATCH 120/129] Remove unneeded mock of logger in evaluate_model --- src/wf_psf/tests/test_metrics/metrics_interface_test.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/wf_psf/tests/test_metrics/metrics_interface_test.py b/src/wf_psf/tests/test_metrics/metrics_interface_test.py index 5672d648..2d6b6553 100644 --- a/src/wf_psf/tests/test_metrics/metrics_interface_test.py +++ b/src/wf_psf/tests/test_metrics/metrics_interface_test.py @@ -200,8 +200,6 @@ def test_evaluate_model( ) as mock_evaluate_shape_results_dict, patch("numpy.save", new_callable=MagicMock) as mock_np_save, ): - # Mock the logger - _ = mocker.patch("wf_psf.metrics.metrics_interface.logger") # Call evaluate_model evaluate_model( From 6f0e21dcba30a2f3fdff93134e89cf3689a94c76 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 12 Dec 2025 13:22:16 +0100 Subject: [PATCH 121/129] Add wf_psf.inference to list of main packages --- docs/source/api.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/api.rst b/docs/source/api.rst index 3bc4d395..cda87632 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -8,6 +8,7 @@ This section contains the API reference for the main packages in WaveDiff. :recursive: wf_psf.data + wf_psf.inference wf_psf.metrics wf_psf.plotting wf_psf.psf_models From cd715c30a78ef5049df2e9fa16c3eef86e0a68e4 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 12 Dec 2025 13:30:18 +0100 Subject: [PATCH 122/129] Update version to 3.1.0 --- docs/source/conf.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 988df8fe..01edfeb6 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -21,7 +21,7 @@ else: copyright = f"{start_year}, CosmoStat" author = "CosmoStat" -release = "3.0.0" +release = "3.1.0" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/pyproject.toml b/pyproject.toml index 3107f12c..ec0c8d34 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "seaborn", ] -version = "3.0.0" +version = "3.1.0" [project.optional-dependencies] docs = [ From 5d54518e99e16586fba8f2323901b830a19b0915 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 12 Dec 2025 13:34:11 +0100 Subject: [PATCH 123/129] docs: Correct syntax error and code-block formatting in doc string examples --- src/wf_psf/inference/psf_inference.py | 30 ++++++++++++++++----------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index dc7b5ac7..70d7a9be 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -151,7 +151,7 @@ class PSFInference: Corresponding masks for the sources (same shape as sources). Defaults to None. - Attributes + Attributes ---------- inference_config_path : str Path to the inference configuration file. @@ -170,14 +170,18 @@ class PSFInference: Examples -------- - >>> psf_inf = PSFInference( - ... inference_config_path="config.yaml", - ... x_field=[100.5, 200.3], - ... y_field=[150.2, 250.8], - ... seds=sed_array - ... ) - >>> psf_inf.run_inference() - >>> psf = psf_inf.get_psf(0) + Basic usage with position coordinates and SEDs: + + .. code-block:: python + + psf_inf = PSFInference( + inference_config_path="config.yaml", + x_field=[100.5, 200.3], + y_field=[150.2, 250.8], + seds=sed_array + ) + psf_inf.run_inference() + psf = psf_inf.get_psf(0) """ def __init__( @@ -633,9 +637,11 @@ class PSFInferenceEngine: Examples -------- - >>> engine = PSFInferenceEngine(model, batch_size=32, output_dim=64) - >>> psfs = engine.compute_psfs(positions, seds) - >>> single_psf = engine.get_psf(0) + .. code-block:: python + + >>> engine = PSFInferenceEngine(model, batch_size=32, output_dim=64) + >>> psfs = engine.compute_psfs(positions, seds) + >>> single_psf = engine.get_psf(0) """ def __init__(self, trained_model, batch_size: int, output_dim: int): From b5da83f8d781a64d61cf5898de8d902b8a7283a8 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 23 Jan 2026 13:58:22 +0100 Subject: [PATCH 124/129] Remove weights_path references missed during rebase Rebase dropped commits that removed weights_path from the metrics interface. Cleaning up remaining references in docstrings and tests. --- src/wf_psf/metrics/metrics_interface.py | 2 -- src/wf_psf/tests/test_metrics/metrics_interface_test.py | 3 --- 2 files changed, 5 deletions(-) diff --git a/src/wf_psf/metrics/metrics_interface.py b/src/wf_psf/metrics/metrics_interface.py index bd6249e2..db410f4e 100644 --- a/src/wf_psf/metrics/metrics_interface.py +++ b/src/wf_psf/metrics/metrics_interface.py @@ -328,8 +328,6 @@ def evaluate_model( DataHandler object containing training and test data psf_model: object PSF model object - weights_path: str - Directory location of model weights metrics_output: str Directory location of metrics output diff --git a/src/wf_psf/tests/test_metrics/metrics_interface_test.py b/src/wf_psf/tests/test_metrics/metrics_interface_test.py index 2d6b6553..35f3b9a1 100644 --- a/src/wf_psf/tests/test_metrics/metrics_interface_test.py +++ b/src/wf_psf/tests/test_metrics/metrics_interface_test.py @@ -106,7 +106,6 @@ def test_evaluate_model_flags( trained_model_params=mock_trained_model_params, data=mock_data, psf_model=mock_psf_model, - weights_path="/mock/path", metrics_output="/mock/output", ) @@ -134,7 +133,6 @@ def test_missing_ground_truth_model_raises( trained_model_params=mock_trained_model_params, data=mock_data, psf_model=mock_psf_model, - weights_path="/mock/weights/path", metrics_output="/mock/metrics/output", ) @@ -168,7 +166,6 @@ def test_plotting_config_passed( trained_model_params=mock_trained_model_params, data=mock_data, psf_model=mock_psf_model, - weights_path="/mock/path", metrics_output="/mock/output", ) From 98b66e94ea234153c420ac9db056d7f7ad49db9f Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 23 Jan 2026 16:22:37 +0100 Subject: [PATCH 125/129] Add changelog fragment --- ...llack_159_psf_output_from_trained_model.md | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 changelog.d/20260123_160331_jennifer.pollack_159_psf_output_from_trained_model.md diff --git a/changelog.d/20260123_160331_jennifer.pollack_159_psf_output_from_trained_model.md b/changelog.d/20260123_160331_jennifer.pollack_159_psf_output_from_trained_model.md new file mode 100644 index 00000000..1206fbcd --- /dev/null +++ b/changelog.d/20260123_160331_jennifer.pollack_159_psf_output_from_trained_model.md @@ -0,0 +1,47 @@ + + + + +### New features + +- Added PSF inference capabilities for generating broadband (polychromatic) PSFs from trained models given star positions and SEDs +- Introduced `PSFInferenceEngine` class to centralize training, simulation, metrics, and inference workflows +- Added `run_type` attribute to `DataHandler` supporting training, simulation, metrics, and inference modes +- Implemented `ZernikeInputsFactory` class for building `ZernikeInputs` instances based on run type +- Added `psf_model_loader.py` module for centralized model weights loading + + + + + +### Internal changes + +- Refactored `TFPhysicalPolychromatic` and related modules to separate training vs. inference behavior +- Enhanced `ZernikeInputs` data class with intelligent assembly based on run type and available data +- Implemented hybrid loading pattern with eager loading in constructors and lazy-loading via property decorators +- Centralized PSF data extraction in `data_handler` module +- Improved code organization with new `tf_utils.py` module in `psf_models` sub-package +- Updated configuration handling to support inference workflows via `inference_config.yaml` +- Fixed incorrect argument name in `DataHandler` that prevented proper TensorFlow data type conversion +- Removed deprecated `get_obs_positions` method +- Updated documentation to include inference package From 33ad4f5ff64b51527fd5ba12db8b35746879a179 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Mon, 26 Jan 2026 12:57:35 +0100 Subject: [PATCH 126/129] Fix logger formatting for relative RMSE metrics Use f-strings instead of %-formatting to properly display percent symbols in metric output. --- src/wf_psf/metrics/metrics.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/wf_psf/metrics/metrics.py b/src/wf_psf/metrics/metrics.py index 0447d596..942a0622 100644 --- a/src/wf_psf/metrics/metrics.py +++ b/src/wf_psf/metrics/metrics.py @@ -152,8 +152,7 @@ def compute_poly_metric( # Print RMSE values logger.info("Absolute RMSE:\t %.4e \t +/- %.4e", rmse, std_rmse) - logger.info("Relative RMSE:\t %.4e % \t +/- %.4e %", rel_rmse, std_rel_rmse) - + logger.info(f"Relative RMSE:\t {rel_rmse:.4e}% \t +/- {std_rel_rmse:.4e}%") return rmse, rel_rmse, std_rmse, std_rel_rmse @@ -364,9 +363,8 @@ def compute_opd_metrics(tf_semiparam_field, gt_tf_semiparam_field, pos, batch_si rel_rmse_std = np.std(rel_rmse_vals) # Print RMSE values - logger.info("Absolute RMSE:\t %.4e % \t +/- %.4e %", rmse, rmse_std) - logger.info("Relative RMSE:\t %.4e % \t +/- %.4e %", rel_rmse, rel_rmse_std) - + logger.info("Absolute RMSE:\t %.4e \t +/- %.4e" % (rmse, rmse_std)) + logger.info(f"Relative RMSE:\t {rel_rmse:.4e}% \t +/- {rel_rmse_std:.4e}%") return rmse, rel_rmse, rmse_std, rel_rmse_std @@ -596,10 +594,10 @@ def compute_shape_metrics( # Print relative shape/size errors logger.info( - f"\nRelative sigma(e1) RMSE =\t {rel_rmse_e1:.4e} % \t +/- {std_rel_rmse_e1:.4e} %" + f"\nRelative sigma(e1) RMSE =\t {rel_rmse_e1:.4e}% \t +/- {std_rel_rmse_e1:.4e}%" ) logger.info( - f"Relative sigma(e2) RMSE =\t {rel_rmse_e2:.4e} % \t +/- {std_rel_rmse_e2:.4e} %" + f"Relative sigma(e2) RMSE =\t {rel_rmse_e2:.4e}% \t +/- {std_rel_rmse_e2:.4e}%" ) # Print number of stars From 63e6acca71b16f9836cda3ad51757ffa647b0cf7 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Mon, 26 Jan 2026 13:05:41 +0100 Subject: [PATCH 127/129] Update changelog with entry under Bug Fixes --- ...0331_jennifer.pollack_159_psf_output_from_trained_model.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/changelog.d/20260123_160331_jennifer.pollack_159_psf_output_from_trained_model.md b/changelog.d/20260123_160331_jennifer.pollack_159_psf_output_from_trained_model.md index 1206fbcd..453a504e 100644 --- a/changelog.d/20260123_160331_jennifer.pollack_159_psf_output_from_trained_model.md +++ b/changelog.d/20260123_160331_jennifer.pollack_159_psf_output_from_trained_model.md @@ -21,12 +21,10 @@ For top level release notes, leave all the headers commented out. - Added `psf_model_loader.py` module for centralized model weights loading - + + + + + + +### Internal changes + +- Remove deprecated/optional import tensorflow-addons statement from tf_layers.py + + diff --git a/src/wf_psf/psf_models/tf_modules/tf_layers.py b/src/wf_psf/psf_models/tf_modules/tf_layers.py index cf11d9cd..cdd01e16 100644 --- a/src/wf_psf/psf_models/tf_modules/tf_layers.py +++ b/src/wf_psf/psf_models/tf_modules/tf_layers.py @@ -7,7 +7,6 @@ """ import tensorflow as tf -import tensorflow_addons as tfa from wf_psf.psf_models.tf_modules.tf_modules import TFMonochromaticPSF from wf_psf.psf_models.tf_modules.tf_utils import find_position_indices from wf_psf.utils.utils import calc_poly_position_mat